mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
* prefer rocm v6 on windows Avoid building with v7 - more changes are needed * MLX: add header vendoring and remove go build tag This switches to using a vendoring approach for the mlx-c headers so that Go can build without requiring a cmake first. This enables building the new MLX based code by default. Every time cmake runs, the headers are refreshed, so we can easily keep them in sync when we bump mlx versions. Basic Windows and Linux support are verified. * ci: harden for flaky choco repo servers CI sometimes fails due to choco not actually installing cache. Since it just speeds up the build, we can proceed without. * review comments
5915 lines
276 KiB
C
5915 lines
276 KiB
C
// AUTO-GENERATED by generate_wrappers.go - DO NOT EDIT
|
|
// This file contains the function pointer definitions and initialization
|
|
// All function pointers are in a single compilation unit to avoid duplication
|
|
|
|
#include "mlx/c/mlx.h"
|
|
#include "mlx_dynamic.h"
|
|
#include <stdio.h>
|
|
|
|
// Platform-specific dynamic loading
|
|
#ifdef _WIN32
|
|
#include <windows.h>
|
|
#define GET_SYM(handle, name) (void*)GetProcAddress((HMODULE)(handle), name)
|
|
#else
|
|
#include <dlfcn.h>
|
|
#define GET_SYM(handle, name) dlsym(handle, name)
|
|
#endif
|
|
|
|
// Function pointer definitions
|
|
size_t (*mlx_dtype_size_ptr)(mlx_dtype dtype) = NULL;
|
|
int (*mlx_array_tostring_ptr)(mlx_string* str, const mlx_array arr) = NULL;
|
|
mlx_array (*mlx_array_new_ptr)(void) = NULL;
|
|
int (*mlx_array_free_ptr)(mlx_array arr) = NULL;
|
|
mlx_array (*mlx_array_new_bool_ptr)(bool val) = NULL;
|
|
mlx_array (*mlx_array_new_int_ptr)(int val) = NULL;
|
|
mlx_array (*mlx_array_new_float32_ptr)(float val) = NULL;
|
|
mlx_array (*mlx_array_new_float_ptr)(float val) = NULL;
|
|
mlx_array (*mlx_array_new_float64_ptr)(double val) = NULL;
|
|
mlx_array (*mlx_array_new_double_ptr)(double val) = NULL;
|
|
mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val) = NULL;
|
|
mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL;
|
|
mlx_array (*mlx_array_new_data_managed_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) = NULL;
|
|
mlx_array (*mlx_array_new_data_managed_payload_ptr)(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) = NULL;
|
|
int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src) = NULL;
|
|
int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val) = NULL;
|
|
int (*mlx_array_set_int_ptr)(mlx_array* arr, int val) = NULL;
|
|
int (*mlx_array_set_float32_ptr)(mlx_array* arr, float val) = NULL;
|
|
int (*mlx_array_set_float_ptr)(mlx_array* arr, float val) = NULL;
|
|
int (*mlx_array_set_float64_ptr)(mlx_array* arr, double val) = NULL;
|
|
int (*mlx_array_set_double_ptr)(mlx_array* arr, double val) = NULL;
|
|
int (*mlx_array_set_complex_ptr)(mlx_array* arr, float real_val, float imag_val) = NULL;
|
|
int (*mlx_array_set_data_ptr)(mlx_array* arr, const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL;
|
|
size_t (*mlx_array_itemsize_ptr)(const mlx_array arr) = NULL;
|
|
size_t (*mlx_array_size_ptr)(const mlx_array arr) = NULL;
|
|
size_t (*mlx_array_nbytes_ptr)(const mlx_array arr) = NULL;
|
|
size_t (*mlx_array_ndim_ptr)(const mlx_array arr) = NULL;
|
|
const int* (*mlx_array_shape_ptr)(const mlx_array arr) = NULL;
|
|
const size_t* (*mlx_array_strides_ptr)(const mlx_array arr) = NULL;
|
|
int (*mlx_array_dim_ptr)(const mlx_array arr, int dim) = NULL;
|
|
mlx_dtype (*mlx_array_dtype_ptr)(const mlx_array arr) = NULL;
|
|
int (*mlx_array_eval_ptr)(mlx_array arr) = NULL;
|
|
int (*mlx_array_item_bool_ptr)(bool* res, const mlx_array arr) = NULL;
|
|
int (*mlx_array_item_uint8_ptr)(uint8_t* res, const mlx_array arr) = NULL;
|
|
int (*mlx_array_item_uint16_ptr)(uint16_t* res, const mlx_array arr) = NULL;
|
|
int (*mlx_array_item_uint32_ptr)(uint32_t* res, const mlx_array arr) = NULL;
|
|
int (*mlx_array_item_uint64_ptr)(uint64_t* res, const mlx_array arr) = NULL;
|
|
int (*mlx_array_item_int8_ptr)(int8_t* res, const mlx_array arr) = NULL;
|
|
int (*mlx_array_item_int16_ptr)(int16_t* res, const mlx_array arr) = NULL;
|
|
int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr) = NULL;
|
|
int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr) = NULL;
|
|
int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr) = NULL;
|
|
int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr) = NULL;
|
|
int (*mlx_array_item_complex64_ptr)(mlx_complex64_t* res, const mlx_array arr) = NULL;
|
|
#if defined(__aarch64__) || defined(_M_ARM64)
|
|
int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr) = NULL;
|
|
#endif
|
|
#if defined(__aarch64__) || defined(_M_ARM64)
|
|
int (*mlx_array_item_bfloat16_ptr)(bfloat16_t* res, const mlx_array arr) = NULL;
|
|
#endif
|
|
const bool* (*mlx_array_data_bool_ptr)(const mlx_array arr) = NULL;
|
|
const uint8_t* (*mlx_array_data_uint8_ptr)(const mlx_array arr) = NULL;
|
|
const uint16_t* (*mlx_array_data_uint16_ptr)(const mlx_array arr) = NULL;
|
|
const uint32_t* (*mlx_array_data_uint32_ptr)(const mlx_array arr) = NULL;
|
|
const uint64_t* (*mlx_array_data_uint64_ptr)(const mlx_array arr) = NULL;
|
|
const int8_t* (*mlx_array_data_int8_ptr)(const mlx_array arr) = NULL;
|
|
const int16_t* (*mlx_array_data_int16_ptr)(const mlx_array arr) = NULL;
|
|
const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr) = NULL;
|
|
const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr) = NULL;
|
|
const float* (*mlx_array_data_float32_ptr)(const mlx_array arr) = NULL;
|
|
const double* (*mlx_array_data_float64_ptr)(const mlx_array arr) = NULL;
|
|
const mlx_complex64_t* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
|
|
#if defined(__aarch64__) || defined(_M_ARM64)
|
|
const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr) = NULL;
|
|
#endif
|
|
#if defined(__aarch64__) || defined(_M_ARM64)
|
|
const bfloat16_t* (*mlx_array_data_bfloat16_ptr)(const mlx_array arr) = NULL;
|
|
#endif
|
|
int (*_mlx_array_is_available_ptr)(bool* res, const mlx_array arr) = NULL;
|
|
int (*_mlx_array_wait_ptr)(const mlx_array arr) = NULL;
|
|
int (*_mlx_array_is_contiguous_ptr)(bool* res, const mlx_array arr) = NULL;
|
|
int (*_mlx_array_is_row_contiguous_ptr)(bool* res, const mlx_array arr) = NULL;
|
|
int (*_mlx_array_is_col_contiguous_ptr)(bool* res, const mlx_array arr) = NULL;
|
|
mlx_closure (*mlx_closure_new_ptr)(void) = NULL;
|
|
int (*mlx_closure_free_ptr)(mlx_closure cls) = NULL;
|
|
mlx_closure (*mlx_closure_new_func_ptr)(int (*fun)(mlx_vector_array*, const mlx_vector_array)) = NULL;
|
|
mlx_closure (*mlx_closure_new_func_payload_ptr)(int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) = NULL;
|
|
int (*mlx_closure_set_ptr)(mlx_closure* cls, const mlx_closure src) = NULL;
|
|
int (*mlx_closure_apply_ptr)(mlx_vector_array* res, mlx_closure cls, const mlx_vector_array input) = NULL;
|
|
mlx_closure (*mlx_closure_new_unary_ptr)(int (*fun)(mlx_array*, const mlx_array)) = NULL;
|
|
mlx_closure_kwargs (*mlx_closure_kwargs_new_ptr)(void) = NULL;
|
|
int (*mlx_closure_kwargs_free_ptr)(mlx_closure_kwargs cls) = NULL;
|
|
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array)) = NULL;
|
|
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array, void*), void* payload, void (*dtor)(void*)) = NULL;
|
|
int (*mlx_closure_kwargs_set_ptr)(mlx_closure_kwargs* cls, const mlx_closure_kwargs src) = NULL;
|
|
int (*mlx_closure_kwargs_apply_ptr)(mlx_vector_array* res, mlx_closure_kwargs cls, const mlx_vector_array input_0, const mlx_map_string_to_array input_1) = NULL;
|
|
mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_ptr)(void) = NULL;
|
|
int (*mlx_closure_value_and_grad_free_ptr)(mlx_closure_value_and_grad cls) = NULL;
|
|
mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_ptr)(int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)) = NULL;
|
|
mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) = NULL;
|
|
int (*mlx_closure_value_and_grad_set_ptr)(mlx_closure_value_and_grad* cls, const mlx_closure_value_and_grad src) = NULL;
|
|
int (*mlx_closure_value_and_grad_apply_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, mlx_closure_value_and_grad cls, const mlx_vector_array input) = NULL;
|
|
mlx_closure_custom (*mlx_closure_custom_new_ptr)(void) = NULL;
|
|
int (*mlx_closure_custom_free_ptr)(mlx_closure_custom cls) = NULL;
|
|
mlx_closure_custom (*mlx_closure_custom_new_func_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array)) = NULL;
|
|
mlx_closure_custom (*mlx_closure_custom_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) = NULL;
|
|
int (*mlx_closure_custom_set_ptr)(mlx_closure_custom* cls, const mlx_closure_custom src) = NULL;
|
|
int (*mlx_closure_custom_apply_ptr)(mlx_vector_array* res, mlx_closure_custom cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const mlx_vector_array input_2) = NULL;
|
|
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_ptr)(void) = NULL;
|
|
int (*mlx_closure_custom_jvp_free_ptr)(mlx_closure_custom_jvp cls) = NULL;
|
|
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num)) = NULL;
|
|
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)) = NULL;
|
|
int (*mlx_closure_custom_jvp_set_ptr)(mlx_closure_custom_jvp* cls, const mlx_closure_custom_jvp src) = NULL;
|
|
int (*mlx_closure_custom_jvp_apply_ptr)(mlx_vector_array* res, mlx_closure_custom_jvp cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const int* input_2, size_t input_2_num) = NULL;
|
|
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_ptr)(void) = NULL;
|
|
int (*mlx_closure_custom_vmap_free_ptr)(mlx_closure_custom_vmap cls) = NULL;
|
|
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_ptr)(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num)) = NULL;
|
|
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)) = NULL;
|
|
int (*mlx_closure_custom_vmap_set_ptr)(mlx_closure_custom_vmap* cls, const mlx_closure_custom_vmap src) = NULL;
|
|
int (*mlx_closure_custom_vmap_apply_ptr)(mlx_vector_array* res_0, mlx_vector_int* res_1, mlx_closure_custom_vmap cls, const mlx_vector_array input_0, const int* input_1, size_t input_1_num) = NULL;
|
|
int (*mlx_compile_ptr)(mlx_closure* res, const mlx_closure fun, bool shapeless) = NULL;
|
|
int (*mlx_detail_compile_ptr)(mlx_closure* res, const mlx_closure fun, uintptr_t fun_id, bool shapeless, const uint64_t* constants, size_t constants_num) = NULL;
|
|
int (*mlx_detail_compile_clear_cache_ptr)(void) = NULL;
|
|
int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id) = NULL;
|
|
int (*mlx_disable_compile_ptr)(void) = NULL;
|
|
int (*mlx_enable_compile_ptr)(void) = NULL;
|
|
int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode) = NULL;
|
|
int (*mlx_cuda_is_available_ptr)(bool* res) = NULL;
|
|
mlx_device (*mlx_device_new_ptr)(void) = NULL;
|
|
mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index) = NULL;
|
|
int (*mlx_device_free_ptr)(mlx_device dev) = NULL;
|
|
int (*mlx_device_set_ptr)(mlx_device* dev, const mlx_device src) = NULL;
|
|
int (*mlx_device_tostring_ptr)(mlx_string* str, mlx_device dev) = NULL;
|
|
bool (*mlx_device_equal_ptr)(mlx_device lhs, mlx_device rhs) = NULL;
|
|
int (*mlx_device_get_index_ptr)(int* index, mlx_device dev) = NULL;
|
|
int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev) = NULL;
|
|
int (*mlx_get_default_device_ptr)(mlx_device* dev) = NULL;
|
|
int (*mlx_set_default_device_ptr)(mlx_device dev) = NULL;
|
|
int (*mlx_device_is_available_ptr)(bool* avail, mlx_device dev) = NULL;
|
|
int (*mlx_device_count_ptr)(int* count, mlx_device_type type) = NULL;
|
|
mlx_device_info (*mlx_device_info_new_ptr)(void) = NULL;
|
|
int (*mlx_device_info_get_ptr)(mlx_device_info* info, mlx_device dev) = NULL;
|
|
int (*mlx_device_info_free_ptr)(mlx_device_info info) = NULL;
|
|
int (*mlx_device_info_has_key_ptr)(bool* exists, mlx_device_info info, const char* key) = NULL;
|
|
int (*mlx_device_info_is_string_ptr)(bool* is_string, mlx_device_info info, const char* key) = NULL;
|
|
int (*mlx_device_info_get_string_ptr)(const char** value, mlx_device_info info, const char* key) = NULL;
|
|
int (*mlx_device_info_get_size_ptr)(size_t* value, mlx_device_info info, const char* key) = NULL;
|
|
int (*mlx_device_info_get_keys_ptr)(mlx_vector_string* keys, mlx_device_info info) = NULL;
|
|
int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) = NULL;
|
|
int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
|
|
int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
|
|
int (*mlx_distributed_all_sum_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
|
|
int (*mlx_distributed_recv_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, int src, const mlx_distributed_group group , const mlx_stream s) = NULL;
|
|
int (*mlx_distributed_recv_like_ptr)(mlx_array* res, const mlx_array x, int src, const mlx_distributed_group group , const mlx_stream s) = NULL;
|
|
int (*mlx_distributed_send_ptr)(mlx_array* res, const mlx_array x, int dst, const mlx_distributed_group group , const mlx_stream s) = NULL;
|
|
int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
|
|
int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group) = NULL;
|
|
int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group) = NULL;
|
|
mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key) = NULL;
|
|
bool (*mlx_distributed_is_available_ptr)(void) = NULL;
|
|
mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict) = NULL;
|
|
void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) = NULL;
|
|
void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...) = NULL;
|
|
int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) = NULL;
|
|
int (*mlx_export_function_kwargs_ptr)(const char* file, const mlx_closure_kwargs fun, const mlx_vector_array args, const mlx_map_string_to_array kwargs, bool shapeless) = NULL;
|
|
mlx_function_exporter (*mlx_function_exporter_new_ptr)(const char* file, const mlx_closure fun, bool shapeless) = NULL;
|
|
int (*mlx_function_exporter_free_ptr)(mlx_function_exporter xfunc) = NULL;
|
|
int (*mlx_function_exporter_apply_ptr)(const mlx_function_exporter xfunc, const mlx_vector_array args) = NULL;
|
|
int (*mlx_function_exporter_apply_kwargs_ptr)(const mlx_function_exporter xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) = NULL;
|
|
mlx_imported_function (*mlx_imported_function_new_ptr)(const char* file) = NULL;
|
|
int (*mlx_imported_function_free_ptr)(mlx_imported_function xfunc) = NULL;
|
|
int (*mlx_imported_function_apply_ptr)(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args) = NULL;
|
|
int (*mlx_imported_function_apply_kwargs_ptr)(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) = NULL;
|
|
mlx_fast_cuda_kernel_config (*mlx_fast_cuda_kernel_config_new_ptr)(void) = NULL;
|
|
void (*mlx_fast_cuda_kernel_config_free_ptr)(mlx_fast_cuda_kernel_config cls) = NULL;
|
|
int (*mlx_fast_cuda_kernel_config_add_output_arg_ptr)(mlx_fast_cuda_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype) = NULL;
|
|
int (*mlx_fast_cuda_kernel_config_set_grid_ptr)(mlx_fast_cuda_kernel_config cls, int grid1, int grid2, int grid3) = NULL;
|
|
int (*mlx_fast_cuda_kernel_config_set_thread_group_ptr)(mlx_fast_cuda_kernel_config cls, int thread1, int thread2, int thread3) = NULL;
|
|
int (*mlx_fast_cuda_kernel_config_set_init_value_ptr)(mlx_fast_cuda_kernel_config cls, float value) = NULL;
|
|
int (*mlx_fast_cuda_kernel_config_set_verbose_ptr)(mlx_fast_cuda_kernel_config cls, bool verbose) = NULL;
|
|
int (*mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr)(mlx_fast_cuda_kernel_config cls, const char* name, mlx_dtype dtype) = NULL;
|
|
int (*mlx_fast_cuda_kernel_config_add_template_arg_int_ptr)(mlx_fast_cuda_kernel_config cls, const char* name, int value) = NULL;
|
|
int (*mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr)(mlx_fast_cuda_kernel_config cls, const char* name, bool value) = NULL;
|
|
mlx_fast_cuda_kernel (*mlx_fast_cuda_kernel_new_ptr)(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, int shared_memory) = NULL;
|
|
void (*mlx_fast_cuda_kernel_free_ptr)(mlx_fast_cuda_kernel cls) = NULL;
|
|
int (*mlx_fast_cuda_kernel_apply_ptr)(mlx_vector_array* outputs, mlx_fast_cuda_kernel cls, const mlx_vector_array inputs, const mlx_fast_cuda_kernel_config config, const mlx_stream stream) = NULL;
|
|
int (*mlx_fast_layer_norm_ptr)(mlx_array* res, const mlx_array x, const mlx_array weight , const mlx_array bias , float eps, const mlx_stream s) = NULL;
|
|
mlx_fast_metal_kernel_config (*mlx_fast_metal_kernel_config_new_ptr)(void) = NULL;
|
|
void (*mlx_fast_metal_kernel_config_free_ptr)(mlx_fast_metal_kernel_config cls) = NULL;
|
|
int (*mlx_fast_metal_kernel_config_add_output_arg_ptr)(mlx_fast_metal_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype) = NULL;
|
|
int (*mlx_fast_metal_kernel_config_set_grid_ptr)(mlx_fast_metal_kernel_config cls, int grid1, int grid2, int grid3) = NULL;
|
|
int (*mlx_fast_metal_kernel_config_set_thread_group_ptr)(mlx_fast_metal_kernel_config cls, int thread1, int thread2, int thread3) = NULL;
|
|
int (*mlx_fast_metal_kernel_config_set_init_value_ptr)(mlx_fast_metal_kernel_config cls, float value) = NULL;
|
|
int (*mlx_fast_metal_kernel_config_set_verbose_ptr)(mlx_fast_metal_kernel_config cls, bool verbose) = NULL;
|
|
int (*mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr)(mlx_fast_metal_kernel_config cls, const char* name, mlx_dtype dtype) = NULL;
|
|
int (*mlx_fast_metal_kernel_config_add_template_arg_int_ptr)(mlx_fast_metal_kernel_config cls, const char* name, int value) = NULL;
|
|
int (*mlx_fast_metal_kernel_config_add_template_arg_bool_ptr)(mlx_fast_metal_kernel_config cls, const char* name, bool value) = NULL;
|
|
mlx_fast_metal_kernel (*mlx_fast_metal_kernel_new_ptr)(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, bool atomic_outputs) = NULL;
|
|
void (*mlx_fast_metal_kernel_free_ptr)(mlx_fast_metal_kernel cls) = NULL;
|
|
int (*mlx_fast_metal_kernel_apply_ptr)(mlx_vector_array* outputs, mlx_fast_metal_kernel cls, const mlx_vector_array inputs, const mlx_fast_metal_kernel_config config, const mlx_stream stream) = NULL;
|
|
int (*mlx_fast_rms_norm_ptr)(mlx_array* res, const mlx_array x, const mlx_array weight , float eps, const mlx_stream s) = NULL;
|
|
int (*mlx_fast_rope_ptr)(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, int offset, const mlx_array freqs , const mlx_stream s) = NULL;
|
|
int (*mlx_fast_rope_dynamic_ptr)(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, const mlx_array offset, const mlx_array freqs , const mlx_stream s) = NULL;
|
|
int (*mlx_fast_scaled_dot_product_attention_ptr)(mlx_array* res, const mlx_array queries, const mlx_array keys, const mlx_array values, float scale, const char* mask_mode, const mlx_array mask_arr , const mlx_array sinks , const mlx_stream s) = NULL;
|
|
int (*mlx_fft_fft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_fft_fft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_fft_fftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_fft_fftshift_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_fft_ifft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_fft_ifft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_fft_ifftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_fft_ifftshift_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_fft_irfft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_fft_irfft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_fft_irfftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_fft_rfft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_fft_rfft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_fft_rfftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_load_reader_ptr)(mlx_array* res, mlx_io_reader in_stream, const mlx_stream s) = NULL;
|
|
int (*mlx_load_ptr)(mlx_array* res, const char* file, const mlx_stream s) = NULL;
|
|
int (*mlx_load_safetensors_reader_ptr)(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, mlx_io_reader in_stream, const mlx_stream s) = NULL;
|
|
int (*mlx_load_safetensors_ptr)(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, const char* file, const mlx_stream s) = NULL;
|
|
int (*mlx_save_writer_ptr)(mlx_io_writer out_stream, const mlx_array a) = NULL;
|
|
int (*mlx_save_ptr)(const char* file, const mlx_array a) = NULL;
|
|
int (*mlx_save_safetensors_writer_ptr)(mlx_io_writer in_stream, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) = NULL;
|
|
int (*mlx_save_safetensors_ptr)(const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) = NULL;
|
|
mlx_io_reader (*mlx_io_reader_new_ptr)(void* desc, mlx_io_vtable vtable) = NULL;
|
|
int (*mlx_io_reader_descriptor_ptr)(void** desc_, mlx_io_reader io) = NULL;
|
|
int (*mlx_io_reader_tostring_ptr)(mlx_string* str_, mlx_io_reader io) = NULL;
|
|
int (*mlx_io_reader_free_ptr)(mlx_io_reader io) = NULL;
|
|
mlx_io_writer (*mlx_io_writer_new_ptr)(void* desc, mlx_io_vtable vtable) = NULL;
|
|
int (*mlx_io_writer_descriptor_ptr)(void** desc_, mlx_io_writer io) = NULL;
|
|
int (*mlx_io_writer_tostring_ptr)(mlx_string* str_, mlx_io_writer io) = NULL;
|
|
int (*mlx_io_writer_free_ptr)(mlx_io_writer io) = NULL;
|
|
int (*mlx_linalg_cholesky_ptr)(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) = NULL;
|
|
int (*mlx_linalg_cholesky_inv_ptr)(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) = NULL;
|
|
int (*mlx_linalg_cross_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_linalg_eig_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_linalg_eigh_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const char* UPLO, const mlx_stream s) = NULL;
|
|
int (*mlx_linalg_eigvals_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_linalg_eigvalsh_ptr)(mlx_array* res, const mlx_array a, const char* UPLO, const mlx_stream s) = NULL;
|
|
int (*mlx_linalg_inv_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_linalg_lu_ptr)(mlx_vector_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_linalg_lu_factor_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_linalg_norm_ptr)(mlx_array* res, const mlx_array a, double ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_linalg_norm_matrix_ptr)(mlx_array* res, const mlx_array a, const char* ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_linalg_norm_l2_ptr)(mlx_array* res, const mlx_array a, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_linalg_pinv_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_linalg_qr_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_linalg_solve_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_linalg_solve_triangular_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, bool upper, const mlx_stream s) = NULL;
|
|
int (*mlx_linalg_svd_ptr)(mlx_vector_array* res, const mlx_array a, bool compute_uv, const mlx_stream s) = NULL;
|
|
int (*mlx_linalg_tri_inv_ptr)(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) = NULL;
|
|
mlx_map_string_to_array (*mlx_map_string_to_array_new_ptr)(void) = NULL;
|
|
int (*mlx_map_string_to_array_set_ptr)(mlx_map_string_to_array* map, const mlx_map_string_to_array src) = NULL;
|
|
int (*mlx_map_string_to_array_free_ptr)(mlx_map_string_to_array map) = NULL;
|
|
int (*mlx_map_string_to_array_insert_ptr)(mlx_map_string_to_array map, const char* key, const mlx_array value) = NULL;
|
|
int (*mlx_map_string_to_array_get_ptr)(mlx_array* value, const mlx_map_string_to_array map, const char* key) = NULL;
|
|
mlx_map_string_to_array_iterator (*mlx_map_string_to_array_iterator_new_ptr)(mlx_map_string_to_array map) = NULL;
|
|
int (*mlx_map_string_to_array_iterator_free_ptr)(mlx_map_string_to_array_iterator it) = NULL;
|
|
int (*mlx_map_string_to_array_iterator_next_ptr)(const char** key, mlx_array* value, mlx_map_string_to_array_iterator it) = NULL;
|
|
mlx_map_string_to_string (*mlx_map_string_to_string_new_ptr)(void) = NULL;
|
|
int (*mlx_map_string_to_string_set_ptr)(mlx_map_string_to_string* map, const mlx_map_string_to_string src) = NULL;
|
|
int (*mlx_map_string_to_string_free_ptr)(mlx_map_string_to_string map) = NULL;
|
|
int (*mlx_map_string_to_string_insert_ptr)(mlx_map_string_to_string map, const char* key, const char* value) = NULL;
|
|
int (*mlx_map_string_to_string_get_ptr)(const char** value, const mlx_map_string_to_string map, const char* key) = NULL;
|
|
mlx_map_string_to_string_iterator (*mlx_map_string_to_string_iterator_new_ptr)(mlx_map_string_to_string map) = NULL;
|
|
int (*mlx_map_string_to_string_iterator_free_ptr)(mlx_map_string_to_string_iterator it) = NULL;
|
|
int (*mlx_map_string_to_string_iterator_next_ptr)(const char** key, const char** value, mlx_map_string_to_string_iterator it) = NULL;
|
|
int (*mlx_clear_cache_ptr)(void) = NULL;
|
|
int (*mlx_get_active_memory_ptr)(size_t* res) = NULL;
|
|
int (*mlx_get_cache_memory_ptr)(size_t* res) = NULL;
|
|
int (*mlx_get_memory_limit_ptr)(size_t* res) = NULL;
|
|
int (*mlx_get_peak_memory_ptr)(size_t* res) = NULL;
|
|
int (*mlx_reset_peak_memory_ptr)(void) = NULL;
|
|
int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit) = NULL;
|
|
int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit) = NULL;
|
|
int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit) = NULL;
|
|
int (*mlx_metal_is_available_ptr)(bool* res) = NULL;
|
|
int (*mlx_metal_start_capture_ptr)(const char* path) = NULL;
|
|
int (*mlx_metal_stop_capture_ptr)(void) = NULL;
|
|
int (*mlx_abs_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_add_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_addmm_ptr)(mlx_array* res, const mlx_array c, const mlx_array a, const mlx_array b, float alpha, float beta, const mlx_stream s) = NULL;
|
|
int (*mlx_all_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_all_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_all_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_allclose_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s) = NULL;
|
|
int (*mlx_any_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_any_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_any_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_arange_ptr)(mlx_array* res, double start, double stop, double step, mlx_dtype dtype, const mlx_stream s) = NULL;
|
|
int (*mlx_arccos_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_arccosh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_arcsin_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_arcsinh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_arctan_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_arctan2_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_arctanh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_argmax_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_argmax_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_argmin_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_argmin_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_argpartition_axis_ptr)(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_argpartition_ptr)(mlx_array* res, const mlx_array a, int kth, const mlx_stream s) = NULL;
|
|
int (*mlx_argsort_axis_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_argsort_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_array_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, bool equal_nan, const mlx_stream s) = NULL;
|
|
int (*mlx_as_strided_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const int64_t* strides, size_t strides_num, size_t offset, const mlx_stream s) = NULL;
|
|
int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s) = NULL;
|
|
int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) = NULL;
|
|
int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) = NULL;
|
|
int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL;
|
|
int (*mlx_ceil_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_clip_ptr)(mlx_array* res, const mlx_array a, const mlx_array a_min , const mlx_array a_max , const mlx_stream s) = NULL;
|
|
int (*mlx_concatenate_axis_ptr)(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_concatenate_ptr)(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) = NULL;
|
|
int (*mlx_conjugate_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_contiguous_ptr)(mlx_array* res, const mlx_array a, bool allow_col_major, const mlx_stream s) = NULL;
|
|
int (*mlx_conv1d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int groups, const mlx_stream s) = NULL;
|
|
int (*mlx_conv2d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int groups, const mlx_stream s) = NULL;
|
|
int (*mlx_conv3d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int groups, const mlx_stream s) = NULL;
|
|
int (*mlx_conv_general_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, const int* stride, size_t stride_num, const int* padding_lo, size_t padding_lo_num, const int* padding_hi, size_t padding_hi_num, const int* kernel_dilation, size_t kernel_dilation_num, const int* input_dilation, size_t input_dilation_num, int groups, bool flip, const mlx_stream s) = NULL;
|
|
int (*mlx_conv_transpose1d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int output_padding, int groups, const mlx_stream s) = NULL;
|
|
int (*mlx_conv_transpose2d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int output_padding_0, int output_padding_1, int groups, const mlx_stream s) = NULL;
|
|
int (*mlx_conv_transpose3d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int output_padding_0, int output_padding_1, int output_padding_2, int groups, const mlx_stream s) = NULL;
|
|
int (*mlx_copy_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_cos_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_cosh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_cummax_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
|
|
int (*mlx_cummin_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
|
|
int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
|
|
int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
|
|
int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) = NULL;
|
|
int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) = NULL;
|
|
int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
|
|
int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) = NULL;
|
|
int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_divmod_ptr)(mlx_vector_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_einsum_ptr)(mlx_array* res, const char* subscripts, const mlx_vector_array operands, const mlx_stream s) = NULL;
|
|
int (*mlx_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_erf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_erfinv_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_exp_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_expand_dims_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_expand_dims_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_expm1_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_eye_ptr)(mlx_array* res, int n, int m, int k, mlx_dtype dtype, const mlx_stream s) = NULL;
|
|
int (*mlx_flatten_ptr)(mlx_array* res, const mlx_array a, int start_axis, int end_axis, const mlx_stream s) = NULL;
|
|
int (*mlx_floor_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_floor_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_from_fp8_ptr)(mlx_array* res, const mlx_array x, mlx_dtype dtype, const mlx_stream s) = NULL;
|
|
int (*mlx_full_ptr)(mlx_array* res, const int* shape, size_t shape_num, const mlx_array vals, mlx_dtype dtype, const mlx_stream s) = NULL;
|
|
int (*mlx_full_like_ptr)(mlx_array* res, const mlx_array a, const mlx_array vals, mlx_dtype dtype, const mlx_stream s) = NULL;
|
|
int (*mlx_gather_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const int* axes, size_t axes_num, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_gather_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_gather_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array lhs_indices , const mlx_array rhs_indices , bool sorted_indices, const mlx_stream s) = NULL;
|
|
int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , const mlx_array lhs_indices , const mlx_array rhs_indices , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, bool sorted_indices, const mlx_stream s) = NULL;
|
|
int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) = NULL;
|
|
int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
|
|
int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_isclose_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s) = NULL;
|
|
int (*mlx_isfinite_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_isinf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_isnan_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_isneginf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_isposinf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_kron_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_left_shift_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_less_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_less_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_linspace_ptr)(mlx_array* res, double start, double stop, int num, mlx_dtype dtype, const mlx_stream s) = NULL;
|
|
int (*mlx_log_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_log10_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_log1p_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_log2_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_logaddexp_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_logcumsumexp_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
|
|
int (*mlx_logical_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_logical_not_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_logical_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_logsumexp_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_logsumexp_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_logsumexp_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_masked_scatter_ptr)(mlx_array* res, const mlx_array a, const mlx_array mask, const mlx_array src, const mlx_stream s) = NULL;
|
|
int (*mlx_matmul_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_max_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_max_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_max_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_maximum_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_mean_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_mean_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_mean_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_median_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_meshgrid_ptr)(mlx_vector_array* res, const mlx_vector_array arrays, bool sparse, const char* indexing, const mlx_stream s) = NULL;
|
|
int (*mlx_min_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_min_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_min_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_minimum_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_moveaxis_ptr)(mlx_array* res, const mlx_array a, int source, int destination, const mlx_stream s) = NULL;
|
|
int (*mlx_multiply_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_nan_to_num_ptr)(mlx_array* res, const mlx_array a, float nan, mlx_optional_float posinf, mlx_optional_float neginf, const mlx_stream s) = NULL;
|
|
int (*mlx_negative_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_not_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_number_of_elements_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool inverted, mlx_dtype dtype, const mlx_stream s) = NULL;
|
|
int (*mlx_ones_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s) = NULL;
|
|
int (*mlx_ones_like_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_outer_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_pad_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const int* low_pad_size, size_t low_pad_size_num, const int* high_pad_size, size_t high_pad_size_num, const mlx_array pad_value, const char* mode, const mlx_stream s) = NULL;
|
|
int (*mlx_pad_symmetric_ptr)(mlx_array* res, const mlx_array a, int pad_width, const mlx_array pad_value, const char* mode, const mlx_stream s) = NULL;
|
|
int (*mlx_partition_axis_ptr)(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_partition_ptr)(mlx_array* res, const mlx_array a, int kth, const mlx_stream s) = NULL;
|
|
int (*mlx_power_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
|
int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
|
int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
|
int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_reciprocal_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_remainder_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_repeat_axis_ptr)(mlx_array* res, const mlx_array arr, int repeats, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_repeat_ptr)(mlx_array* res, const mlx_array arr, int repeats, const mlx_stream s) = NULL;
|
|
int (*mlx_reshape_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL;
|
|
int (*mlx_right_shift_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_roll_axis_ptr)(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_roll_axes_ptr)(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_roll_ptr)(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const mlx_stream s) = NULL;
|
|
int (*mlx_round_ptr)(mlx_array* res, const mlx_array a, int decimals, const mlx_stream s) = NULL;
|
|
int (*mlx_rsqrt_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_scatter_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_scatter_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_scatter_add_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_scatter_add_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_scatter_add_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_scatter_max_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_scatter_max_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_scatter_min_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_scatter_min_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_scatter_prod_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_scatter_prod_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_segmented_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array segments, const mlx_stream s) = NULL;
|
|
int (*mlx_sigmoid_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_sign_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_sin_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_sinh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_slice_ptr)(mlx_array* res, const mlx_array a, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) = NULL;
|
|
int (*mlx_slice_dynamic_ptr)(mlx_array* res, const mlx_array a, const mlx_array start, const int* axes, size_t axes_num, const int* slice_size, size_t slice_size_num, const mlx_stream s) = NULL;
|
|
int (*mlx_slice_update_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) = NULL;
|
|
int (*mlx_slice_update_dynamic_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const mlx_array start, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_softmax_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool precise, const mlx_stream s) = NULL;
|
|
int (*mlx_softmax_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool precise, const mlx_stream s) = NULL;
|
|
int (*mlx_softmax_ptr)(mlx_array* res, const mlx_array a, bool precise, const mlx_stream s) = NULL;
|
|
int (*mlx_sort_axis_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_sort_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_split_ptr)(mlx_vector_array* res, const mlx_array a, int num_splits, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_split_sections_ptr)(mlx_vector_array* res, const mlx_array a, const int* indices, size_t indices_num, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_sqrt_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_square_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_squeeze_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_squeeze_axis_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_squeeze_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_stack_axis_ptr)(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_stack_ptr)(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) = NULL;
|
|
int (*mlx_std_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s) = NULL;
|
|
int (*mlx_std_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s) = NULL;
|
|
int (*mlx_std_ptr)(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s) = NULL;
|
|
int (*mlx_stop_gradient_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_subtract_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
|
int (*mlx_sum_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_sum_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_sum_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
|
|
int (*mlx_swapaxes_ptr)(mlx_array* res, const mlx_array a, int axis1, int axis2, const mlx_stream s) = NULL;
|
|
int (*mlx_take_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_take_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_stream s) = NULL;
|
|
int (*mlx_take_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_tan_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_tanh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_tensordot_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const int* axes_a, size_t axes_a_num, const int* axes_b, size_t axes_b_num, const mlx_stream s) = NULL;
|
|
int (*mlx_tensordot_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_tile_ptr)(mlx_array* res, const mlx_array arr, const int* reps, size_t reps_num, const mlx_stream s) = NULL;
|
|
int (*mlx_to_fp8_ptr)(mlx_array* res, const mlx_array x, const mlx_stream s) = NULL;
|
|
int (*mlx_topk_axis_ptr)(mlx_array* res, const mlx_array a, int k, int axis, const mlx_stream s) = NULL;
|
|
int (*mlx_topk_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
|
|
int (*mlx_trace_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, mlx_dtype dtype, const mlx_stream s) = NULL;
|
|
int (*mlx_transpose_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
|
|
int (*mlx_transpose_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_tri_ptr)(mlx_array* res, int n, int m, int k, mlx_dtype type, const mlx_stream s) = NULL;
|
|
int (*mlx_tril_ptr)(mlx_array* res, const mlx_array x, int k, const mlx_stream s) = NULL;
|
|
int (*mlx_triu_ptr)(mlx_array* res, const mlx_array x, int k, const mlx_stream s) = NULL;
|
|
int (*mlx_unflatten_ptr)(mlx_array* res, const mlx_array a, int axis, const int* shape, size_t shape_num, const mlx_stream s) = NULL;
|
|
int (*mlx_var_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s) = NULL;
|
|
int (*mlx_var_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s) = NULL;
|
|
int (*mlx_var_ptr)(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s) = NULL;
|
|
int (*mlx_view_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s) = NULL;
|
|
int (*mlx_where_ptr)(mlx_array* res, const mlx_array condition, const mlx_array x, const mlx_array y, const mlx_stream s) = NULL;
|
|
int (*mlx_zeros_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s) = NULL;
|
|
int (*mlx_zeros_like_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
|
int (*mlx_random_bernoulli_ptr)(mlx_array* res, const mlx_array p, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s) = NULL;
|
|
int (*mlx_random_bits_ptr)(mlx_array* res, const int* shape, size_t shape_num, int width, const mlx_array key , const mlx_stream s) = NULL;
|
|
int (*mlx_random_categorical_shape_ptr)(mlx_array* res, const mlx_array logits, int axis, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s) = NULL;
|
|
int (*mlx_random_categorical_num_samples_ptr)(mlx_array* res, const mlx_array logits_, int axis, int num_samples, const mlx_array key , const mlx_stream s) = NULL;
|
|
int (*mlx_random_categorical_ptr)(mlx_array* res, const mlx_array logits, int axis, const mlx_array key , const mlx_stream s) = NULL;
|
|
int (*mlx_random_gumbel_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL;
|
|
int (*mlx_random_key_ptr)(mlx_array* res, uint64_t seed) = NULL;
|
|
int (*mlx_random_laplace_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s) = NULL;
|
|
int (*mlx_random_multivariate_normal_ptr)(mlx_array* res, const mlx_array mean, const mlx_array cov, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL;
|
|
int (*mlx_random_normal_broadcast_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array loc , const mlx_array scale , const mlx_array key , const mlx_stream s) = NULL;
|
|
int (*mlx_random_normal_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s) = NULL;
|
|
int (*mlx_random_permutation_ptr)(mlx_array* res, const mlx_array x, int axis, const mlx_array key , const mlx_stream s) = NULL;
|
|
int (*mlx_random_permutation_arange_ptr)(mlx_array* res, int x, const mlx_array key , const mlx_stream s) = NULL;
|
|
int (*mlx_random_randint_ptr)(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL;
|
|
int (*mlx_random_seed_ptr)(uint64_t seed) = NULL;
|
|
int (*mlx_random_split_num_ptr)(mlx_array* res, const mlx_array key, int num, const mlx_stream s) = NULL;
|
|
int (*mlx_random_split_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array key, const mlx_stream s) = NULL;
|
|
int (*mlx_random_truncated_normal_ptr)(mlx_array* res, const mlx_array lower, const mlx_array upper, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL;
|
|
int (*mlx_random_uniform_ptr)(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL;
|
|
mlx_stream (*mlx_stream_new_ptr)(void) = NULL;
|
|
mlx_stream (*mlx_stream_new_device_ptr)(mlx_device dev) = NULL;
|
|
int (*mlx_stream_set_ptr)(mlx_stream* stream, const mlx_stream src) = NULL;
|
|
int (*mlx_stream_free_ptr)(mlx_stream stream) = NULL;
|
|
int (*mlx_stream_tostring_ptr)(mlx_string* str, mlx_stream stream) = NULL;
|
|
bool (*mlx_stream_equal_ptr)(mlx_stream lhs, mlx_stream rhs) = NULL;
|
|
int (*mlx_stream_get_device_ptr)(mlx_device* dev, mlx_stream stream) = NULL;
|
|
int (*mlx_stream_get_index_ptr)(int* index, mlx_stream stream) = NULL;
|
|
int (*mlx_synchronize_ptr)(mlx_stream stream) = NULL;
|
|
int (*mlx_get_default_stream_ptr)(mlx_stream* stream, mlx_device dev) = NULL;
|
|
int (*mlx_set_default_stream_ptr)(mlx_stream stream) = NULL;
|
|
mlx_stream (*mlx_default_cpu_stream_new_ptr)(void) = NULL;
|
|
mlx_stream (*mlx_default_gpu_stream_new_ptr)(void) = NULL;
|
|
mlx_string (*mlx_string_new_ptr)(void) = NULL;
|
|
mlx_string (*mlx_string_new_data_ptr)(const char* str) = NULL;
|
|
int (*mlx_string_set_ptr)(mlx_string* str, const mlx_string src) = NULL;
|
|
const char* (*mlx_string_data_ptr)(mlx_string str) = NULL;
|
|
int (*mlx_string_free_ptr)(mlx_string str) = NULL;
|
|
int (*mlx_async_eval_ptr)(const mlx_vector_array outputs) = NULL;
|
|
int (*mlx_checkpoint_ptr)(mlx_closure* res, const mlx_closure fun) = NULL;
|
|
int (*mlx_custom_function_ptr)(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp , const mlx_closure_custom_jvp fun_jvp , const mlx_closure_custom_vmap fun_vmap) = NULL;
|
|
int (*mlx_custom_vjp_ptr)(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp) = NULL;
|
|
int (*mlx_eval_ptr)(const mlx_vector_array outputs) = NULL;
|
|
int (*mlx_jvp_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array tangents) = NULL;
|
|
int (*mlx_value_and_grad_ptr)(mlx_closure_value_and_grad* res, const mlx_closure fun, const int* argnums, size_t argnums_num) = NULL;
|
|
int (*mlx_vjp_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array cotangents) = NULL;
|
|
int (*mlx_detail_vmap_replace_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array s_inputs, const mlx_vector_array s_outputs, const int* in_axes, size_t in_axes_num, const int* out_axes, size_t out_axes_num) = NULL;
|
|
int (*mlx_detail_vmap_trace_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array inputs, const int* in_axes, size_t in_axes_num) = NULL;
|
|
mlx_vector_array (*mlx_vector_array_new_ptr)(void) = NULL;
|
|
int (*mlx_vector_array_set_ptr)(mlx_vector_array* vec, const mlx_vector_array src) = NULL;
|
|
int (*mlx_vector_array_free_ptr)(mlx_vector_array vec) = NULL;
|
|
mlx_vector_array (*mlx_vector_array_new_data_ptr)(const mlx_array* data, size_t size) = NULL;
|
|
mlx_vector_array (*mlx_vector_array_new_value_ptr)(const mlx_array val) = NULL;
|
|
int (*mlx_vector_array_set_data_ptr)(mlx_vector_array* vec, const mlx_array* data, size_t size) = NULL;
|
|
int (*mlx_vector_array_set_value_ptr)(mlx_vector_array* vec, const mlx_array val) = NULL;
|
|
int (*mlx_vector_array_append_data_ptr)(mlx_vector_array vec, const mlx_array* data, size_t size) = NULL;
|
|
int (*mlx_vector_array_append_value_ptr)(mlx_vector_array vec, const mlx_array val) = NULL;
|
|
size_t (*mlx_vector_array_size_ptr)(mlx_vector_array vec) = NULL;
|
|
int (*mlx_vector_array_get_ptr)(mlx_array* res, const mlx_vector_array vec, size_t idx) = NULL;
|
|
mlx_vector_vector_array (*mlx_vector_vector_array_new_ptr)(void) = NULL;
|
|
int (*mlx_vector_vector_array_set_ptr)(mlx_vector_vector_array* vec, const mlx_vector_vector_array src) = NULL;
|
|
int (*mlx_vector_vector_array_free_ptr)(mlx_vector_vector_array vec) = NULL;
|
|
mlx_vector_vector_array (*mlx_vector_vector_array_new_data_ptr)(const mlx_vector_array* data, size_t size) = NULL;
|
|
mlx_vector_vector_array (*mlx_vector_vector_array_new_value_ptr)(const mlx_vector_array val) = NULL;
|
|
int (*mlx_vector_vector_array_set_data_ptr)(mlx_vector_vector_array* vec, const mlx_vector_array* data, size_t size) = NULL;
|
|
int (*mlx_vector_vector_array_set_value_ptr)(mlx_vector_vector_array* vec, const mlx_vector_array val) = NULL;
|
|
int (*mlx_vector_vector_array_append_data_ptr)(mlx_vector_vector_array vec, const mlx_vector_array* data, size_t size) = NULL;
|
|
int (*mlx_vector_vector_array_append_value_ptr)(mlx_vector_vector_array vec, const mlx_vector_array val) = NULL;
|
|
size_t (*mlx_vector_vector_array_size_ptr)(mlx_vector_vector_array vec) = NULL;
|
|
int (*mlx_vector_vector_array_get_ptr)(mlx_vector_array* res, const mlx_vector_vector_array vec, size_t idx) = NULL;
|
|
mlx_vector_int (*mlx_vector_int_new_ptr)(void) = NULL;
|
|
int (*mlx_vector_int_set_ptr)(mlx_vector_int* vec, const mlx_vector_int src) = NULL;
|
|
int (*mlx_vector_int_free_ptr)(mlx_vector_int vec) = NULL;
|
|
mlx_vector_int (*mlx_vector_int_new_data_ptr)(int* data, size_t size) = NULL;
|
|
mlx_vector_int (*mlx_vector_int_new_value_ptr)(int val) = NULL;
|
|
int (*mlx_vector_int_set_data_ptr)(mlx_vector_int* vec, int* data, size_t size) = NULL;
|
|
int (*mlx_vector_int_set_value_ptr)(mlx_vector_int* vec, int val) = NULL;
|
|
int (*mlx_vector_int_append_data_ptr)(mlx_vector_int vec, int* data, size_t size) = NULL;
|
|
int (*mlx_vector_int_append_value_ptr)(mlx_vector_int vec, int val) = NULL;
|
|
size_t (*mlx_vector_int_size_ptr)(mlx_vector_int vec) = NULL;
|
|
int (*mlx_vector_int_get_ptr)(int* res, const mlx_vector_int vec, size_t idx) = NULL;
|
|
mlx_vector_string (*mlx_vector_string_new_ptr)(void) = NULL;
|
|
int (*mlx_vector_string_set_ptr)(mlx_vector_string* vec, const mlx_vector_string src) = NULL;
|
|
int (*mlx_vector_string_free_ptr)(mlx_vector_string vec) = NULL;
|
|
mlx_vector_string (*mlx_vector_string_new_data_ptr)(const char** data, size_t size) = NULL;
|
|
mlx_vector_string (*mlx_vector_string_new_value_ptr)(const char* val) = NULL;
|
|
int (*mlx_vector_string_set_data_ptr)(mlx_vector_string* vec, const char** data, size_t size) = NULL;
|
|
int (*mlx_vector_string_set_value_ptr)(mlx_vector_string* vec, const char* val) = NULL;
|
|
int (*mlx_vector_string_append_data_ptr)(mlx_vector_string vec, const char** data, size_t size) = NULL;
|
|
int (*mlx_vector_string_append_value_ptr)(mlx_vector_string vec, const char* val) = NULL;
|
|
size_t (*mlx_vector_string_size_ptr)(mlx_vector_string vec) = NULL;
|
|
int (*mlx_vector_string_get_ptr)(char** res, const mlx_vector_string vec, size_t idx) = NULL;
|
|
int (*mlx_version_ptr)(mlx_string* str_) = NULL;
|
|
|
|
// Initialize all function pointers
|
|
int mlx_load_functions(void* handle) {
|
|
if (handle == NULL) {
|
|
fprintf(stderr, "MLX: Invalid library handle\n");
|
|
return -1;
|
|
}
|
|
|
|
mlx_dtype_size_ptr = GET_SYM(handle, "mlx_dtype_size");
|
|
if (mlx_dtype_size_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_dtype_size\n");
|
|
return -1;
|
|
}
|
|
mlx_array_tostring_ptr = GET_SYM(handle, "mlx_array_tostring");
|
|
if (mlx_array_tostring_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_tostring\n");
|
|
return -1;
|
|
}
|
|
mlx_array_new_ptr = GET_SYM(handle, "mlx_array_new");
|
|
if (mlx_array_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new\n");
|
|
return -1;
|
|
}
|
|
mlx_array_free_ptr = GET_SYM(handle, "mlx_array_free");
|
|
if (mlx_array_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_free\n");
|
|
return -1;
|
|
}
|
|
mlx_array_new_bool_ptr = GET_SYM(handle, "mlx_array_new_bool");
|
|
if (mlx_array_new_bool_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_bool\n");
|
|
return -1;
|
|
}
|
|
mlx_array_new_int_ptr = GET_SYM(handle, "mlx_array_new_int");
|
|
if (mlx_array_new_int_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_int\n");
|
|
return -1;
|
|
}
|
|
mlx_array_new_float32_ptr = GET_SYM(handle, "mlx_array_new_float32");
|
|
if (mlx_array_new_float32_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float32\n");
|
|
return -1;
|
|
}
|
|
mlx_array_new_float_ptr = GET_SYM(handle, "mlx_array_new_float");
|
|
if (mlx_array_new_float_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float\n");
|
|
return -1;
|
|
}
|
|
mlx_array_new_float64_ptr = GET_SYM(handle, "mlx_array_new_float64");
|
|
if (mlx_array_new_float64_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float64\n");
|
|
return -1;
|
|
}
|
|
mlx_array_new_double_ptr = GET_SYM(handle, "mlx_array_new_double");
|
|
if (mlx_array_new_double_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_double\n");
|
|
return -1;
|
|
}
|
|
mlx_array_new_complex_ptr = GET_SYM(handle, "mlx_array_new_complex");
|
|
if (mlx_array_new_complex_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_complex\n");
|
|
return -1;
|
|
}
|
|
mlx_array_new_data_ptr = GET_SYM(handle, "mlx_array_new_data");
|
|
if (mlx_array_new_data_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data\n");
|
|
return -1;
|
|
}
|
|
mlx_array_new_data_managed_ptr = GET_SYM(handle, "mlx_array_new_data_managed");
|
|
if (mlx_array_new_data_managed_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed\n");
|
|
return -1;
|
|
}
|
|
mlx_array_new_data_managed_payload_ptr = GET_SYM(handle, "mlx_array_new_data_managed_payload");
|
|
if (mlx_array_new_data_managed_payload_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data_managed_payload\n");
|
|
return -1;
|
|
}
|
|
mlx_array_set_ptr = GET_SYM(handle, "mlx_array_set");
|
|
if (mlx_array_set_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set\n");
|
|
return -1;
|
|
}
|
|
mlx_array_set_bool_ptr = GET_SYM(handle, "mlx_array_set_bool");
|
|
if (mlx_array_set_bool_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_bool\n");
|
|
return -1;
|
|
}
|
|
mlx_array_set_int_ptr = GET_SYM(handle, "mlx_array_set_int");
|
|
if (mlx_array_set_int_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_int\n");
|
|
return -1;
|
|
}
|
|
mlx_array_set_float32_ptr = GET_SYM(handle, "mlx_array_set_float32");
|
|
if (mlx_array_set_float32_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float32\n");
|
|
return -1;
|
|
}
|
|
mlx_array_set_float_ptr = GET_SYM(handle, "mlx_array_set_float");
|
|
if (mlx_array_set_float_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float\n");
|
|
return -1;
|
|
}
|
|
mlx_array_set_float64_ptr = GET_SYM(handle, "mlx_array_set_float64");
|
|
if (mlx_array_set_float64_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float64\n");
|
|
return -1;
|
|
}
|
|
mlx_array_set_double_ptr = GET_SYM(handle, "mlx_array_set_double");
|
|
if (mlx_array_set_double_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_double\n");
|
|
return -1;
|
|
}
|
|
mlx_array_set_complex_ptr = GET_SYM(handle, "mlx_array_set_complex");
|
|
if (mlx_array_set_complex_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_complex\n");
|
|
return -1;
|
|
}
|
|
mlx_array_set_data_ptr = GET_SYM(handle, "mlx_array_set_data");
|
|
if (mlx_array_set_data_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_data\n");
|
|
return -1;
|
|
}
|
|
mlx_array_itemsize_ptr = GET_SYM(handle, "mlx_array_itemsize");
|
|
if (mlx_array_itemsize_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_itemsize\n");
|
|
return -1;
|
|
}
|
|
mlx_array_size_ptr = GET_SYM(handle, "mlx_array_size");
|
|
if (mlx_array_size_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_size\n");
|
|
return -1;
|
|
}
|
|
mlx_array_nbytes_ptr = GET_SYM(handle, "mlx_array_nbytes");
|
|
if (mlx_array_nbytes_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_nbytes\n");
|
|
return -1;
|
|
}
|
|
mlx_array_ndim_ptr = GET_SYM(handle, "mlx_array_ndim");
|
|
if (mlx_array_ndim_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_ndim\n");
|
|
return -1;
|
|
}
|
|
mlx_array_shape_ptr = GET_SYM(handle, "mlx_array_shape");
|
|
if (mlx_array_shape_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_shape\n");
|
|
return -1;
|
|
}
|
|
mlx_array_strides_ptr = GET_SYM(handle, "mlx_array_strides");
|
|
if (mlx_array_strides_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_strides\n");
|
|
return -1;
|
|
}
|
|
mlx_array_dim_ptr = GET_SYM(handle, "mlx_array_dim");
|
|
if (mlx_array_dim_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_dim\n");
|
|
return -1;
|
|
}
|
|
mlx_array_dtype_ptr = GET_SYM(handle, "mlx_array_dtype");
|
|
if (mlx_array_dtype_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_dtype\n");
|
|
return -1;
|
|
}
|
|
mlx_array_eval_ptr = GET_SYM(handle, "mlx_array_eval");
|
|
if (mlx_array_eval_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_eval\n");
|
|
return -1;
|
|
}
|
|
mlx_array_item_bool_ptr = GET_SYM(handle, "mlx_array_item_bool");
|
|
if (mlx_array_item_bool_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_bool\n");
|
|
return -1;
|
|
}
|
|
mlx_array_item_uint8_ptr = GET_SYM(handle, "mlx_array_item_uint8");
|
|
if (mlx_array_item_uint8_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint8\n");
|
|
return -1;
|
|
}
|
|
mlx_array_item_uint16_ptr = GET_SYM(handle, "mlx_array_item_uint16");
|
|
if (mlx_array_item_uint16_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint16\n");
|
|
return -1;
|
|
}
|
|
mlx_array_item_uint32_ptr = GET_SYM(handle, "mlx_array_item_uint32");
|
|
if (mlx_array_item_uint32_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint32\n");
|
|
return -1;
|
|
}
|
|
mlx_array_item_uint64_ptr = GET_SYM(handle, "mlx_array_item_uint64");
|
|
if (mlx_array_item_uint64_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint64\n");
|
|
return -1;
|
|
}
|
|
mlx_array_item_int8_ptr = GET_SYM(handle, "mlx_array_item_int8");
|
|
if (mlx_array_item_int8_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int8\n");
|
|
return -1;
|
|
}
|
|
mlx_array_item_int16_ptr = GET_SYM(handle, "mlx_array_item_int16");
|
|
if (mlx_array_item_int16_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int16\n");
|
|
return -1;
|
|
}
|
|
mlx_array_item_int32_ptr = GET_SYM(handle, "mlx_array_item_int32");
|
|
if (mlx_array_item_int32_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int32\n");
|
|
return -1;
|
|
}
|
|
mlx_array_item_int64_ptr = GET_SYM(handle, "mlx_array_item_int64");
|
|
if (mlx_array_item_int64_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int64\n");
|
|
return -1;
|
|
}
|
|
mlx_array_item_float32_ptr = GET_SYM(handle, "mlx_array_item_float32");
|
|
if (mlx_array_item_float32_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float32\n");
|
|
return -1;
|
|
}
|
|
mlx_array_item_float64_ptr = GET_SYM(handle, "mlx_array_item_float64");
|
|
if (mlx_array_item_float64_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float64\n");
|
|
return -1;
|
|
}
|
|
mlx_array_item_complex64_ptr = GET_SYM(handle, "mlx_array_item_complex64");
|
|
if (mlx_array_item_complex64_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_complex64\n");
|
|
return -1;
|
|
}
|
|
#if defined(__aarch64__) || defined(_M_ARM64)
|
|
mlx_array_item_float16_ptr = GET_SYM(handle, "mlx_array_item_float16");
|
|
if (mlx_array_item_float16_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float16\n");
|
|
return -1;
|
|
}
|
|
#endif
|
|
#if defined(__aarch64__) || defined(_M_ARM64)
|
|
mlx_array_item_bfloat16_ptr = GET_SYM(handle, "mlx_array_item_bfloat16");
|
|
if (mlx_array_item_bfloat16_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_bfloat16\n");
|
|
return -1;
|
|
}
|
|
#endif
|
|
mlx_array_data_bool_ptr = GET_SYM(handle, "mlx_array_data_bool");
|
|
if (mlx_array_data_bool_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_bool\n");
|
|
return -1;
|
|
}
|
|
mlx_array_data_uint8_ptr = GET_SYM(handle, "mlx_array_data_uint8");
|
|
if (mlx_array_data_uint8_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint8\n");
|
|
return -1;
|
|
}
|
|
mlx_array_data_uint16_ptr = GET_SYM(handle, "mlx_array_data_uint16");
|
|
if (mlx_array_data_uint16_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint16\n");
|
|
return -1;
|
|
}
|
|
mlx_array_data_uint32_ptr = GET_SYM(handle, "mlx_array_data_uint32");
|
|
if (mlx_array_data_uint32_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint32\n");
|
|
return -1;
|
|
}
|
|
mlx_array_data_uint64_ptr = GET_SYM(handle, "mlx_array_data_uint64");
|
|
if (mlx_array_data_uint64_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint64\n");
|
|
return -1;
|
|
}
|
|
mlx_array_data_int8_ptr = GET_SYM(handle, "mlx_array_data_int8");
|
|
if (mlx_array_data_int8_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int8\n");
|
|
return -1;
|
|
}
|
|
mlx_array_data_int16_ptr = GET_SYM(handle, "mlx_array_data_int16");
|
|
if (mlx_array_data_int16_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int16\n");
|
|
return -1;
|
|
}
|
|
mlx_array_data_int32_ptr = GET_SYM(handle, "mlx_array_data_int32");
|
|
if (mlx_array_data_int32_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int32\n");
|
|
return -1;
|
|
}
|
|
mlx_array_data_int64_ptr = GET_SYM(handle, "mlx_array_data_int64");
|
|
if (mlx_array_data_int64_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int64\n");
|
|
return -1;
|
|
}
|
|
mlx_array_data_float32_ptr = GET_SYM(handle, "mlx_array_data_float32");
|
|
if (mlx_array_data_float32_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float32\n");
|
|
return -1;
|
|
}
|
|
mlx_array_data_float64_ptr = GET_SYM(handle, "mlx_array_data_float64");
|
|
if (mlx_array_data_float64_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float64\n");
|
|
return -1;
|
|
}
|
|
mlx_array_data_complex64_ptr = GET_SYM(handle, "mlx_array_data_complex64");
|
|
if (mlx_array_data_complex64_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_complex64\n");
|
|
return -1;
|
|
}
|
|
#if defined(__aarch64__) || defined(_M_ARM64)
|
|
mlx_array_data_float16_ptr = GET_SYM(handle, "mlx_array_data_float16");
|
|
if (mlx_array_data_float16_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float16\n");
|
|
return -1;
|
|
}
|
|
#endif
|
|
#if defined(__aarch64__) || defined(_M_ARM64)
|
|
mlx_array_data_bfloat16_ptr = GET_SYM(handle, "mlx_array_data_bfloat16");
|
|
if (mlx_array_data_bfloat16_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_bfloat16\n");
|
|
return -1;
|
|
}
|
|
#endif
|
|
_mlx_array_is_available_ptr = GET_SYM(handle, "_mlx_array_is_available");
|
|
if (_mlx_array_is_available_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_available\n");
|
|
return -1;
|
|
}
|
|
_mlx_array_wait_ptr = GET_SYM(handle, "_mlx_array_wait");
|
|
if (_mlx_array_wait_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_wait\n");
|
|
return -1;
|
|
}
|
|
_mlx_array_is_contiguous_ptr = GET_SYM(handle, "_mlx_array_is_contiguous");
|
|
if (_mlx_array_is_contiguous_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_contiguous\n");
|
|
return -1;
|
|
}
|
|
_mlx_array_is_row_contiguous_ptr = GET_SYM(handle, "_mlx_array_is_row_contiguous");
|
|
if (_mlx_array_is_row_contiguous_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_row_contiguous\n");
|
|
return -1;
|
|
}
|
|
_mlx_array_is_col_contiguous_ptr = GET_SYM(handle, "_mlx_array_is_col_contiguous");
|
|
if (_mlx_array_is_col_contiguous_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_col_contiguous\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_new_ptr = GET_SYM(handle, "mlx_closure_new");
|
|
if (mlx_closure_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_free_ptr = GET_SYM(handle, "mlx_closure_free");
|
|
if (mlx_closure_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_free\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_new_func_ptr = GET_SYM(handle, "mlx_closure_new_func");
|
|
if (mlx_closure_new_func_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_func\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_new_func_payload");
|
|
if (mlx_closure_new_func_payload_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_func_payload\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_set_ptr = GET_SYM(handle, "mlx_closure_set");
|
|
if (mlx_closure_set_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_set\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_apply_ptr = GET_SYM(handle, "mlx_closure_apply");
|
|
if (mlx_closure_apply_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_apply\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_new_unary_ptr = GET_SYM(handle, "mlx_closure_new_unary");
|
|
if (mlx_closure_new_unary_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_unary\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_kwargs_new_ptr = GET_SYM(handle, "mlx_closure_kwargs_new");
|
|
if (mlx_closure_kwargs_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_kwargs_free_ptr = GET_SYM(handle, "mlx_closure_kwargs_free");
|
|
if (mlx_closure_kwargs_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_free\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_kwargs_new_func_ptr = GET_SYM(handle, "mlx_closure_kwargs_new_func");
|
|
if (mlx_closure_kwargs_new_func_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new_func\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_kwargs_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_kwargs_new_func_payload");
|
|
if (mlx_closure_kwargs_new_func_payload_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new_func_payload\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_kwargs_set_ptr = GET_SYM(handle, "mlx_closure_kwargs_set");
|
|
if (mlx_closure_kwargs_set_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_set\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_kwargs_apply_ptr = GET_SYM(handle, "mlx_closure_kwargs_apply");
|
|
if (mlx_closure_kwargs_apply_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_apply\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_value_and_grad_new_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_new");
|
|
if (mlx_closure_value_and_grad_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_value_and_grad_free_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_free");
|
|
if (mlx_closure_value_and_grad_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_free\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_value_and_grad_new_func_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_new_func");
|
|
if (mlx_closure_value_and_grad_new_func_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new_func\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_value_and_grad_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_new_func_payload");
|
|
if (mlx_closure_value_and_grad_new_func_payload_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new_func_payload\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_value_and_grad_set_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_set");
|
|
if (mlx_closure_value_and_grad_set_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_set\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_value_and_grad_apply_ptr = GET_SYM(handle, "mlx_closure_value_and_grad_apply");
|
|
if (mlx_closure_value_and_grad_apply_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_apply\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_custom_new_ptr = GET_SYM(handle, "mlx_closure_custom_new");
|
|
if (mlx_closure_custom_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_custom_free_ptr = GET_SYM(handle, "mlx_closure_custom_free");
|
|
if (mlx_closure_custom_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_free\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_custom_new_func_ptr = GET_SYM(handle, "mlx_closure_custom_new_func");
|
|
if (mlx_closure_custom_new_func_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new_func\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_custom_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_custom_new_func_payload");
|
|
if (mlx_closure_custom_new_func_payload_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new_func_payload\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_custom_set_ptr = GET_SYM(handle, "mlx_closure_custom_set");
|
|
if (mlx_closure_custom_set_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_set\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_custom_apply_ptr = GET_SYM(handle, "mlx_closure_custom_apply");
|
|
if (mlx_closure_custom_apply_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_apply\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_custom_jvp_new_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_new");
|
|
if (mlx_closure_custom_jvp_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_custom_jvp_free_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_free");
|
|
if (mlx_closure_custom_jvp_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_free\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_custom_jvp_new_func_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_new_func");
|
|
if (mlx_closure_custom_jvp_new_func_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new_func\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_custom_jvp_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_new_func_payload");
|
|
if (mlx_closure_custom_jvp_new_func_payload_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new_func_payload\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_custom_jvp_set_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_set");
|
|
if (mlx_closure_custom_jvp_set_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_set\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_custom_jvp_apply_ptr = GET_SYM(handle, "mlx_closure_custom_jvp_apply");
|
|
if (mlx_closure_custom_jvp_apply_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_apply\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_custom_vmap_new_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_new");
|
|
if (mlx_closure_custom_vmap_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_custom_vmap_free_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_free");
|
|
if (mlx_closure_custom_vmap_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_free\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_custom_vmap_new_func_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_new_func");
|
|
if (mlx_closure_custom_vmap_new_func_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new_func\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_custom_vmap_new_func_payload_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_new_func_payload");
|
|
if (mlx_closure_custom_vmap_new_func_payload_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new_func_payload\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_custom_vmap_set_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_set");
|
|
if (mlx_closure_custom_vmap_set_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_set\n");
|
|
return -1;
|
|
}
|
|
mlx_closure_custom_vmap_apply_ptr = GET_SYM(handle, "mlx_closure_custom_vmap_apply");
|
|
if (mlx_closure_custom_vmap_apply_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_apply\n");
|
|
return -1;
|
|
}
|
|
mlx_compile_ptr = GET_SYM(handle, "mlx_compile");
|
|
if (mlx_compile_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_compile\n");
|
|
return -1;
|
|
}
|
|
mlx_detail_compile_ptr = GET_SYM(handle, "mlx_detail_compile");
|
|
if (mlx_detail_compile_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile\n");
|
|
return -1;
|
|
}
|
|
mlx_detail_compile_clear_cache_ptr = GET_SYM(handle, "mlx_detail_compile_clear_cache");
|
|
if (mlx_detail_compile_clear_cache_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile_clear_cache\n");
|
|
return -1;
|
|
}
|
|
mlx_detail_compile_erase_ptr = GET_SYM(handle, "mlx_detail_compile_erase");
|
|
if (mlx_detail_compile_erase_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile_erase\n");
|
|
return -1;
|
|
}
|
|
mlx_disable_compile_ptr = GET_SYM(handle, "mlx_disable_compile");
|
|
if (mlx_disable_compile_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_disable_compile\n");
|
|
return -1;
|
|
}
|
|
mlx_enable_compile_ptr = GET_SYM(handle, "mlx_enable_compile");
|
|
if (mlx_enable_compile_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_enable_compile\n");
|
|
return -1;
|
|
}
|
|
mlx_set_compile_mode_ptr = GET_SYM(handle, "mlx_set_compile_mode");
|
|
if (mlx_set_compile_mode_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_compile_mode\n");
|
|
return -1;
|
|
}
|
|
mlx_cuda_is_available_ptr = GET_SYM(handle, "mlx_cuda_is_available");
|
|
if (mlx_cuda_is_available_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_cuda_is_available\n");
|
|
return -1;
|
|
}
|
|
mlx_device_new_ptr = GET_SYM(handle, "mlx_device_new");
|
|
if (mlx_device_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new\n");
|
|
return -1;
|
|
}
|
|
mlx_device_new_type_ptr = GET_SYM(handle, "mlx_device_new_type");
|
|
if (mlx_device_new_type_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new_type\n");
|
|
return -1;
|
|
}
|
|
mlx_device_free_ptr = GET_SYM(handle, "mlx_device_free");
|
|
if (mlx_device_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_free\n");
|
|
return -1;
|
|
}
|
|
mlx_device_set_ptr = GET_SYM(handle, "mlx_device_set");
|
|
if (mlx_device_set_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_set\n");
|
|
return -1;
|
|
}
|
|
mlx_device_tostring_ptr = GET_SYM(handle, "mlx_device_tostring");
|
|
if (mlx_device_tostring_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_tostring\n");
|
|
return -1;
|
|
}
|
|
mlx_device_equal_ptr = GET_SYM(handle, "mlx_device_equal");
|
|
if (mlx_device_equal_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_equal\n");
|
|
return -1;
|
|
}
|
|
mlx_device_get_index_ptr = GET_SYM(handle, "mlx_device_get_index");
|
|
if (mlx_device_get_index_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_get_index\n");
|
|
return -1;
|
|
}
|
|
mlx_device_get_type_ptr = GET_SYM(handle, "mlx_device_get_type");
|
|
if (mlx_device_get_type_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_get_type\n");
|
|
return -1;
|
|
}
|
|
mlx_get_default_device_ptr = GET_SYM(handle, "mlx_get_default_device");
|
|
if (mlx_get_default_device_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_get_default_device\n");
|
|
return -1;
|
|
}
|
|
mlx_set_default_device_ptr = GET_SYM(handle, "mlx_set_default_device");
|
|
if (mlx_set_default_device_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_device\n");
|
|
return -1;
|
|
}
|
|
mlx_device_is_available_ptr = GET_SYM(handle, "mlx_device_is_available");
|
|
if (mlx_device_is_available_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_is_available\n");
|
|
return -1;
|
|
}
|
|
mlx_device_count_ptr = GET_SYM(handle, "mlx_device_count");
|
|
if (mlx_device_count_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_count\n");
|
|
return -1;
|
|
}
|
|
mlx_device_info_new_ptr = GET_SYM(handle, "mlx_device_info_new");
|
|
if (mlx_device_info_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_new\n");
|
|
return -1;
|
|
}
|
|
mlx_device_info_get_ptr = GET_SYM(handle, "mlx_device_info_get");
|
|
if (mlx_device_info_get_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get\n");
|
|
return -1;
|
|
}
|
|
mlx_device_info_free_ptr = GET_SYM(handle, "mlx_device_info_free");
|
|
if (mlx_device_info_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_free\n");
|
|
return -1;
|
|
}
|
|
mlx_device_info_has_key_ptr = GET_SYM(handle, "mlx_device_info_has_key");
|
|
if (mlx_device_info_has_key_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_has_key\n");
|
|
return -1;
|
|
}
|
|
mlx_device_info_is_string_ptr = GET_SYM(handle, "mlx_device_info_is_string");
|
|
if (mlx_device_info_is_string_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_is_string\n");
|
|
return -1;
|
|
}
|
|
mlx_device_info_get_string_ptr = GET_SYM(handle, "mlx_device_info_get_string");
|
|
if (mlx_device_info_get_string_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_string\n");
|
|
return -1;
|
|
}
|
|
mlx_device_info_get_size_ptr = GET_SYM(handle, "mlx_device_info_get_size");
|
|
if (mlx_device_info_get_size_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_size\n");
|
|
return -1;
|
|
}
|
|
mlx_device_info_get_keys_ptr = GET_SYM(handle, "mlx_device_info_get_keys");
|
|
if (mlx_device_info_get_keys_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_info_get_keys\n");
|
|
return -1;
|
|
}
|
|
mlx_distributed_all_gather_ptr = GET_SYM(handle, "mlx_distributed_all_gather");
|
|
if (mlx_distributed_all_gather_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_gather\n");
|
|
return -1;
|
|
}
|
|
mlx_distributed_all_max_ptr = GET_SYM(handle, "mlx_distributed_all_max");
|
|
if (mlx_distributed_all_max_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_max\n");
|
|
return -1;
|
|
}
|
|
mlx_distributed_all_min_ptr = GET_SYM(handle, "mlx_distributed_all_min");
|
|
if (mlx_distributed_all_min_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_min\n");
|
|
return -1;
|
|
}
|
|
mlx_distributed_all_sum_ptr = GET_SYM(handle, "mlx_distributed_all_sum");
|
|
if (mlx_distributed_all_sum_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_sum\n");
|
|
return -1;
|
|
}
|
|
mlx_distributed_recv_ptr = GET_SYM(handle, "mlx_distributed_recv");
|
|
if (mlx_distributed_recv_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_recv\n");
|
|
return -1;
|
|
}
|
|
mlx_distributed_recv_like_ptr = GET_SYM(handle, "mlx_distributed_recv_like");
|
|
if (mlx_distributed_recv_like_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_recv_like\n");
|
|
return -1;
|
|
}
|
|
mlx_distributed_send_ptr = GET_SYM(handle, "mlx_distributed_send");
|
|
if (mlx_distributed_send_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_send\n");
|
|
return -1;
|
|
}
|
|
mlx_distributed_sum_scatter_ptr = GET_SYM(handle, "mlx_distributed_sum_scatter");
|
|
if (mlx_distributed_sum_scatter_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_sum_scatter\n");
|
|
return -1;
|
|
}
|
|
mlx_distributed_group_rank_ptr = GET_SYM(handle, "mlx_distributed_group_rank");
|
|
if (mlx_distributed_group_rank_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_rank\n");
|
|
return -1;
|
|
}
|
|
mlx_distributed_group_size_ptr = GET_SYM(handle, "mlx_distributed_group_size");
|
|
if (mlx_distributed_group_size_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_size\n");
|
|
return -1;
|
|
}
|
|
mlx_distributed_group_split_ptr = GET_SYM(handle, "mlx_distributed_group_split");
|
|
if (mlx_distributed_group_split_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_split\n");
|
|
return -1;
|
|
}
|
|
mlx_distributed_is_available_ptr = GET_SYM(handle, "mlx_distributed_is_available");
|
|
if (mlx_distributed_is_available_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_is_available\n");
|
|
return -1;
|
|
}
|
|
mlx_distributed_init_ptr = GET_SYM(handle, "mlx_distributed_init");
|
|
if (mlx_distributed_init_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_init\n");
|
|
return -1;
|
|
}
|
|
mlx_set_error_handler_ptr = GET_SYM(handle, "mlx_set_error_handler");
|
|
if (mlx_set_error_handler_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_error_handler\n");
|
|
return -1;
|
|
}
|
|
_mlx_error_ptr = GET_SYM(handle, "_mlx_error");
|
|
if (_mlx_error_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: _mlx_error\n");
|
|
return -1;
|
|
}
|
|
mlx_export_function_ptr = GET_SYM(handle, "mlx_export_function");
|
|
if (mlx_export_function_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_export_function\n");
|
|
return -1;
|
|
}
|
|
mlx_export_function_kwargs_ptr = GET_SYM(handle, "mlx_export_function_kwargs");
|
|
if (mlx_export_function_kwargs_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_export_function_kwargs\n");
|
|
return -1;
|
|
}
|
|
mlx_function_exporter_new_ptr = GET_SYM(handle, "mlx_function_exporter_new");
|
|
if (mlx_function_exporter_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_new\n");
|
|
return -1;
|
|
}
|
|
mlx_function_exporter_free_ptr = GET_SYM(handle, "mlx_function_exporter_free");
|
|
if (mlx_function_exporter_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_free\n");
|
|
return -1;
|
|
}
|
|
mlx_function_exporter_apply_ptr = GET_SYM(handle, "mlx_function_exporter_apply");
|
|
if (mlx_function_exporter_apply_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_apply\n");
|
|
return -1;
|
|
}
|
|
mlx_function_exporter_apply_kwargs_ptr = GET_SYM(handle, "mlx_function_exporter_apply_kwargs");
|
|
if (mlx_function_exporter_apply_kwargs_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_apply_kwargs\n");
|
|
return -1;
|
|
}
|
|
mlx_imported_function_new_ptr = GET_SYM(handle, "mlx_imported_function_new");
|
|
if (mlx_imported_function_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_new\n");
|
|
return -1;
|
|
}
|
|
mlx_imported_function_free_ptr = GET_SYM(handle, "mlx_imported_function_free");
|
|
if (mlx_imported_function_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_free\n");
|
|
return -1;
|
|
}
|
|
mlx_imported_function_apply_ptr = GET_SYM(handle, "mlx_imported_function_apply");
|
|
if (mlx_imported_function_apply_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_apply\n");
|
|
return -1;
|
|
}
|
|
mlx_imported_function_apply_kwargs_ptr = GET_SYM(handle, "mlx_imported_function_apply_kwargs");
|
|
if (mlx_imported_function_apply_kwargs_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_apply_kwargs\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_cuda_kernel_config_new_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_new");
|
|
if (mlx_fast_cuda_kernel_config_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_new\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_cuda_kernel_config_free_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_free");
|
|
if (mlx_fast_cuda_kernel_config_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_free\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_cuda_kernel_config_add_output_arg_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_output_arg");
|
|
if (mlx_fast_cuda_kernel_config_add_output_arg_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_output_arg\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_cuda_kernel_config_set_grid_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_grid");
|
|
if (mlx_fast_cuda_kernel_config_set_grid_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_grid\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_cuda_kernel_config_set_thread_group_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_thread_group");
|
|
if (mlx_fast_cuda_kernel_config_set_thread_group_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_thread_group\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_cuda_kernel_config_set_init_value_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_init_value");
|
|
if (mlx_fast_cuda_kernel_config_set_init_value_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_init_value\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_cuda_kernel_config_set_verbose_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_set_verbose");
|
|
if (mlx_fast_cuda_kernel_config_set_verbose_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_verbose\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_template_arg_dtype");
|
|
if (mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_dtype\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_cuda_kernel_config_add_template_arg_int_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_template_arg_int");
|
|
if (mlx_fast_cuda_kernel_config_add_template_arg_int_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_int\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_config_add_template_arg_bool");
|
|
if (mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_bool\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_cuda_kernel_new_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_new");
|
|
if (mlx_fast_cuda_kernel_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_new\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_cuda_kernel_free_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_free");
|
|
if (mlx_fast_cuda_kernel_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_free\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_cuda_kernel_apply_ptr = GET_SYM(handle, "mlx_fast_cuda_kernel_apply");
|
|
if (mlx_fast_cuda_kernel_apply_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_apply\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_layer_norm_ptr = GET_SYM(handle, "mlx_fast_layer_norm");
|
|
if (mlx_fast_layer_norm_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_layer_norm\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_metal_kernel_config_new_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_new");
|
|
if (mlx_fast_metal_kernel_config_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_new\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_metal_kernel_config_free_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_free");
|
|
if (mlx_fast_metal_kernel_config_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_free\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_metal_kernel_config_add_output_arg_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_output_arg");
|
|
if (mlx_fast_metal_kernel_config_add_output_arg_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_output_arg\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_metal_kernel_config_set_grid_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_grid");
|
|
if (mlx_fast_metal_kernel_config_set_grid_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_grid\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_metal_kernel_config_set_thread_group_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_thread_group");
|
|
if (mlx_fast_metal_kernel_config_set_thread_group_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_thread_group\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_metal_kernel_config_set_init_value_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_init_value");
|
|
if (mlx_fast_metal_kernel_config_set_init_value_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_init_value\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_metal_kernel_config_set_verbose_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_set_verbose");
|
|
if (mlx_fast_metal_kernel_config_set_verbose_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_verbose\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_template_arg_dtype");
|
|
if (mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_dtype\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_metal_kernel_config_add_template_arg_int_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_template_arg_int");
|
|
if (mlx_fast_metal_kernel_config_add_template_arg_int_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_int\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_metal_kernel_config_add_template_arg_bool_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_config_add_template_arg_bool");
|
|
if (mlx_fast_metal_kernel_config_add_template_arg_bool_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_bool\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_metal_kernel_new_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_new");
|
|
if (mlx_fast_metal_kernel_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_new\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_metal_kernel_free_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_free");
|
|
if (mlx_fast_metal_kernel_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_free\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_metal_kernel_apply_ptr = GET_SYM(handle, "mlx_fast_metal_kernel_apply");
|
|
if (mlx_fast_metal_kernel_apply_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_apply\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_rms_norm_ptr = GET_SYM(handle, "mlx_fast_rms_norm");
|
|
if (mlx_fast_rms_norm_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rms_norm\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_rope_ptr = GET_SYM(handle, "mlx_fast_rope");
|
|
if (mlx_fast_rope_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rope\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_rope_dynamic_ptr = GET_SYM(handle, "mlx_fast_rope_dynamic");
|
|
if (mlx_fast_rope_dynamic_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rope_dynamic\n");
|
|
return -1;
|
|
}
|
|
mlx_fast_scaled_dot_product_attention_ptr = GET_SYM(handle, "mlx_fast_scaled_dot_product_attention");
|
|
if (mlx_fast_scaled_dot_product_attention_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_scaled_dot_product_attention\n");
|
|
return -1;
|
|
}
|
|
mlx_fft_fft_ptr = GET_SYM(handle, "mlx_fft_fft");
|
|
if (mlx_fft_fft_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fft\n");
|
|
return -1;
|
|
}
|
|
mlx_fft_fft2_ptr = GET_SYM(handle, "mlx_fft_fft2");
|
|
if (mlx_fft_fft2_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fft2\n");
|
|
return -1;
|
|
}
|
|
mlx_fft_fftn_ptr = GET_SYM(handle, "mlx_fft_fftn");
|
|
if (mlx_fft_fftn_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fftn\n");
|
|
return -1;
|
|
}
|
|
mlx_fft_fftshift_ptr = GET_SYM(handle, "mlx_fft_fftshift");
|
|
if (mlx_fft_fftshift_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fftshift\n");
|
|
return -1;
|
|
}
|
|
mlx_fft_ifft_ptr = GET_SYM(handle, "mlx_fft_ifft");
|
|
if (mlx_fft_ifft_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifft\n");
|
|
return -1;
|
|
}
|
|
mlx_fft_ifft2_ptr = GET_SYM(handle, "mlx_fft_ifft2");
|
|
if (mlx_fft_ifft2_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifft2\n");
|
|
return -1;
|
|
}
|
|
mlx_fft_ifftn_ptr = GET_SYM(handle, "mlx_fft_ifftn");
|
|
if (mlx_fft_ifftn_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifftn\n");
|
|
return -1;
|
|
}
|
|
mlx_fft_ifftshift_ptr = GET_SYM(handle, "mlx_fft_ifftshift");
|
|
if (mlx_fft_ifftshift_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifftshift\n");
|
|
return -1;
|
|
}
|
|
mlx_fft_irfft_ptr = GET_SYM(handle, "mlx_fft_irfft");
|
|
if (mlx_fft_irfft_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfft\n");
|
|
return -1;
|
|
}
|
|
mlx_fft_irfft2_ptr = GET_SYM(handle, "mlx_fft_irfft2");
|
|
if (mlx_fft_irfft2_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfft2\n");
|
|
return -1;
|
|
}
|
|
mlx_fft_irfftn_ptr = GET_SYM(handle, "mlx_fft_irfftn");
|
|
if (mlx_fft_irfftn_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfftn\n");
|
|
return -1;
|
|
}
|
|
mlx_fft_rfft_ptr = GET_SYM(handle, "mlx_fft_rfft");
|
|
if (mlx_fft_rfft_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfft\n");
|
|
return -1;
|
|
}
|
|
mlx_fft_rfft2_ptr = GET_SYM(handle, "mlx_fft_rfft2");
|
|
if (mlx_fft_rfft2_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfft2\n");
|
|
return -1;
|
|
}
|
|
mlx_fft_rfftn_ptr = GET_SYM(handle, "mlx_fft_rfftn");
|
|
if (mlx_fft_rfftn_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfftn\n");
|
|
return -1;
|
|
}
|
|
mlx_load_reader_ptr = GET_SYM(handle, "mlx_load_reader");
|
|
if (mlx_load_reader_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_load_reader\n");
|
|
return -1;
|
|
}
|
|
mlx_load_ptr = GET_SYM(handle, "mlx_load");
|
|
if (mlx_load_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_load\n");
|
|
return -1;
|
|
}
|
|
mlx_load_safetensors_reader_ptr = GET_SYM(handle, "mlx_load_safetensors_reader");
|
|
if (mlx_load_safetensors_reader_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_load_safetensors_reader\n");
|
|
return -1;
|
|
}
|
|
mlx_load_safetensors_ptr = GET_SYM(handle, "mlx_load_safetensors");
|
|
if (mlx_load_safetensors_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_load_safetensors\n");
|
|
return -1;
|
|
}
|
|
mlx_save_writer_ptr = GET_SYM(handle, "mlx_save_writer");
|
|
if (mlx_save_writer_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_save_writer\n");
|
|
return -1;
|
|
}
|
|
mlx_save_ptr = GET_SYM(handle, "mlx_save");
|
|
if (mlx_save_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_save\n");
|
|
return -1;
|
|
}
|
|
mlx_save_safetensors_writer_ptr = GET_SYM(handle, "mlx_save_safetensors_writer");
|
|
if (mlx_save_safetensors_writer_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_save_safetensors_writer\n");
|
|
return -1;
|
|
}
|
|
mlx_save_safetensors_ptr = GET_SYM(handle, "mlx_save_safetensors");
|
|
if (mlx_save_safetensors_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_save_safetensors\n");
|
|
return -1;
|
|
}
|
|
mlx_io_reader_new_ptr = GET_SYM(handle, "mlx_io_reader_new");
|
|
if (mlx_io_reader_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_new\n");
|
|
return -1;
|
|
}
|
|
mlx_io_reader_descriptor_ptr = GET_SYM(handle, "mlx_io_reader_descriptor");
|
|
if (mlx_io_reader_descriptor_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_descriptor\n");
|
|
return -1;
|
|
}
|
|
mlx_io_reader_tostring_ptr = GET_SYM(handle, "mlx_io_reader_tostring");
|
|
if (mlx_io_reader_tostring_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_tostring\n");
|
|
return -1;
|
|
}
|
|
mlx_io_reader_free_ptr = GET_SYM(handle, "mlx_io_reader_free");
|
|
if (mlx_io_reader_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_free\n");
|
|
return -1;
|
|
}
|
|
mlx_io_writer_new_ptr = GET_SYM(handle, "mlx_io_writer_new");
|
|
if (mlx_io_writer_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_new\n");
|
|
return -1;
|
|
}
|
|
mlx_io_writer_descriptor_ptr = GET_SYM(handle, "mlx_io_writer_descriptor");
|
|
if (mlx_io_writer_descriptor_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_descriptor\n");
|
|
return -1;
|
|
}
|
|
mlx_io_writer_tostring_ptr = GET_SYM(handle, "mlx_io_writer_tostring");
|
|
if (mlx_io_writer_tostring_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_tostring\n");
|
|
return -1;
|
|
}
|
|
mlx_io_writer_free_ptr = GET_SYM(handle, "mlx_io_writer_free");
|
|
if (mlx_io_writer_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_free\n");
|
|
return -1;
|
|
}
|
|
mlx_linalg_cholesky_ptr = GET_SYM(handle, "mlx_linalg_cholesky");
|
|
if (mlx_linalg_cholesky_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cholesky\n");
|
|
return -1;
|
|
}
|
|
mlx_linalg_cholesky_inv_ptr = GET_SYM(handle, "mlx_linalg_cholesky_inv");
|
|
if (mlx_linalg_cholesky_inv_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cholesky_inv\n");
|
|
return -1;
|
|
}
|
|
mlx_linalg_cross_ptr = GET_SYM(handle, "mlx_linalg_cross");
|
|
if (mlx_linalg_cross_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cross\n");
|
|
return -1;
|
|
}
|
|
mlx_linalg_eig_ptr = GET_SYM(handle, "mlx_linalg_eig");
|
|
if (mlx_linalg_eig_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eig\n");
|
|
return -1;
|
|
}
|
|
mlx_linalg_eigh_ptr = GET_SYM(handle, "mlx_linalg_eigh");
|
|
if (mlx_linalg_eigh_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigh\n");
|
|
return -1;
|
|
}
|
|
mlx_linalg_eigvals_ptr = GET_SYM(handle, "mlx_linalg_eigvals");
|
|
if (mlx_linalg_eigvals_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigvals\n");
|
|
return -1;
|
|
}
|
|
mlx_linalg_eigvalsh_ptr = GET_SYM(handle, "mlx_linalg_eigvalsh");
|
|
if (mlx_linalg_eigvalsh_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigvalsh\n");
|
|
return -1;
|
|
}
|
|
mlx_linalg_inv_ptr = GET_SYM(handle, "mlx_linalg_inv");
|
|
if (mlx_linalg_inv_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_inv\n");
|
|
return -1;
|
|
}
|
|
mlx_linalg_lu_ptr = GET_SYM(handle, "mlx_linalg_lu");
|
|
if (mlx_linalg_lu_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_lu\n");
|
|
return -1;
|
|
}
|
|
mlx_linalg_lu_factor_ptr = GET_SYM(handle, "mlx_linalg_lu_factor");
|
|
if (mlx_linalg_lu_factor_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_lu_factor\n");
|
|
return -1;
|
|
}
|
|
mlx_linalg_norm_ptr = GET_SYM(handle, "mlx_linalg_norm");
|
|
if (mlx_linalg_norm_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm\n");
|
|
return -1;
|
|
}
|
|
mlx_linalg_norm_matrix_ptr = GET_SYM(handle, "mlx_linalg_norm_matrix");
|
|
if (mlx_linalg_norm_matrix_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm_matrix\n");
|
|
return -1;
|
|
}
|
|
mlx_linalg_norm_l2_ptr = GET_SYM(handle, "mlx_linalg_norm_l2");
|
|
if (mlx_linalg_norm_l2_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm_l2\n");
|
|
return -1;
|
|
}
|
|
mlx_linalg_pinv_ptr = GET_SYM(handle, "mlx_linalg_pinv");
|
|
if (mlx_linalg_pinv_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_pinv\n");
|
|
return -1;
|
|
}
|
|
mlx_linalg_qr_ptr = GET_SYM(handle, "mlx_linalg_qr");
|
|
if (mlx_linalg_qr_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_qr\n");
|
|
return -1;
|
|
}
|
|
mlx_linalg_solve_ptr = GET_SYM(handle, "mlx_linalg_solve");
|
|
if (mlx_linalg_solve_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_solve\n");
|
|
return -1;
|
|
}
|
|
mlx_linalg_solve_triangular_ptr = GET_SYM(handle, "mlx_linalg_solve_triangular");
|
|
if (mlx_linalg_solve_triangular_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_solve_triangular\n");
|
|
return -1;
|
|
}
|
|
mlx_linalg_svd_ptr = GET_SYM(handle, "mlx_linalg_svd");
|
|
if (mlx_linalg_svd_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_svd\n");
|
|
return -1;
|
|
}
|
|
mlx_linalg_tri_inv_ptr = GET_SYM(handle, "mlx_linalg_tri_inv");
|
|
if (mlx_linalg_tri_inv_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_tri_inv\n");
|
|
return -1;
|
|
}
|
|
mlx_map_string_to_array_new_ptr = GET_SYM(handle, "mlx_map_string_to_array_new");
|
|
if (mlx_map_string_to_array_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_new\n");
|
|
return -1;
|
|
}
|
|
mlx_map_string_to_array_set_ptr = GET_SYM(handle, "mlx_map_string_to_array_set");
|
|
if (mlx_map_string_to_array_set_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_set\n");
|
|
return -1;
|
|
}
|
|
mlx_map_string_to_array_free_ptr = GET_SYM(handle, "mlx_map_string_to_array_free");
|
|
if (mlx_map_string_to_array_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_free\n");
|
|
return -1;
|
|
}
|
|
mlx_map_string_to_array_insert_ptr = GET_SYM(handle, "mlx_map_string_to_array_insert");
|
|
if (mlx_map_string_to_array_insert_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_insert\n");
|
|
return -1;
|
|
}
|
|
mlx_map_string_to_array_get_ptr = GET_SYM(handle, "mlx_map_string_to_array_get");
|
|
if (mlx_map_string_to_array_get_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_get\n");
|
|
return -1;
|
|
}
|
|
mlx_map_string_to_array_iterator_new_ptr = GET_SYM(handle, "mlx_map_string_to_array_iterator_new");
|
|
if (mlx_map_string_to_array_iterator_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_new\n");
|
|
return -1;
|
|
}
|
|
mlx_map_string_to_array_iterator_free_ptr = GET_SYM(handle, "mlx_map_string_to_array_iterator_free");
|
|
if (mlx_map_string_to_array_iterator_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_free\n");
|
|
return -1;
|
|
}
|
|
mlx_map_string_to_array_iterator_next_ptr = GET_SYM(handle, "mlx_map_string_to_array_iterator_next");
|
|
if (mlx_map_string_to_array_iterator_next_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_next\n");
|
|
return -1;
|
|
}
|
|
mlx_map_string_to_string_new_ptr = GET_SYM(handle, "mlx_map_string_to_string_new");
|
|
if (mlx_map_string_to_string_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_new\n");
|
|
return -1;
|
|
}
|
|
mlx_map_string_to_string_set_ptr = GET_SYM(handle, "mlx_map_string_to_string_set");
|
|
if (mlx_map_string_to_string_set_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_set\n");
|
|
return -1;
|
|
}
|
|
mlx_map_string_to_string_free_ptr = GET_SYM(handle, "mlx_map_string_to_string_free");
|
|
if (mlx_map_string_to_string_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_free\n");
|
|
return -1;
|
|
}
|
|
mlx_map_string_to_string_insert_ptr = GET_SYM(handle, "mlx_map_string_to_string_insert");
|
|
if (mlx_map_string_to_string_insert_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_insert\n");
|
|
return -1;
|
|
}
|
|
mlx_map_string_to_string_get_ptr = GET_SYM(handle, "mlx_map_string_to_string_get");
|
|
if (mlx_map_string_to_string_get_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_get\n");
|
|
return -1;
|
|
}
|
|
mlx_map_string_to_string_iterator_new_ptr = GET_SYM(handle, "mlx_map_string_to_string_iterator_new");
|
|
if (mlx_map_string_to_string_iterator_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_new\n");
|
|
return -1;
|
|
}
|
|
mlx_map_string_to_string_iterator_free_ptr = GET_SYM(handle, "mlx_map_string_to_string_iterator_free");
|
|
if (mlx_map_string_to_string_iterator_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_free\n");
|
|
return -1;
|
|
}
|
|
mlx_map_string_to_string_iterator_next_ptr = GET_SYM(handle, "mlx_map_string_to_string_iterator_next");
|
|
if (mlx_map_string_to_string_iterator_next_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_next\n");
|
|
return -1;
|
|
}
|
|
mlx_clear_cache_ptr = GET_SYM(handle, "mlx_clear_cache");
|
|
if (mlx_clear_cache_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_clear_cache\n");
|
|
return -1;
|
|
}
|
|
mlx_get_active_memory_ptr = GET_SYM(handle, "mlx_get_active_memory");
|
|
if (mlx_get_active_memory_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_get_active_memory\n");
|
|
return -1;
|
|
}
|
|
mlx_get_cache_memory_ptr = GET_SYM(handle, "mlx_get_cache_memory");
|
|
if (mlx_get_cache_memory_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_get_cache_memory\n");
|
|
return -1;
|
|
}
|
|
mlx_get_memory_limit_ptr = GET_SYM(handle, "mlx_get_memory_limit");
|
|
if (mlx_get_memory_limit_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_get_memory_limit\n");
|
|
return -1;
|
|
}
|
|
mlx_get_peak_memory_ptr = GET_SYM(handle, "mlx_get_peak_memory");
|
|
if (mlx_get_peak_memory_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_get_peak_memory\n");
|
|
return -1;
|
|
}
|
|
mlx_reset_peak_memory_ptr = GET_SYM(handle, "mlx_reset_peak_memory");
|
|
if (mlx_reset_peak_memory_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_reset_peak_memory\n");
|
|
return -1;
|
|
}
|
|
mlx_set_cache_limit_ptr = GET_SYM(handle, "mlx_set_cache_limit");
|
|
if (mlx_set_cache_limit_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_cache_limit\n");
|
|
return -1;
|
|
}
|
|
mlx_set_memory_limit_ptr = GET_SYM(handle, "mlx_set_memory_limit");
|
|
if (mlx_set_memory_limit_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_memory_limit\n");
|
|
return -1;
|
|
}
|
|
mlx_set_wired_limit_ptr = GET_SYM(handle, "mlx_set_wired_limit");
|
|
if (mlx_set_wired_limit_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_wired_limit\n");
|
|
return -1;
|
|
}
|
|
mlx_metal_is_available_ptr = GET_SYM(handle, "mlx_metal_is_available");
|
|
if (mlx_metal_is_available_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_is_available\n");
|
|
return -1;
|
|
}
|
|
mlx_metal_start_capture_ptr = GET_SYM(handle, "mlx_metal_start_capture");
|
|
if (mlx_metal_start_capture_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_start_capture\n");
|
|
return -1;
|
|
}
|
|
mlx_metal_stop_capture_ptr = GET_SYM(handle, "mlx_metal_stop_capture");
|
|
if (mlx_metal_stop_capture_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_stop_capture\n");
|
|
return -1;
|
|
}
|
|
mlx_abs_ptr = GET_SYM(handle, "mlx_abs");
|
|
if (mlx_abs_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_abs\n");
|
|
return -1;
|
|
}
|
|
mlx_add_ptr = GET_SYM(handle, "mlx_add");
|
|
if (mlx_add_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_add\n");
|
|
return -1;
|
|
}
|
|
mlx_addmm_ptr = GET_SYM(handle, "mlx_addmm");
|
|
if (mlx_addmm_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_addmm\n");
|
|
return -1;
|
|
}
|
|
mlx_all_axes_ptr = GET_SYM(handle, "mlx_all_axes");
|
|
if (mlx_all_axes_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_all_axes\n");
|
|
return -1;
|
|
}
|
|
mlx_all_axis_ptr = GET_SYM(handle, "mlx_all_axis");
|
|
if (mlx_all_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_all_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_all_ptr = GET_SYM(handle, "mlx_all");
|
|
if (mlx_all_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_all\n");
|
|
return -1;
|
|
}
|
|
mlx_allclose_ptr = GET_SYM(handle, "mlx_allclose");
|
|
if (mlx_allclose_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_allclose\n");
|
|
return -1;
|
|
}
|
|
mlx_any_axes_ptr = GET_SYM(handle, "mlx_any_axes");
|
|
if (mlx_any_axes_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_any_axes\n");
|
|
return -1;
|
|
}
|
|
mlx_any_axis_ptr = GET_SYM(handle, "mlx_any_axis");
|
|
if (mlx_any_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_any_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_any_ptr = GET_SYM(handle, "mlx_any");
|
|
if (mlx_any_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_any\n");
|
|
return -1;
|
|
}
|
|
mlx_arange_ptr = GET_SYM(handle, "mlx_arange");
|
|
if (mlx_arange_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_arange\n");
|
|
return -1;
|
|
}
|
|
mlx_arccos_ptr = GET_SYM(handle, "mlx_arccos");
|
|
if (mlx_arccos_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_arccos\n");
|
|
return -1;
|
|
}
|
|
mlx_arccosh_ptr = GET_SYM(handle, "mlx_arccosh");
|
|
if (mlx_arccosh_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_arccosh\n");
|
|
return -1;
|
|
}
|
|
mlx_arcsin_ptr = GET_SYM(handle, "mlx_arcsin");
|
|
if (mlx_arcsin_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_arcsin\n");
|
|
return -1;
|
|
}
|
|
mlx_arcsinh_ptr = GET_SYM(handle, "mlx_arcsinh");
|
|
if (mlx_arcsinh_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_arcsinh\n");
|
|
return -1;
|
|
}
|
|
mlx_arctan_ptr = GET_SYM(handle, "mlx_arctan");
|
|
if (mlx_arctan_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_arctan\n");
|
|
return -1;
|
|
}
|
|
mlx_arctan2_ptr = GET_SYM(handle, "mlx_arctan2");
|
|
if (mlx_arctan2_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_arctan2\n");
|
|
return -1;
|
|
}
|
|
mlx_arctanh_ptr = GET_SYM(handle, "mlx_arctanh");
|
|
if (mlx_arctanh_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_arctanh\n");
|
|
return -1;
|
|
}
|
|
mlx_argmax_axis_ptr = GET_SYM(handle, "mlx_argmax_axis");
|
|
if (mlx_argmax_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_argmax_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_argmax_ptr = GET_SYM(handle, "mlx_argmax");
|
|
if (mlx_argmax_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_argmax\n");
|
|
return -1;
|
|
}
|
|
mlx_argmin_axis_ptr = GET_SYM(handle, "mlx_argmin_axis");
|
|
if (mlx_argmin_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_argmin_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_argmin_ptr = GET_SYM(handle, "mlx_argmin");
|
|
if (mlx_argmin_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_argmin\n");
|
|
return -1;
|
|
}
|
|
mlx_argpartition_axis_ptr = GET_SYM(handle, "mlx_argpartition_axis");
|
|
if (mlx_argpartition_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_argpartition_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_argpartition_ptr = GET_SYM(handle, "mlx_argpartition");
|
|
if (mlx_argpartition_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_argpartition\n");
|
|
return -1;
|
|
}
|
|
mlx_argsort_axis_ptr = GET_SYM(handle, "mlx_argsort_axis");
|
|
if (mlx_argsort_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_argsort_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_argsort_ptr = GET_SYM(handle, "mlx_argsort");
|
|
if (mlx_argsort_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_argsort\n");
|
|
return -1;
|
|
}
|
|
mlx_array_equal_ptr = GET_SYM(handle, "mlx_array_equal");
|
|
if (mlx_array_equal_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_equal\n");
|
|
return -1;
|
|
}
|
|
mlx_as_strided_ptr = GET_SYM(handle, "mlx_as_strided");
|
|
if (mlx_as_strided_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_as_strided\n");
|
|
return -1;
|
|
}
|
|
mlx_astype_ptr = GET_SYM(handle, "mlx_astype");
|
|
if (mlx_astype_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_astype\n");
|
|
return -1;
|
|
}
|
|
mlx_atleast_1d_ptr = GET_SYM(handle, "mlx_atleast_1d");
|
|
if (mlx_atleast_1d_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_1d\n");
|
|
return -1;
|
|
}
|
|
mlx_atleast_2d_ptr = GET_SYM(handle, "mlx_atleast_2d");
|
|
if (mlx_atleast_2d_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_2d\n");
|
|
return -1;
|
|
}
|
|
mlx_atleast_3d_ptr = GET_SYM(handle, "mlx_atleast_3d");
|
|
if (mlx_atleast_3d_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n");
|
|
return -1;
|
|
}
|
|
mlx_bitwise_and_ptr = GET_SYM(handle, "mlx_bitwise_and");
|
|
if (mlx_bitwise_and_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n");
|
|
return -1;
|
|
}
|
|
mlx_bitwise_invert_ptr = GET_SYM(handle, "mlx_bitwise_invert");
|
|
if (mlx_bitwise_invert_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_invert\n");
|
|
return -1;
|
|
}
|
|
mlx_bitwise_or_ptr = GET_SYM(handle, "mlx_bitwise_or");
|
|
if (mlx_bitwise_or_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_or\n");
|
|
return -1;
|
|
}
|
|
mlx_bitwise_xor_ptr = GET_SYM(handle, "mlx_bitwise_xor");
|
|
if (mlx_bitwise_xor_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n");
|
|
return -1;
|
|
}
|
|
mlx_block_masked_mm_ptr = GET_SYM(handle, "mlx_block_masked_mm");
|
|
if (mlx_block_masked_mm_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n");
|
|
return -1;
|
|
}
|
|
mlx_broadcast_arrays_ptr = GET_SYM(handle, "mlx_broadcast_arrays");
|
|
if (mlx_broadcast_arrays_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_broadcast_arrays\n");
|
|
return -1;
|
|
}
|
|
mlx_broadcast_to_ptr = GET_SYM(handle, "mlx_broadcast_to");
|
|
if (mlx_broadcast_to_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_broadcast_to\n");
|
|
return -1;
|
|
}
|
|
mlx_ceil_ptr = GET_SYM(handle, "mlx_ceil");
|
|
if (mlx_ceil_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_ceil\n");
|
|
return -1;
|
|
}
|
|
mlx_clip_ptr = GET_SYM(handle, "mlx_clip");
|
|
if (mlx_clip_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_clip\n");
|
|
return -1;
|
|
}
|
|
mlx_concatenate_axis_ptr = GET_SYM(handle, "mlx_concatenate_axis");
|
|
if (mlx_concatenate_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_concatenate_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_concatenate_ptr = GET_SYM(handle, "mlx_concatenate");
|
|
if (mlx_concatenate_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_concatenate\n");
|
|
return -1;
|
|
}
|
|
mlx_conjugate_ptr = GET_SYM(handle, "mlx_conjugate");
|
|
if (mlx_conjugate_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_conjugate\n");
|
|
return -1;
|
|
}
|
|
mlx_contiguous_ptr = GET_SYM(handle, "mlx_contiguous");
|
|
if (mlx_contiguous_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_contiguous\n");
|
|
return -1;
|
|
}
|
|
mlx_conv1d_ptr = GET_SYM(handle, "mlx_conv1d");
|
|
if (mlx_conv1d_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv1d\n");
|
|
return -1;
|
|
}
|
|
mlx_conv2d_ptr = GET_SYM(handle, "mlx_conv2d");
|
|
if (mlx_conv2d_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv2d\n");
|
|
return -1;
|
|
}
|
|
mlx_conv3d_ptr = GET_SYM(handle, "mlx_conv3d");
|
|
if (mlx_conv3d_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv3d\n");
|
|
return -1;
|
|
}
|
|
mlx_conv_general_ptr = GET_SYM(handle, "mlx_conv_general");
|
|
if (mlx_conv_general_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_general\n");
|
|
return -1;
|
|
}
|
|
mlx_conv_transpose1d_ptr = GET_SYM(handle, "mlx_conv_transpose1d");
|
|
if (mlx_conv_transpose1d_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose1d\n");
|
|
return -1;
|
|
}
|
|
mlx_conv_transpose2d_ptr = GET_SYM(handle, "mlx_conv_transpose2d");
|
|
if (mlx_conv_transpose2d_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose2d\n");
|
|
return -1;
|
|
}
|
|
mlx_conv_transpose3d_ptr = GET_SYM(handle, "mlx_conv_transpose3d");
|
|
if (mlx_conv_transpose3d_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose3d\n");
|
|
return -1;
|
|
}
|
|
mlx_copy_ptr = GET_SYM(handle, "mlx_copy");
|
|
if (mlx_copy_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_copy\n");
|
|
return -1;
|
|
}
|
|
mlx_cos_ptr = GET_SYM(handle, "mlx_cos");
|
|
if (mlx_cos_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_cos\n");
|
|
return -1;
|
|
}
|
|
mlx_cosh_ptr = GET_SYM(handle, "mlx_cosh");
|
|
if (mlx_cosh_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_cosh\n");
|
|
return -1;
|
|
}
|
|
mlx_cummax_ptr = GET_SYM(handle, "mlx_cummax");
|
|
if (mlx_cummax_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_cummax\n");
|
|
return -1;
|
|
}
|
|
mlx_cummin_ptr = GET_SYM(handle, "mlx_cummin");
|
|
if (mlx_cummin_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_cummin\n");
|
|
return -1;
|
|
}
|
|
mlx_cumprod_ptr = GET_SYM(handle, "mlx_cumprod");
|
|
if (mlx_cumprod_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_cumprod\n");
|
|
return -1;
|
|
}
|
|
mlx_cumsum_ptr = GET_SYM(handle, "mlx_cumsum");
|
|
if (mlx_cumsum_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_cumsum\n");
|
|
return -1;
|
|
}
|
|
mlx_degrees_ptr = GET_SYM(handle, "mlx_degrees");
|
|
if (mlx_degrees_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_degrees\n");
|
|
return -1;
|
|
}
|
|
mlx_depends_ptr = GET_SYM(handle, "mlx_depends");
|
|
if (mlx_depends_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_depends\n");
|
|
return -1;
|
|
}
|
|
mlx_dequantize_ptr = GET_SYM(handle, "mlx_dequantize");
|
|
if (mlx_dequantize_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_dequantize\n");
|
|
return -1;
|
|
}
|
|
mlx_diag_ptr = GET_SYM(handle, "mlx_diag");
|
|
if (mlx_diag_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_diag\n");
|
|
return -1;
|
|
}
|
|
mlx_diagonal_ptr = GET_SYM(handle, "mlx_diagonal");
|
|
if (mlx_diagonal_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_diagonal\n");
|
|
return -1;
|
|
}
|
|
mlx_divide_ptr = GET_SYM(handle, "mlx_divide");
|
|
if (mlx_divide_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_divide\n");
|
|
return -1;
|
|
}
|
|
mlx_divmod_ptr = GET_SYM(handle, "mlx_divmod");
|
|
if (mlx_divmod_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_divmod\n");
|
|
return -1;
|
|
}
|
|
mlx_einsum_ptr = GET_SYM(handle, "mlx_einsum");
|
|
if (mlx_einsum_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_einsum\n");
|
|
return -1;
|
|
}
|
|
mlx_equal_ptr = GET_SYM(handle, "mlx_equal");
|
|
if (mlx_equal_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_equal\n");
|
|
return -1;
|
|
}
|
|
mlx_erf_ptr = GET_SYM(handle, "mlx_erf");
|
|
if (mlx_erf_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_erf\n");
|
|
return -1;
|
|
}
|
|
mlx_erfinv_ptr = GET_SYM(handle, "mlx_erfinv");
|
|
if (mlx_erfinv_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_erfinv\n");
|
|
return -1;
|
|
}
|
|
mlx_exp_ptr = GET_SYM(handle, "mlx_exp");
|
|
if (mlx_exp_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_exp\n");
|
|
return -1;
|
|
}
|
|
mlx_expand_dims_axes_ptr = GET_SYM(handle, "mlx_expand_dims_axes");
|
|
if (mlx_expand_dims_axes_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_expand_dims_axes\n");
|
|
return -1;
|
|
}
|
|
mlx_expand_dims_ptr = GET_SYM(handle, "mlx_expand_dims");
|
|
if (mlx_expand_dims_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_expand_dims\n");
|
|
return -1;
|
|
}
|
|
mlx_expm1_ptr = GET_SYM(handle, "mlx_expm1");
|
|
if (mlx_expm1_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_expm1\n");
|
|
return -1;
|
|
}
|
|
mlx_eye_ptr = GET_SYM(handle, "mlx_eye");
|
|
if (mlx_eye_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_eye\n");
|
|
return -1;
|
|
}
|
|
mlx_flatten_ptr = GET_SYM(handle, "mlx_flatten");
|
|
if (mlx_flatten_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_flatten\n");
|
|
return -1;
|
|
}
|
|
mlx_floor_ptr = GET_SYM(handle, "mlx_floor");
|
|
if (mlx_floor_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_floor\n");
|
|
return -1;
|
|
}
|
|
mlx_floor_divide_ptr = GET_SYM(handle, "mlx_floor_divide");
|
|
if (mlx_floor_divide_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_floor_divide\n");
|
|
return -1;
|
|
}
|
|
mlx_from_fp8_ptr = GET_SYM(handle, "mlx_from_fp8");
|
|
if (mlx_from_fp8_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_from_fp8\n");
|
|
return -1;
|
|
}
|
|
mlx_full_ptr = GET_SYM(handle, "mlx_full");
|
|
if (mlx_full_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_full\n");
|
|
return -1;
|
|
}
|
|
mlx_full_like_ptr = GET_SYM(handle, "mlx_full_like");
|
|
if (mlx_full_like_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_full_like\n");
|
|
return -1;
|
|
}
|
|
mlx_gather_ptr = GET_SYM(handle, "mlx_gather");
|
|
if (mlx_gather_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_gather\n");
|
|
return -1;
|
|
}
|
|
mlx_gather_single_ptr = GET_SYM(handle, "mlx_gather_single");
|
|
if (mlx_gather_single_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_single\n");
|
|
return -1;
|
|
}
|
|
mlx_gather_mm_ptr = GET_SYM(handle, "mlx_gather_mm");
|
|
if (mlx_gather_mm_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_mm\n");
|
|
return -1;
|
|
}
|
|
mlx_gather_qmm_ptr = GET_SYM(handle, "mlx_gather_qmm");
|
|
if (mlx_gather_qmm_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_qmm\n");
|
|
return -1;
|
|
}
|
|
mlx_greater_ptr = GET_SYM(handle, "mlx_greater");
|
|
if (mlx_greater_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_greater\n");
|
|
return -1;
|
|
}
|
|
mlx_greater_equal_ptr = GET_SYM(handle, "mlx_greater_equal");
|
|
if (mlx_greater_equal_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_greater_equal\n");
|
|
return -1;
|
|
}
|
|
mlx_hadamard_transform_ptr = GET_SYM(handle, "mlx_hadamard_transform");
|
|
if (mlx_hadamard_transform_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n");
|
|
return -1;
|
|
}
|
|
mlx_identity_ptr = GET_SYM(handle, "mlx_identity");
|
|
if (mlx_identity_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n");
|
|
return -1;
|
|
}
|
|
mlx_imag_ptr = GET_SYM(handle, "mlx_imag");
|
|
if (mlx_imag_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_imag\n");
|
|
return -1;
|
|
}
|
|
mlx_inner_ptr = GET_SYM(handle, "mlx_inner");
|
|
if (mlx_inner_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_inner\n");
|
|
return -1;
|
|
}
|
|
mlx_isclose_ptr = GET_SYM(handle, "mlx_isclose");
|
|
if (mlx_isclose_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_isclose\n");
|
|
return -1;
|
|
}
|
|
mlx_isfinite_ptr = GET_SYM(handle, "mlx_isfinite");
|
|
if (mlx_isfinite_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_isfinite\n");
|
|
return -1;
|
|
}
|
|
mlx_isinf_ptr = GET_SYM(handle, "mlx_isinf");
|
|
if (mlx_isinf_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_isinf\n");
|
|
return -1;
|
|
}
|
|
mlx_isnan_ptr = GET_SYM(handle, "mlx_isnan");
|
|
if (mlx_isnan_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_isnan\n");
|
|
return -1;
|
|
}
|
|
mlx_isneginf_ptr = GET_SYM(handle, "mlx_isneginf");
|
|
if (mlx_isneginf_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_isneginf\n");
|
|
return -1;
|
|
}
|
|
mlx_isposinf_ptr = GET_SYM(handle, "mlx_isposinf");
|
|
if (mlx_isposinf_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_isposinf\n");
|
|
return -1;
|
|
}
|
|
mlx_kron_ptr = GET_SYM(handle, "mlx_kron");
|
|
if (mlx_kron_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_kron\n");
|
|
return -1;
|
|
}
|
|
mlx_left_shift_ptr = GET_SYM(handle, "mlx_left_shift");
|
|
if (mlx_left_shift_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_left_shift\n");
|
|
return -1;
|
|
}
|
|
mlx_less_ptr = GET_SYM(handle, "mlx_less");
|
|
if (mlx_less_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_less\n");
|
|
return -1;
|
|
}
|
|
mlx_less_equal_ptr = GET_SYM(handle, "mlx_less_equal");
|
|
if (mlx_less_equal_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_less_equal\n");
|
|
return -1;
|
|
}
|
|
mlx_linspace_ptr = GET_SYM(handle, "mlx_linspace");
|
|
if (mlx_linspace_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_linspace\n");
|
|
return -1;
|
|
}
|
|
mlx_log_ptr = GET_SYM(handle, "mlx_log");
|
|
if (mlx_log_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_log\n");
|
|
return -1;
|
|
}
|
|
mlx_log10_ptr = GET_SYM(handle, "mlx_log10");
|
|
if (mlx_log10_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_log10\n");
|
|
return -1;
|
|
}
|
|
mlx_log1p_ptr = GET_SYM(handle, "mlx_log1p");
|
|
if (mlx_log1p_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_log1p\n");
|
|
return -1;
|
|
}
|
|
mlx_log2_ptr = GET_SYM(handle, "mlx_log2");
|
|
if (mlx_log2_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_log2\n");
|
|
return -1;
|
|
}
|
|
mlx_logaddexp_ptr = GET_SYM(handle, "mlx_logaddexp");
|
|
if (mlx_logaddexp_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_logaddexp\n");
|
|
return -1;
|
|
}
|
|
mlx_logcumsumexp_ptr = GET_SYM(handle, "mlx_logcumsumexp");
|
|
if (mlx_logcumsumexp_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_logcumsumexp\n");
|
|
return -1;
|
|
}
|
|
mlx_logical_and_ptr = GET_SYM(handle, "mlx_logical_and");
|
|
if (mlx_logical_and_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_and\n");
|
|
return -1;
|
|
}
|
|
mlx_logical_not_ptr = GET_SYM(handle, "mlx_logical_not");
|
|
if (mlx_logical_not_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_not\n");
|
|
return -1;
|
|
}
|
|
mlx_logical_or_ptr = GET_SYM(handle, "mlx_logical_or");
|
|
if (mlx_logical_or_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_or\n");
|
|
return -1;
|
|
}
|
|
mlx_logsumexp_axes_ptr = GET_SYM(handle, "mlx_logsumexp_axes");
|
|
if (mlx_logsumexp_axes_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp_axes\n");
|
|
return -1;
|
|
}
|
|
mlx_logsumexp_axis_ptr = GET_SYM(handle, "mlx_logsumexp_axis");
|
|
if (mlx_logsumexp_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_logsumexp_ptr = GET_SYM(handle, "mlx_logsumexp");
|
|
if (mlx_logsumexp_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp\n");
|
|
return -1;
|
|
}
|
|
mlx_masked_scatter_ptr = GET_SYM(handle, "mlx_masked_scatter");
|
|
if (mlx_masked_scatter_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_masked_scatter\n");
|
|
return -1;
|
|
}
|
|
mlx_matmul_ptr = GET_SYM(handle, "mlx_matmul");
|
|
if (mlx_matmul_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_matmul\n");
|
|
return -1;
|
|
}
|
|
mlx_max_axes_ptr = GET_SYM(handle, "mlx_max_axes");
|
|
if (mlx_max_axes_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_max_axes\n");
|
|
return -1;
|
|
}
|
|
mlx_max_axis_ptr = GET_SYM(handle, "mlx_max_axis");
|
|
if (mlx_max_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_max_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_max_ptr = GET_SYM(handle, "mlx_max");
|
|
if (mlx_max_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_max\n");
|
|
return -1;
|
|
}
|
|
mlx_maximum_ptr = GET_SYM(handle, "mlx_maximum");
|
|
if (mlx_maximum_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_maximum\n");
|
|
return -1;
|
|
}
|
|
mlx_mean_axes_ptr = GET_SYM(handle, "mlx_mean_axes");
|
|
if (mlx_mean_axes_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_mean_axes\n");
|
|
return -1;
|
|
}
|
|
mlx_mean_axis_ptr = GET_SYM(handle, "mlx_mean_axis");
|
|
if (mlx_mean_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_mean_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_mean_ptr = GET_SYM(handle, "mlx_mean");
|
|
if (mlx_mean_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_mean\n");
|
|
return -1;
|
|
}
|
|
mlx_median_ptr = GET_SYM(handle, "mlx_median");
|
|
if (mlx_median_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_median\n");
|
|
return -1;
|
|
}
|
|
mlx_meshgrid_ptr = GET_SYM(handle, "mlx_meshgrid");
|
|
if (mlx_meshgrid_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_meshgrid\n");
|
|
return -1;
|
|
}
|
|
mlx_min_axes_ptr = GET_SYM(handle, "mlx_min_axes");
|
|
if (mlx_min_axes_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_min_axes\n");
|
|
return -1;
|
|
}
|
|
mlx_min_axis_ptr = GET_SYM(handle, "mlx_min_axis");
|
|
if (mlx_min_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_min_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_min_ptr = GET_SYM(handle, "mlx_min");
|
|
if (mlx_min_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_min\n");
|
|
return -1;
|
|
}
|
|
mlx_minimum_ptr = GET_SYM(handle, "mlx_minimum");
|
|
if (mlx_minimum_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_minimum\n");
|
|
return -1;
|
|
}
|
|
mlx_moveaxis_ptr = GET_SYM(handle, "mlx_moveaxis");
|
|
if (mlx_moveaxis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_moveaxis\n");
|
|
return -1;
|
|
}
|
|
mlx_multiply_ptr = GET_SYM(handle, "mlx_multiply");
|
|
if (mlx_multiply_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_multiply\n");
|
|
return -1;
|
|
}
|
|
mlx_nan_to_num_ptr = GET_SYM(handle, "mlx_nan_to_num");
|
|
if (mlx_nan_to_num_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_nan_to_num\n");
|
|
return -1;
|
|
}
|
|
mlx_negative_ptr = GET_SYM(handle, "mlx_negative");
|
|
if (mlx_negative_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_negative\n");
|
|
return -1;
|
|
}
|
|
mlx_not_equal_ptr = GET_SYM(handle, "mlx_not_equal");
|
|
if (mlx_not_equal_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_not_equal\n");
|
|
return -1;
|
|
}
|
|
mlx_number_of_elements_ptr = GET_SYM(handle, "mlx_number_of_elements");
|
|
if (mlx_number_of_elements_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_number_of_elements\n");
|
|
return -1;
|
|
}
|
|
mlx_ones_ptr = GET_SYM(handle, "mlx_ones");
|
|
if (mlx_ones_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_ones\n");
|
|
return -1;
|
|
}
|
|
mlx_ones_like_ptr = GET_SYM(handle, "mlx_ones_like");
|
|
if (mlx_ones_like_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_ones_like\n");
|
|
return -1;
|
|
}
|
|
mlx_outer_ptr = GET_SYM(handle, "mlx_outer");
|
|
if (mlx_outer_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_outer\n");
|
|
return -1;
|
|
}
|
|
mlx_pad_ptr = GET_SYM(handle, "mlx_pad");
|
|
if (mlx_pad_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_pad\n");
|
|
return -1;
|
|
}
|
|
mlx_pad_symmetric_ptr = GET_SYM(handle, "mlx_pad_symmetric");
|
|
if (mlx_pad_symmetric_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_pad_symmetric\n");
|
|
return -1;
|
|
}
|
|
mlx_partition_axis_ptr = GET_SYM(handle, "mlx_partition_axis");
|
|
if (mlx_partition_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_partition_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_partition_ptr = GET_SYM(handle, "mlx_partition");
|
|
if (mlx_partition_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_partition\n");
|
|
return -1;
|
|
}
|
|
mlx_power_ptr = GET_SYM(handle, "mlx_power");
|
|
if (mlx_power_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_power\n");
|
|
return -1;
|
|
}
|
|
mlx_prod_axes_ptr = GET_SYM(handle, "mlx_prod_axes");
|
|
if (mlx_prod_axes_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_prod_axes\n");
|
|
return -1;
|
|
}
|
|
mlx_prod_axis_ptr = GET_SYM(handle, "mlx_prod_axis");
|
|
if (mlx_prod_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_prod_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_prod_ptr = GET_SYM(handle, "mlx_prod");
|
|
if (mlx_prod_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_prod\n");
|
|
return -1;
|
|
}
|
|
mlx_put_along_axis_ptr = GET_SYM(handle, "mlx_put_along_axis");
|
|
if (mlx_put_along_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_put_along_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_qqmm_ptr = GET_SYM(handle, "mlx_qqmm");
|
|
if (mlx_qqmm_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_qqmm\n");
|
|
return -1;
|
|
}
|
|
mlx_quantize_ptr = GET_SYM(handle, "mlx_quantize");
|
|
if (mlx_quantize_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_quantize\n");
|
|
return -1;
|
|
}
|
|
mlx_quantized_matmul_ptr = GET_SYM(handle, "mlx_quantized_matmul");
|
|
if (mlx_quantized_matmul_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_quantized_matmul\n");
|
|
return -1;
|
|
}
|
|
mlx_radians_ptr = GET_SYM(handle, "mlx_radians");
|
|
if (mlx_radians_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_radians\n");
|
|
return -1;
|
|
}
|
|
mlx_real_ptr = GET_SYM(handle, "mlx_real");
|
|
if (mlx_real_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_real\n");
|
|
return -1;
|
|
}
|
|
mlx_reciprocal_ptr = GET_SYM(handle, "mlx_reciprocal");
|
|
if (mlx_reciprocal_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_reciprocal\n");
|
|
return -1;
|
|
}
|
|
mlx_remainder_ptr = GET_SYM(handle, "mlx_remainder");
|
|
if (mlx_remainder_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_remainder\n");
|
|
return -1;
|
|
}
|
|
mlx_repeat_axis_ptr = GET_SYM(handle, "mlx_repeat_axis");
|
|
if (mlx_repeat_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_repeat_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_repeat_ptr = GET_SYM(handle, "mlx_repeat");
|
|
if (mlx_repeat_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_repeat\n");
|
|
return -1;
|
|
}
|
|
mlx_reshape_ptr = GET_SYM(handle, "mlx_reshape");
|
|
if (mlx_reshape_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_reshape\n");
|
|
return -1;
|
|
}
|
|
mlx_right_shift_ptr = GET_SYM(handle, "mlx_right_shift");
|
|
if (mlx_right_shift_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_right_shift\n");
|
|
return -1;
|
|
}
|
|
mlx_roll_axis_ptr = GET_SYM(handle, "mlx_roll_axis");
|
|
if (mlx_roll_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_roll_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_roll_axes_ptr = GET_SYM(handle, "mlx_roll_axes");
|
|
if (mlx_roll_axes_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_roll_axes\n");
|
|
return -1;
|
|
}
|
|
mlx_roll_ptr = GET_SYM(handle, "mlx_roll");
|
|
if (mlx_roll_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_roll\n");
|
|
return -1;
|
|
}
|
|
mlx_round_ptr = GET_SYM(handle, "mlx_round");
|
|
if (mlx_round_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_round\n");
|
|
return -1;
|
|
}
|
|
mlx_rsqrt_ptr = GET_SYM(handle, "mlx_rsqrt");
|
|
if (mlx_rsqrt_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_rsqrt\n");
|
|
return -1;
|
|
}
|
|
mlx_scatter_ptr = GET_SYM(handle, "mlx_scatter");
|
|
if (mlx_scatter_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter\n");
|
|
return -1;
|
|
}
|
|
mlx_scatter_single_ptr = GET_SYM(handle, "mlx_scatter_single");
|
|
if (mlx_scatter_single_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_single\n");
|
|
return -1;
|
|
}
|
|
mlx_scatter_add_ptr = GET_SYM(handle, "mlx_scatter_add");
|
|
if (mlx_scatter_add_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add\n");
|
|
return -1;
|
|
}
|
|
mlx_scatter_add_single_ptr = GET_SYM(handle, "mlx_scatter_add_single");
|
|
if (mlx_scatter_add_single_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add_single\n");
|
|
return -1;
|
|
}
|
|
mlx_scatter_add_axis_ptr = GET_SYM(handle, "mlx_scatter_add_axis");
|
|
if (mlx_scatter_add_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_scatter_max_ptr = GET_SYM(handle, "mlx_scatter_max");
|
|
if (mlx_scatter_max_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_max\n");
|
|
return -1;
|
|
}
|
|
mlx_scatter_max_single_ptr = GET_SYM(handle, "mlx_scatter_max_single");
|
|
if (mlx_scatter_max_single_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_max_single\n");
|
|
return -1;
|
|
}
|
|
mlx_scatter_min_ptr = GET_SYM(handle, "mlx_scatter_min");
|
|
if (mlx_scatter_min_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_min\n");
|
|
return -1;
|
|
}
|
|
mlx_scatter_min_single_ptr = GET_SYM(handle, "mlx_scatter_min_single");
|
|
if (mlx_scatter_min_single_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_min_single\n");
|
|
return -1;
|
|
}
|
|
mlx_scatter_prod_ptr = GET_SYM(handle, "mlx_scatter_prod");
|
|
if (mlx_scatter_prod_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_prod\n");
|
|
return -1;
|
|
}
|
|
mlx_scatter_prod_single_ptr = GET_SYM(handle, "mlx_scatter_prod_single");
|
|
if (mlx_scatter_prod_single_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_prod_single\n");
|
|
return -1;
|
|
}
|
|
mlx_segmented_mm_ptr = GET_SYM(handle, "mlx_segmented_mm");
|
|
if (mlx_segmented_mm_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_segmented_mm\n");
|
|
return -1;
|
|
}
|
|
mlx_sigmoid_ptr = GET_SYM(handle, "mlx_sigmoid");
|
|
if (mlx_sigmoid_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_sigmoid\n");
|
|
return -1;
|
|
}
|
|
mlx_sign_ptr = GET_SYM(handle, "mlx_sign");
|
|
if (mlx_sign_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_sign\n");
|
|
return -1;
|
|
}
|
|
mlx_sin_ptr = GET_SYM(handle, "mlx_sin");
|
|
if (mlx_sin_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_sin\n");
|
|
return -1;
|
|
}
|
|
mlx_sinh_ptr = GET_SYM(handle, "mlx_sinh");
|
|
if (mlx_sinh_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_sinh\n");
|
|
return -1;
|
|
}
|
|
mlx_slice_ptr = GET_SYM(handle, "mlx_slice");
|
|
if (mlx_slice_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_slice\n");
|
|
return -1;
|
|
}
|
|
mlx_slice_dynamic_ptr = GET_SYM(handle, "mlx_slice_dynamic");
|
|
if (mlx_slice_dynamic_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_dynamic\n");
|
|
return -1;
|
|
}
|
|
mlx_slice_update_ptr = GET_SYM(handle, "mlx_slice_update");
|
|
if (mlx_slice_update_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_update\n");
|
|
return -1;
|
|
}
|
|
mlx_slice_update_dynamic_ptr = GET_SYM(handle, "mlx_slice_update_dynamic");
|
|
if (mlx_slice_update_dynamic_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_update_dynamic\n");
|
|
return -1;
|
|
}
|
|
mlx_softmax_axes_ptr = GET_SYM(handle, "mlx_softmax_axes");
|
|
if (mlx_softmax_axes_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax_axes\n");
|
|
return -1;
|
|
}
|
|
mlx_softmax_axis_ptr = GET_SYM(handle, "mlx_softmax_axis");
|
|
if (mlx_softmax_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_softmax_ptr = GET_SYM(handle, "mlx_softmax");
|
|
if (mlx_softmax_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax\n");
|
|
return -1;
|
|
}
|
|
mlx_sort_axis_ptr = GET_SYM(handle, "mlx_sort_axis");
|
|
if (mlx_sort_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_sort_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_sort_ptr = GET_SYM(handle, "mlx_sort");
|
|
if (mlx_sort_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_sort\n");
|
|
return -1;
|
|
}
|
|
mlx_split_ptr = GET_SYM(handle, "mlx_split");
|
|
if (mlx_split_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_split\n");
|
|
return -1;
|
|
}
|
|
mlx_split_sections_ptr = GET_SYM(handle, "mlx_split_sections");
|
|
if (mlx_split_sections_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_split_sections\n");
|
|
return -1;
|
|
}
|
|
mlx_sqrt_ptr = GET_SYM(handle, "mlx_sqrt");
|
|
if (mlx_sqrt_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_sqrt\n");
|
|
return -1;
|
|
}
|
|
mlx_square_ptr = GET_SYM(handle, "mlx_square");
|
|
if (mlx_square_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_square\n");
|
|
return -1;
|
|
}
|
|
mlx_squeeze_axes_ptr = GET_SYM(handle, "mlx_squeeze_axes");
|
|
if (mlx_squeeze_axes_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze_axes\n");
|
|
return -1;
|
|
}
|
|
mlx_squeeze_axis_ptr = GET_SYM(handle, "mlx_squeeze_axis");
|
|
if (mlx_squeeze_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_squeeze_ptr = GET_SYM(handle, "mlx_squeeze");
|
|
if (mlx_squeeze_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze\n");
|
|
return -1;
|
|
}
|
|
mlx_stack_axis_ptr = GET_SYM(handle, "mlx_stack_axis");
|
|
if (mlx_stack_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_stack_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_stack_ptr = GET_SYM(handle, "mlx_stack");
|
|
if (mlx_stack_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_stack\n");
|
|
return -1;
|
|
}
|
|
mlx_std_axes_ptr = GET_SYM(handle, "mlx_std_axes");
|
|
if (mlx_std_axes_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_std_axes\n");
|
|
return -1;
|
|
}
|
|
mlx_std_axis_ptr = GET_SYM(handle, "mlx_std_axis");
|
|
if (mlx_std_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_std_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_std_ptr = GET_SYM(handle, "mlx_std");
|
|
if (mlx_std_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_std\n");
|
|
return -1;
|
|
}
|
|
mlx_stop_gradient_ptr = GET_SYM(handle, "mlx_stop_gradient");
|
|
if (mlx_stop_gradient_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_stop_gradient\n");
|
|
return -1;
|
|
}
|
|
mlx_subtract_ptr = GET_SYM(handle, "mlx_subtract");
|
|
if (mlx_subtract_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_subtract\n");
|
|
return -1;
|
|
}
|
|
mlx_sum_axes_ptr = GET_SYM(handle, "mlx_sum_axes");
|
|
if (mlx_sum_axes_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_sum_axes\n");
|
|
return -1;
|
|
}
|
|
mlx_sum_axis_ptr = GET_SYM(handle, "mlx_sum_axis");
|
|
if (mlx_sum_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_sum_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_sum_ptr = GET_SYM(handle, "mlx_sum");
|
|
if (mlx_sum_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_sum\n");
|
|
return -1;
|
|
}
|
|
mlx_swapaxes_ptr = GET_SYM(handle, "mlx_swapaxes");
|
|
if (mlx_swapaxes_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_swapaxes\n");
|
|
return -1;
|
|
}
|
|
mlx_take_axis_ptr = GET_SYM(handle, "mlx_take_axis");
|
|
if (mlx_take_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_take_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_take_ptr = GET_SYM(handle, "mlx_take");
|
|
if (mlx_take_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_take\n");
|
|
return -1;
|
|
}
|
|
mlx_take_along_axis_ptr = GET_SYM(handle, "mlx_take_along_axis");
|
|
if (mlx_take_along_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_take_along_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_tan_ptr = GET_SYM(handle, "mlx_tan");
|
|
if (mlx_tan_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_tan\n");
|
|
return -1;
|
|
}
|
|
mlx_tanh_ptr = GET_SYM(handle, "mlx_tanh");
|
|
if (mlx_tanh_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_tanh\n");
|
|
return -1;
|
|
}
|
|
mlx_tensordot_ptr = GET_SYM(handle, "mlx_tensordot");
|
|
if (mlx_tensordot_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_tensordot\n");
|
|
return -1;
|
|
}
|
|
mlx_tensordot_axis_ptr = GET_SYM(handle, "mlx_tensordot_axis");
|
|
if (mlx_tensordot_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_tensordot_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_tile_ptr = GET_SYM(handle, "mlx_tile");
|
|
if (mlx_tile_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_tile\n");
|
|
return -1;
|
|
}
|
|
mlx_to_fp8_ptr = GET_SYM(handle, "mlx_to_fp8");
|
|
if (mlx_to_fp8_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_to_fp8\n");
|
|
return -1;
|
|
}
|
|
mlx_topk_axis_ptr = GET_SYM(handle, "mlx_topk_axis");
|
|
if (mlx_topk_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_topk_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_topk_ptr = GET_SYM(handle, "mlx_topk");
|
|
if (mlx_topk_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_topk\n");
|
|
return -1;
|
|
}
|
|
mlx_trace_ptr = GET_SYM(handle, "mlx_trace");
|
|
if (mlx_trace_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_trace\n");
|
|
return -1;
|
|
}
|
|
mlx_transpose_axes_ptr = GET_SYM(handle, "mlx_transpose_axes");
|
|
if (mlx_transpose_axes_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_transpose_axes\n");
|
|
return -1;
|
|
}
|
|
mlx_transpose_ptr = GET_SYM(handle, "mlx_transpose");
|
|
if (mlx_transpose_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_transpose\n");
|
|
return -1;
|
|
}
|
|
mlx_tri_ptr = GET_SYM(handle, "mlx_tri");
|
|
if (mlx_tri_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_tri\n");
|
|
return -1;
|
|
}
|
|
mlx_tril_ptr = GET_SYM(handle, "mlx_tril");
|
|
if (mlx_tril_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_tril\n");
|
|
return -1;
|
|
}
|
|
mlx_triu_ptr = GET_SYM(handle, "mlx_triu");
|
|
if (mlx_triu_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_triu\n");
|
|
return -1;
|
|
}
|
|
mlx_unflatten_ptr = GET_SYM(handle, "mlx_unflatten");
|
|
if (mlx_unflatten_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_unflatten\n");
|
|
return -1;
|
|
}
|
|
mlx_var_axes_ptr = GET_SYM(handle, "mlx_var_axes");
|
|
if (mlx_var_axes_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_var_axes\n");
|
|
return -1;
|
|
}
|
|
mlx_var_axis_ptr = GET_SYM(handle, "mlx_var_axis");
|
|
if (mlx_var_axis_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_var_axis\n");
|
|
return -1;
|
|
}
|
|
mlx_var_ptr = GET_SYM(handle, "mlx_var");
|
|
if (mlx_var_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_var\n");
|
|
return -1;
|
|
}
|
|
mlx_view_ptr = GET_SYM(handle, "mlx_view");
|
|
if (mlx_view_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_view\n");
|
|
return -1;
|
|
}
|
|
mlx_where_ptr = GET_SYM(handle, "mlx_where");
|
|
if (mlx_where_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_where\n");
|
|
return -1;
|
|
}
|
|
mlx_zeros_ptr = GET_SYM(handle, "mlx_zeros");
|
|
if (mlx_zeros_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_zeros\n");
|
|
return -1;
|
|
}
|
|
mlx_zeros_like_ptr = GET_SYM(handle, "mlx_zeros_like");
|
|
if (mlx_zeros_like_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_zeros_like\n");
|
|
return -1;
|
|
}
|
|
mlx_random_bernoulli_ptr = GET_SYM(handle, "mlx_random_bernoulli");
|
|
if (mlx_random_bernoulli_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_bernoulli\n");
|
|
return -1;
|
|
}
|
|
mlx_random_bits_ptr = GET_SYM(handle, "mlx_random_bits");
|
|
if (mlx_random_bits_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_bits\n");
|
|
return -1;
|
|
}
|
|
mlx_random_categorical_shape_ptr = GET_SYM(handle, "mlx_random_categorical_shape");
|
|
if (mlx_random_categorical_shape_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical_shape\n");
|
|
return -1;
|
|
}
|
|
mlx_random_categorical_num_samples_ptr = GET_SYM(handle, "mlx_random_categorical_num_samples");
|
|
if (mlx_random_categorical_num_samples_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical_num_samples\n");
|
|
return -1;
|
|
}
|
|
mlx_random_categorical_ptr = GET_SYM(handle, "mlx_random_categorical");
|
|
if (mlx_random_categorical_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical\n");
|
|
return -1;
|
|
}
|
|
mlx_random_gumbel_ptr = GET_SYM(handle, "mlx_random_gumbel");
|
|
if (mlx_random_gumbel_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_gumbel\n");
|
|
return -1;
|
|
}
|
|
mlx_random_key_ptr = GET_SYM(handle, "mlx_random_key");
|
|
if (mlx_random_key_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_key\n");
|
|
return -1;
|
|
}
|
|
mlx_random_laplace_ptr = GET_SYM(handle, "mlx_random_laplace");
|
|
if (mlx_random_laplace_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_laplace\n");
|
|
return -1;
|
|
}
|
|
mlx_random_multivariate_normal_ptr = GET_SYM(handle, "mlx_random_multivariate_normal");
|
|
if (mlx_random_multivariate_normal_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_multivariate_normal\n");
|
|
return -1;
|
|
}
|
|
mlx_random_normal_broadcast_ptr = GET_SYM(handle, "mlx_random_normal_broadcast");
|
|
if (mlx_random_normal_broadcast_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_normal_broadcast\n");
|
|
return -1;
|
|
}
|
|
mlx_random_normal_ptr = GET_SYM(handle, "mlx_random_normal");
|
|
if (mlx_random_normal_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_normal\n");
|
|
return -1;
|
|
}
|
|
mlx_random_permutation_ptr = GET_SYM(handle, "mlx_random_permutation");
|
|
if (mlx_random_permutation_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_permutation\n");
|
|
return -1;
|
|
}
|
|
mlx_random_permutation_arange_ptr = GET_SYM(handle, "mlx_random_permutation_arange");
|
|
if (mlx_random_permutation_arange_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_permutation_arange\n");
|
|
return -1;
|
|
}
|
|
mlx_random_randint_ptr = GET_SYM(handle, "mlx_random_randint");
|
|
if (mlx_random_randint_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_randint\n");
|
|
return -1;
|
|
}
|
|
mlx_random_seed_ptr = GET_SYM(handle, "mlx_random_seed");
|
|
if (mlx_random_seed_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_seed\n");
|
|
return -1;
|
|
}
|
|
mlx_random_split_num_ptr = GET_SYM(handle, "mlx_random_split_num");
|
|
if (mlx_random_split_num_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_split_num\n");
|
|
return -1;
|
|
}
|
|
mlx_random_split_ptr = GET_SYM(handle, "mlx_random_split");
|
|
if (mlx_random_split_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_split\n");
|
|
return -1;
|
|
}
|
|
mlx_random_truncated_normal_ptr = GET_SYM(handle, "mlx_random_truncated_normal");
|
|
if (mlx_random_truncated_normal_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_truncated_normal\n");
|
|
return -1;
|
|
}
|
|
mlx_random_uniform_ptr = GET_SYM(handle, "mlx_random_uniform");
|
|
if (mlx_random_uniform_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_uniform\n");
|
|
return -1;
|
|
}
|
|
mlx_stream_new_ptr = GET_SYM(handle, "mlx_stream_new");
|
|
if (mlx_stream_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_new\n");
|
|
return -1;
|
|
}
|
|
mlx_stream_new_device_ptr = GET_SYM(handle, "mlx_stream_new_device");
|
|
if (mlx_stream_new_device_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_new_device\n");
|
|
return -1;
|
|
}
|
|
mlx_stream_set_ptr = GET_SYM(handle, "mlx_stream_set");
|
|
if (mlx_stream_set_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_set\n");
|
|
return -1;
|
|
}
|
|
mlx_stream_free_ptr = GET_SYM(handle, "mlx_stream_free");
|
|
if (mlx_stream_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_free\n");
|
|
return -1;
|
|
}
|
|
mlx_stream_tostring_ptr = GET_SYM(handle, "mlx_stream_tostring");
|
|
if (mlx_stream_tostring_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_tostring\n");
|
|
return -1;
|
|
}
|
|
mlx_stream_equal_ptr = GET_SYM(handle, "mlx_stream_equal");
|
|
if (mlx_stream_equal_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_equal\n");
|
|
return -1;
|
|
}
|
|
mlx_stream_get_device_ptr = GET_SYM(handle, "mlx_stream_get_device");
|
|
if (mlx_stream_get_device_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_get_device\n");
|
|
return -1;
|
|
}
|
|
mlx_stream_get_index_ptr = GET_SYM(handle, "mlx_stream_get_index");
|
|
if (mlx_stream_get_index_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_get_index\n");
|
|
return -1;
|
|
}
|
|
mlx_synchronize_ptr = GET_SYM(handle, "mlx_synchronize");
|
|
if (mlx_synchronize_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_synchronize\n");
|
|
return -1;
|
|
}
|
|
mlx_get_default_stream_ptr = GET_SYM(handle, "mlx_get_default_stream");
|
|
if (mlx_get_default_stream_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_get_default_stream\n");
|
|
return -1;
|
|
}
|
|
mlx_set_default_stream_ptr = GET_SYM(handle, "mlx_set_default_stream");
|
|
if (mlx_set_default_stream_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_stream\n");
|
|
return -1;
|
|
}
|
|
mlx_default_cpu_stream_new_ptr = GET_SYM(handle, "mlx_default_cpu_stream_new");
|
|
if (mlx_default_cpu_stream_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_default_cpu_stream_new\n");
|
|
return -1;
|
|
}
|
|
mlx_default_gpu_stream_new_ptr = GET_SYM(handle, "mlx_default_gpu_stream_new");
|
|
if (mlx_default_gpu_stream_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_default_gpu_stream_new\n");
|
|
return -1;
|
|
}
|
|
mlx_string_new_ptr = GET_SYM(handle, "mlx_string_new");
|
|
if (mlx_string_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_string_new\n");
|
|
return -1;
|
|
}
|
|
mlx_string_new_data_ptr = GET_SYM(handle, "mlx_string_new_data");
|
|
if (mlx_string_new_data_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_string_new_data\n");
|
|
return -1;
|
|
}
|
|
mlx_string_set_ptr = GET_SYM(handle, "mlx_string_set");
|
|
if (mlx_string_set_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_string_set\n");
|
|
return -1;
|
|
}
|
|
mlx_string_data_ptr = GET_SYM(handle, "mlx_string_data");
|
|
if (mlx_string_data_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_string_data\n");
|
|
return -1;
|
|
}
|
|
mlx_string_free_ptr = GET_SYM(handle, "mlx_string_free");
|
|
if (mlx_string_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_string_free\n");
|
|
return -1;
|
|
}
|
|
mlx_async_eval_ptr = GET_SYM(handle, "mlx_async_eval");
|
|
if (mlx_async_eval_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_async_eval\n");
|
|
return -1;
|
|
}
|
|
mlx_checkpoint_ptr = GET_SYM(handle, "mlx_checkpoint");
|
|
if (mlx_checkpoint_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_checkpoint\n");
|
|
return -1;
|
|
}
|
|
mlx_custom_function_ptr = GET_SYM(handle, "mlx_custom_function");
|
|
if (mlx_custom_function_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_custom_function\n");
|
|
return -1;
|
|
}
|
|
mlx_custom_vjp_ptr = GET_SYM(handle, "mlx_custom_vjp");
|
|
if (mlx_custom_vjp_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_custom_vjp\n");
|
|
return -1;
|
|
}
|
|
mlx_eval_ptr = GET_SYM(handle, "mlx_eval");
|
|
if (mlx_eval_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_eval\n");
|
|
return -1;
|
|
}
|
|
mlx_jvp_ptr = GET_SYM(handle, "mlx_jvp");
|
|
if (mlx_jvp_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_jvp\n");
|
|
return -1;
|
|
}
|
|
mlx_value_and_grad_ptr = GET_SYM(handle, "mlx_value_and_grad");
|
|
if (mlx_value_and_grad_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_value_and_grad\n");
|
|
return -1;
|
|
}
|
|
mlx_vjp_ptr = GET_SYM(handle, "mlx_vjp");
|
|
if (mlx_vjp_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vjp\n");
|
|
return -1;
|
|
}
|
|
mlx_detail_vmap_replace_ptr = GET_SYM(handle, "mlx_detail_vmap_replace");
|
|
if (mlx_detail_vmap_replace_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_vmap_replace\n");
|
|
return -1;
|
|
}
|
|
mlx_detail_vmap_trace_ptr = GET_SYM(handle, "mlx_detail_vmap_trace");
|
|
if (mlx_detail_vmap_trace_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_vmap_trace\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_array_new_ptr = GET_SYM(handle, "mlx_vector_array_new");
|
|
if (mlx_vector_array_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_array_set_ptr = GET_SYM(handle, "mlx_vector_array_set");
|
|
if (mlx_vector_array_set_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_array_free_ptr = GET_SYM(handle, "mlx_vector_array_free");
|
|
if (mlx_vector_array_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_free\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_array_new_data_ptr = GET_SYM(handle, "mlx_vector_array_new_data");
|
|
if (mlx_vector_array_new_data_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new_data\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_array_new_value_ptr = GET_SYM(handle, "mlx_vector_array_new_value");
|
|
if (mlx_vector_array_new_value_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new_value\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_array_set_data_ptr = GET_SYM(handle, "mlx_vector_array_set_data");
|
|
if (mlx_vector_array_set_data_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set_data\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_array_set_value_ptr = GET_SYM(handle, "mlx_vector_array_set_value");
|
|
if (mlx_vector_array_set_value_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set_value\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_array_append_data_ptr = GET_SYM(handle, "mlx_vector_array_append_data");
|
|
if (mlx_vector_array_append_data_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_append_data\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_array_append_value_ptr = GET_SYM(handle, "mlx_vector_array_append_value");
|
|
if (mlx_vector_array_append_value_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_append_value\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_array_size_ptr = GET_SYM(handle, "mlx_vector_array_size");
|
|
if (mlx_vector_array_size_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_size\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_array_get_ptr = GET_SYM(handle, "mlx_vector_array_get");
|
|
if (mlx_vector_array_get_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_get\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_vector_array_new_ptr = GET_SYM(handle, "mlx_vector_vector_array_new");
|
|
if (mlx_vector_vector_array_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_vector_array_set_ptr = GET_SYM(handle, "mlx_vector_vector_array_set");
|
|
if (mlx_vector_vector_array_set_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_vector_array_free_ptr = GET_SYM(handle, "mlx_vector_vector_array_free");
|
|
if (mlx_vector_vector_array_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_free\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_vector_array_new_data_ptr = GET_SYM(handle, "mlx_vector_vector_array_new_data");
|
|
if (mlx_vector_vector_array_new_data_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new_data\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_vector_array_new_value_ptr = GET_SYM(handle, "mlx_vector_vector_array_new_value");
|
|
if (mlx_vector_vector_array_new_value_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new_value\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_vector_array_set_data_ptr = GET_SYM(handle, "mlx_vector_vector_array_set_data");
|
|
if (mlx_vector_vector_array_set_data_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set_data\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_vector_array_set_value_ptr = GET_SYM(handle, "mlx_vector_vector_array_set_value");
|
|
if (mlx_vector_vector_array_set_value_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set_value\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_vector_array_append_data_ptr = GET_SYM(handle, "mlx_vector_vector_array_append_data");
|
|
if (mlx_vector_vector_array_append_data_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_append_data\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_vector_array_append_value_ptr = GET_SYM(handle, "mlx_vector_vector_array_append_value");
|
|
if (mlx_vector_vector_array_append_value_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_append_value\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_vector_array_size_ptr = GET_SYM(handle, "mlx_vector_vector_array_size");
|
|
if (mlx_vector_vector_array_size_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_size\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_vector_array_get_ptr = GET_SYM(handle, "mlx_vector_vector_array_get");
|
|
if (mlx_vector_vector_array_get_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_get\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_int_new_ptr = GET_SYM(handle, "mlx_vector_int_new");
|
|
if (mlx_vector_int_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_int_set_ptr = GET_SYM(handle, "mlx_vector_int_set");
|
|
if (mlx_vector_int_set_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_int_free_ptr = GET_SYM(handle, "mlx_vector_int_free");
|
|
if (mlx_vector_int_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_free\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_int_new_data_ptr = GET_SYM(handle, "mlx_vector_int_new_data");
|
|
if (mlx_vector_int_new_data_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new_data\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_int_new_value_ptr = GET_SYM(handle, "mlx_vector_int_new_value");
|
|
if (mlx_vector_int_new_value_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new_value\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_int_set_data_ptr = GET_SYM(handle, "mlx_vector_int_set_data");
|
|
if (mlx_vector_int_set_data_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set_data\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_int_set_value_ptr = GET_SYM(handle, "mlx_vector_int_set_value");
|
|
if (mlx_vector_int_set_value_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set_value\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_int_append_data_ptr = GET_SYM(handle, "mlx_vector_int_append_data");
|
|
if (mlx_vector_int_append_data_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_append_data\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_int_append_value_ptr = GET_SYM(handle, "mlx_vector_int_append_value");
|
|
if (mlx_vector_int_append_value_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_append_value\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_int_size_ptr = GET_SYM(handle, "mlx_vector_int_size");
|
|
if (mlx_vector_int_size_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_size\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_int_get_ptr = GET_SYM(handle, "mlx_vector_int_get");
|
|
if (mlx_vector_int_get_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_get\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_string_new_ptr = GET_SYM(handle, "mlx_vector_string_new");
|
|
if (mlx_vector_string_new_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_string_set_ptr = GET_SYM(handle, "mlx_vector_string_set");
|
|
if (mlx_vector_string_set_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_string_free_ptr = GET_SYM(handle, "mlx_vector_string_free");
|
|
if (mlx_vector_string_free_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_free\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_string_new_data_ptr = GET_SYM(handle, "mlx_vector_string_new_data");
|
|
if (mlx_vector_string_new_data_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new_data\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_string_new_value_ptr = GET_SYM(handle, "mlx_vector_string_new_value");
|
|
if (mlx_vector_string_new_value_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new_value\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_string_set_data_ptr = GET_SYM(handle, "mlx_vector_string_set_data");
|
|
if (mlx_vector_string_set_data_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set_data\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_string_set_value_ptr = GET_SYM(handle, "mlx_vector_string_set_value");
|
|
if (mlx_vector_string_set_value_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set_value\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_string_append_data_ptr = GET_SYM(handle, "mlx_vector_string_append_data");
|
|
if (mlx_vector_string_append_data_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_append_data\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_string_append_value_ptr = GET_SYM(handle, "mlx_vector_string_append_value");
|
|
if (mlx_vector_string_append_value_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_append_value\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_string_size_ptr = GET_SYM(handle, "mlx_vector_string_size");
|
|
if (mlx_vector_string_size_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_size\n");
|
|
return -1;
|
|
}
|
|
mlx_vector_string_get_ptr = GET_SYM(handle, "mlx_vector_string_get");
|
|
if (mlx_vector_string_get_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_get\n");
|
|
return -1;
|
|
}
|
|
mlx_version_ptr = GET_SYM(handle, "mlx_version");
|
|
if (mlx_version_ptr == NULL) {
|
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_version\n");
|
|
return -1;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
// Wrapper function implementations that call through function pointers
|
|
size_t mlx_dtype_size(mlx_dtype dtype) {
|
|
return mlx_dtype_size_ptr(dtype);
|
|
}
|
|
|
|
int mlx_array_tostring(mlx_string* str, const mlx_array arr) {
|
|
return mlx_array_tostring_ptr(str, arr);
|
|
}
|
|
|
|
mlx_array mlx_array_new(void) {
|
|
return mlx_array_new_ptr();
|
|
}
|
|
|
|
int mlx_array_free(mlx_array arr) {
|
|
return mlx_array_free_ptr(arr);
|
|
}
|
|
|
|
mlx_array mlx_array_new_bool(bool val) {
|
|
return mlx_array_new_bool_ptr(val);
|
|
}
|
|
|
|
mlx_array mlx_array_new_int(int val) {
|
|
return mlx_array_new_int_ptr(val);
|
|
}
|
|
|
|
mlx_array mlx_array_new_float32(float val) {
|
|
return mlx_array_new_float32_ptr(val);
|
|
}
|
|
|
|
mlx_array mlx_array_new_float(float val) {
|
|
return mlx_array_new_float_ptr(val);
|
|
}
|
|
|
|
mlx_array mlx_array_new_float64(double val) {
|
|
return mlx_array_new_float64_ptr(val);
|
|
}
|
|
|
|
mlx_array mlx_array_new_double(double val) {
|
|
return mlx_array_new_double_ptr(val);
|
|
}
|
|
|
|
mlx_array mlx_array_new_complex(float real_val, float imag_val) {
|
|
return mlx_array_new_complex_ptr(real_val, imag_val);
|
|
}
|
|
|
|
mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dtype dtype) {
|
|
return mlx_array_new_data_ptr(data, shape, dim, dtype);
|
|
}
|
|
|
|
mlx_array mlx_array_new_data_managed(void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) {
|
|
return mlx_array_new_data_managed_ptr(data, shape, dim, dtype, dtor);
|
|
}
|
|
|
|
mlx_array mlx_array_new_data_managed_payload(void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) {
|
|
return mlx_array_new_data_managed_payload_ptr(data, shape, dim, dtype, payload, dtor);
|
|
}
|
|
|
|
int mlx_array_set(mlx_array* arr, const mlx_array src) {
|
|
return mlx_array_set_ptr(arr, src);
|
|
}
|
|
|
|
int mlx_array_set_bool(mlx_array* arr, bool val) {
|
|
return mlx_array_set_bool_ptr(arr, val);
|
|
}
|
|
|
|
int mlx_array_set_int(mlx_array* arr, int val) {
|
|
return mlx_array_set_int_ptr(arr, val);
|
|
}
|
|
|
|
int mlx_array_set_float32(mlx_array* arr, float val) {
|
|
return mlx_array_set_float32_ptr(arr, val);
|
|
}
|
|
|
|
int mlx_array_set_float(mlx_array* arr, float val) {
|
|
return mlx_array_set_float_ptr(arr, val);
|
|
}
|
|
|
|
int mlx_array_set_float64(mlx_array* arr, double val) {
|
|
return mlx_array_set_float64_ptr(arr, val);
|
|
}
|
|
|
|
int mlx_array_set_double(mlx_array* arr, double val) {
|
|
return mlx_array_set_double_ptr(arr, val);
|
|
}
|
|
|
|
int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val) {
|
|
return mlx_array_set_complex_ptr(arr, real_val, imag_val);
|
|
}
|
|
|
|
int mlx_array_set_data(mlx_array* arr, const void* data, const int* shape, int dim, mlx_dtype dtype) {
|
|
return mlx_array_set_data_ptr(arr, data, shape, dim, dtype);
|
|
}
|
|
|
|
size_t mlx_array_itemsize(const mlx_array arr) {
|
|
return mlx_array_itemsize_ptr(arr);
|
|
}
|
|
|
|
size_t mlx_array_size(const mlx_array arr) {
|
|
return mlx_array_size_ptr(arr);
|
|
}
|
|
|
|
size_t mlx_array_nbytes(const mlx_array arr) {
|
|
return mlx_array_nbytes_ptr(arr);
|
|
}
|
|
|
|
size_t mlx_array_ndim(const mlx_array arr) {
|
|
return mlx_array_ndim_ptr(arr);
|
|
}
|
|
|
|
const int* mlx_array_shape(const mlx_array arr) {
|
|
return mlx_array_shape_ptr(arr);
|
|
}
|
|
|
|
const size_t* mlx_array_strides(const mlx_array arr) {
|
|
return mlx_array_strides_ptr(arr);
|
|
}
|
|
|
|
int mlx_array_dim(const mlx_array arr, int dim) {
|
|
return mlx_array_dim_ptr(arr, dim);
|
|
}
|
|
|
|
mlx_dtype mlx_array_dtype(const mlx_array arr) {
|
|
return mlx_array_dtype_ptr(arr);
|
|
}
|
|
|
|
int mlx_array_eval(mlx_array arr) {
|
|
return mlx_array_eval_ptr(arr);
|
|
}
|
|
|
|
int mlx_array_item_bool(bool* res, const mlx_array arr) {
|
|
return mlx_array_item_bool_ptr(res, arr);
|
|
}
|
|
|
|
int mlx_array_item_uint8(uint8_t* res, const mlx_array arr) {
|
|
return mlx_array_item_uint8_ptr(res, arr);
|
|
}
|
|
|
|
int mlx_array_item_uint16(uint16_t* res, const mlx_array arr) {
|
|
return mlx_array_item_uint16_ptr(res, arr);
|
|
}
|
|
|
|
int mlx_array_item_uint32(uint32_t* res, const mlx_array arr) {
|
|
return mlx_array_item_uint32_ptr(res, arr);
|
|
}
|
|
|
|
int mlx_array_item_uint64(uint64_t* res, const mlx_array arr) {
|
|
return mlx_array_item_uint64_ptr(res, arr);
|
|
}
|
|
|
|
int mlx_array_item_int8(int8_t* res, const mlx_array arr) {
|
|
return mlx_array_item_int8_ptr(res, arr);
|
|
}
|
|
|
|
int mlx_array_item_int16(int16_t* res, const mlx_array arr) {
|
|
return mlx_array_item_int16_ptr(res, arr);
|
|
}
|
|
|
|
int mlx_array_item_int32(int32_t* res, const mlx_array arr) {
|
|
return mlx_array_item_int32_ptr(res, arr);
|
|
}
|
|
|
|
int mlx_array_item_int64(int64_t* res, const mlx_array arr) {
|
|
return mlx_array_item_int64_ptr(res, arr);
|
|
}
|
|
|
|
int mlx_array_item_float32(float* res, const mlx_array arr) {
|
|
return mlx_array_item_float32_ptr(res, arr);
|
|
}
|
|
|
|
int mlx_array_item_float64(double* res, const mlx_array arr) {
|
|
return mlx_array_item_float64_ptr(res, arr);
|
|
}
|
|
|
|
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr) {
|
|
return mlx_array_item_complex64_ptr(res, arr);
|
|
}
|
|
|
|
#if defined(__aarch64__) || defined(_M_ARM64)
|
|
int mlx_array_item_float16(float16_t* res, const mlx_array arr) {
|
|
return mlx_array_item_float16_ptr(res, arr);
|
|
}
|
|
#endif
|
|
|
|
#if defined(__aarch64__) || defined(_M_ARM64)
|
|
int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr) {
|
|
return mlx_array_item_bfloat16_ptr(res, arr);
|
|
}
|
|
#endif
|
|
|
|
const bool* mlx_array_data_bool(const mlx_array arr) {
|
|
return mlx_array_data_bool_ptr(arr);
|
|
}
|
|
|
|
const uint8_t* mlx_array_data_uint8(const mlx_array arr) {
|
|
return mlx_array_data_uint8_ptr(arr);
|
|
}
|
|
|
|
const uint16_t* mlx_array_data_uint16(const mlx_array arr) {
|
|
return mlx_array_data_uint16_ptr(arr);
|
|
}
|
|
|
|
const uint32_t* mlx_array_data_uint32(const mlx_array arr) {
|
|
return mlx_array_data_uint32_ptr(arr);
|
|
}
|
|
|
|
const uint64_t* mlx_array_data_uint64(const mlx_array arr) {
|
|
return mlx_array_data_uint64_ptr(arr);
|
|
}
|
|
|
|
const int8_t* mlx_array_data_int8(const mlx_array arr) {
|
|
return mlx_array_data_int8_ptr(arr);
|
|
}
|
|
|
|
const int16_t* mlx_array_data_int16(const mlx_array arr) {
|
|
return mlx_array_data_int16_ptr(arr);
|
|
}
|
|
|
|
const int32_t* mlx_array_data_int32(const mlx_array arr) {
|
|
return mlx_array_data_int32_ptr(arr);
|
|
}
|
|
|
|
const int64_t* mlx_array_data_int64(const mlx_array arr) {
|
|
return mlx_array_data_int64_ptr(arr);
|
|
}
|
|
|
|
const float* mlx_array_data_float32(const mlx_array arr) {
|
|
return mlx_array_data_float32_ptr(arr);
|
|
}
|
|
|
|
const double* mlx_array_data_float64(const mlx_array arr) {
|
|
return mlx_array_data_float64_ptr(arr);
|
|
}
|
|
|
|
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr) {
|
|
return mlx_array_data_complex64_ptr(arr);
|
|
}
|
|
|
|
#if defined(__aarch64__) || defined(_M_ARM64)
|
|
const float16_t* mlx_array_data_float16(const mlx_array arr) {
|
|
return mlx_array_data_float16_ptr(arr);
|
|
}
|
|
#endif
|
|
|
|
#if defined(__aarch64__) || defined(_M_ARM64)
|
|
const bfloat16_t* mlx_array_data_bfloat16(const mlx_array arr) {
|
|
return mlx_array_data_bfloat16_ptr(arr);
|
|
}
|
|
#endif
|
|
|
|
int _mlx_array_is_available(bool* res, const mlx_array arr) {
|
|
return _mlx_array_is_available_ptr(res, arr);
|
|
}
|
|
|
|
int _mlx_array_wait(const mlx_array arr) {
|
|
return _mlx_array_wait_ptr(arr);
|
|
}
|
|
|
|
int _mlx_array_is_contiguous(bool* res, const mlx_array arr) {
|
|
return _mlx_array_is_contiguous_ptr(res, arr);
|
|
}
|
|
|
|
int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr) {
|
|
return _mlx_array_is_row_contiguous_ptr(res, arr);
|
|
}
|
|
|
|
int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr) {
|
|
return _mlx_array_is_col_contiguous_ptr(res, arr);
|
|
}
|
|
|
|
mlx_closure mlx_closure_new(void) {
|
|
return mlx_closure_new_ptr();
|
|
}
|
|
|
|
int mlx_closure_free(mlx_closure cls) {
|
|
return mlx_closure_free_ptr(cls);
|
|
}
|
|
|
|
mlx_closure mlx_closure_new_func(int (*fun)(mlx_vector_array*, const mlx_vector_array)) {
|
|
return mlx_closure_new_func_ptr(fun);
|
|
}
|
|
|
|
mlx_closure mlx_closure_new_func_payload(int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) {
|
|
return mlx_closure_new_func_payload_ptr(fun, payload, dtor);
|
|
}
|
|
|
|
int mlx_closure_set(mlx_closure* cls, const mlx_closure src) {
|
|
return mlx_closure_set_ptr(cls, src);
|
|
}
|
|
|
|
int mlx_closure_apply(mlx_vector_array* res, mlx_closure cls, const mlx_vector_array input) {
|
|
return mlx_closure_apply_ptr(res, cls, input);
|
|
}
|
|
|
|
mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array)) {
|
|
return mlx_closure_new_unary_ptr(fun);
|
|
}
|
|
|
|
mlx_closure_kwargs mlx_closure_kwargs_new(void) {
|
|
return mlx_closure_kwargs_new_ptr();
|
|
}
|
|
|
|
int mlx_closure_kwargs_free(mlx_closure_kwargs cls) {
|
|
return mlx_closure_kwargs_free_ptr(cls);
|
|
}
|
|
|
|
mlx_closure_kwargs mlx_closure_kwargs_new_func(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array)) {
|
|
return mlx_closure_kwargs_new_func_ptr(fun);
|
|
}
|
|
|
|
mlx_closure_kwargs mlx_closure_kwargs_new_func_payload(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array, void*), void* payload, void (*dtor)(void*)) {
|
|
return mlx_closure_kwargs_new_func_payload_ptr(fun, payload, dtor);
|
|
}
|
|
|
|
int mlx_closure_kwargs_set(mlx_closure_kwargs* cls, const mlx_closure_kwargs src) {
|
|
return mlx_closure_kwargs_set_ptr(cls, src);
|
|
}
|
|
|
|
int mlx_closure_kwargs_apply(mlx_vector_array* res, mlx_closure_kwargs cls, const mlx_vector_array input_0, const mlx_map_string_to_array input_1) {
|
|
return mlx_closure_kwargs_apply_ptr(res, cls, input_0, input_1);
|
|
}
|
|
|
|
mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void) {
|
|
return mlx_closure_value_and_grad_new_ptr();
|
|
}
|
|
|
|
int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls) {
|
|
return mlx_closure_value_and_grad_free_ptr(cls);
|
|
}
|
|
|
|
mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func(int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)) {
|
|
return mlx_closure_value_and_grad_new_func_ptr(fun);
|
|
}
|
|
|
|
mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload(int (*fun)( mlx_vector_array*, mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) {
|
|
return mlx_closure_value_and_grad_new_func_payload_ptr(fun, payload, dtor);
|
|
}
|
|
|
|
int mlx_closure_value_and_grad_set(mlx_closure_value_and_grad* cls, const mlx_closure_value_and_grad src) {
|
|
return mlx_closure_value_and_grad_set_ptr(cls, src);
|
|
}
|
|
|
|
int mlx_closure_value_and_grad_apply(mlx_vector_array* res_0, mlx_vector_array* res_1, mlx_closure_value_and_grad cls, const mlx_vector_array input) {
|
|
return mlx_closure_value_and_grad_apply_ptr(res_0, res_1, cls, input);
|
|
}
|
|
|
|
mlx_closure_custom mlx_closure_custom_new(void) {
|
|
return mlx_closure_custom_new_ptr();
|
|
}
|
|
|
|
int mlx_closure_custom_free(mlx_closure_custom cls) {
|
|
return mlx_closure_custom_free_ptr(cls);
|
|
}
|
|
|
|
mlx_closure_custom mlx_closure_custom_new_func(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array)) {
|
|
return mlx_closure_custom_new_func_ptr(fun);
|
|
}
|
|
|
|
mlx_closure_custom mlx_closure_custom_new_func_payload(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) {
|
|
return mlx_closure_custom_new_func_payload_ptr(fun, payload, dtor);
|
|
}
|
|
|
|
int mlx_closure_custom_set(mlx_closure_custom* cls, const mlx_closure_custom src) {
|
|
return mlx_closure_custom_set_ptr(cls, src);
|
|
}
|
|
|
|
int mlx_closure_custom_apply(mlx_vector_array* res, mlx_closure_custom cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const mlx_vector_array input_2) {
|
|
return mlx_closure_custom_apply_ptr(res, cls, input_0, input_1, input_2);
|
|
}
|
|
|
|
mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void) {
|
|
return mlx_closure_custom_jvp_new_ptr();
|
|
}
|
|
|
|
int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls) {
|
|
return mlx_closure_custom_jvp_free_ptr(cls);
|
|
}
|
|
|
|
mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num)) {
|
|
return mlx_closure_custom_jvp_new_func_ptr(fun);
|
|
}
|
|
|
|
mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)) {
|
|
return mlx_closure_custom_jvp_new_func_payload_ptr(fun, payload, dtor);
|
|
}
|
|
|
|
int mlx_closure_custom_jvp_set(mlx_closure_custom_jvp* cls, const mlx_closure_custom_jvp src) {
|
|
return mlx_closure_custom_jvp_set_ptr(cls, src);
|
|
}
|
|
|
|
int mlx_closure_custom_jvp_apply(mlx_vector_array* res, mlx_closure_custom_jvp cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const int* input_2, size_t input_2_num) {
|
|
return mlx_closure_custom_jvp_apply_ptr(res, cls, input_0, input_1, input_2, input_2_num);
|
|
}
|
|
|
|
mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void) {
|
|
return mlx_closure_custom_vmap_new_ptr();
|
|
}
|
|
|
|
int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls) {
|
|
return mlx_closure_custom_vmap_free_ptr(cls);
|
|
}
|
|
|
|
mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num)) {
|
|
return mlx_closure_custom_vmap_new_func_ptr(fun);
|
|
}
|
|
|
|
mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)) {
|
|
return mlx_closure_custom_vmap_new_func_payload_ptr(fun, payload, dtor);
|
|
}
|
|
|
|
int mlx_closure_custom_vmap_set(mlx_closure_custom_vmap* cls, const mlx_closure_custom_vmap src) {
|
|
return mlx_closure_custom_vmap_set_ptr(cls, src);
|
|
}
|
|
|
|
int mlx_closure_custom_vmap_apply(mlx_vector_array* res_0, mlx_vector_int* res_1, mlx_closure_custom_vmap cls, const mlx_vector_array input_0, const int* input_1, size_t input_1_num) {
|
|
return mlx_closure_custom_vmap_apply_ptr(res_0, res_1, cls, input_0, input_1, input_1_num);
|
|
}
|
|
|
|
int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless) {
|
|
return mlx_compile_ptr(res, fun, shapeless);
|
|
}
|
|
|
|
int mlx_detail_compile(mlx_closure* res, const mlx_closure fun, uintptr_t fun_id, bool shapeless, const uint64_t* constants, size_t constants_num) {
|
|
return mlx_detail_compile_ptr(res, fun, fun_id, shapeless, constants, constants_num);
|
|
}
|
|
|
|
int mlx_detail_compile_clear_cache(void) {
|
|
return mlx_detail_compile_clear_cache_ptr();
|
|
}
|
|
|
|
int mlx_detail_compile_erase(uintptr_t fun_id) {
|
|
return mlx_detail_compile_erase_ptr(fun_id);
|
|
}
|
|
|
|
int mlx_disable_compile(void) {
|
|
return mlx_disable_compile_ptr();
|
|
}
|
|
|
|
int mlx_enable_compile(void) {
|
|
return mlx_enable_compile_ptr();
|
|
}
|
|
|
|
int mlx_set_compile_mode(mlx_compile_mode mode) {
|
|
return mlx_set_compile_mode_ptr(mode);
|
|
}
|
|
|
|
int mlx_cuda_is_available(bool* res) {
|
|
return mlx_cuda_is_available_ptr(res);
|
|
}
|
|
|
|
mlx_device mlx_device_new(void) {
|
|
return mlx_device_new_ptr();
|
|
}
|
|
|
|
mlx_device mlx_device_new_type(mlx_device_type type, int index) {
|
|
return mlx_device_new_type_ptr(type, index);
|
|
}
|
|
|
|
int mlx_device_free(mlx_device dev) {
|
|
return mlx_device_free_ptr(dev);
|
|
}
|
|
|
|
int mlx_device_set(mlx_device* dev, const mlx_device src) {
|
|
return mlx_device_set_ptr(dev, src);
|
|
}
|
|
|
|
int mlx_device_tostring(mlx_string* str, mlx_device dev) {
|
|
return mlx_device_tostring_ptr(str, dev);
|
|
}
|
|
|
|
bool mlx_device_equal(mlx_device lhs, mlx_device rhs) {
|
|
return mlx_device_equal_ptr(lhs, rhs);
|
|
}
|
|
|
|
int mlx_device_get_index(int* index, mlx_device dev) {
|
|
return mlx_device_get_index_ptr(index, dev);
|
|
}
|
|
|
|
int mlx_device_get_type(mlx_device_type* type, mlx_device dev) {
|
|
return mlx_device_get_type_ptr(type, dev);
|
|
}
|
|
|
|
int mlx_get_default_device(mlx_device* dev) {
|
|
return mlx_get_default_device_ptr(dev);
|
|
}
|
|
|
|
int mlx_set_default_device(mlx_device dev) {
|
|
return mlx_set_default_device_ptr(dev);
|
|
}
|
|
|
|
int mlx_device_is_available(bool* avail, mlx_device dev) {
|
|
return mlx_device_is_available_ptr(avail, dev);
|
|
}
|
|
|
|
int mlx_device_count(int* count, mlx_device_type type) {
|
|
return mlx_device_count_ptr(count, type);
|
|
}
|
|
|
|
mlx_device_info mlx_device_info_new(void) {
|
|
return mlx_device_info_new_ptr();
|
|
}
|
|
|
|
int mlx_device_info_get(mlx_device_info* info, mlx_device dev) {
|
|
return mlx_device_info_get_ptr(info, dev);
|
|
}
|
|
|
|
int mlx_device_info_free(mlx_device_info info) {
|
|
return mlx_device_info_free_ptr(info);
|
|
}
|
|
|
|
int mlx_device_info_has_key(bool* exists, mlx_device_info info, const char* key) {
|
|
return mlx_device_info_has_key_ptr(exists, info, key);
|
|
}
|
|
|
|
int mlx_device_info_is_string(bool* is_string, mlx_device_info info, const char* key) {
|
|
return mlx_device_info_is_string_ptr(is_string, info, key);
|
|
}
|
|
|
|
int mlx_device_info_get_string(const char** value, mlx_device_info info, const char* key) {
|
|
return mlx_device_info_get_string_ptr(value, info, key);
|
|
}
|
|
|
|
int mlx_device_info_get_size(size_t* value, mlx_device_info info, const char* key) {
|
|
return mlx_device_info_get_size_ptr(value, info, key);
|
|
}
|
|
|
|
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info) {
|
|
return mlx_device_info_get_keys_ptr(keys, info);
|
|
}
|
|
|
|
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) {
|
|
return mlx_distributed_all_gather_ptr(res, x, group, S);
|
|
}
|
|
|
|
int mlx_distributed_all_max(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) {
|
|
return mlx_distributed_all_max_ptr(res, x, group, s);
|
|
}
|
|
|
|
int mlx_distributed_all_min(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) {
|
|
return mlx_distributed_all_min_ptr(res, x, group, s);
|
|
}
|
|
|
|
int mlx_distributed_all_sum(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) {
|
|
return mlx_distributed_all_sum_ptr(res, x, group, s);
|
|
}
|
|
|
|
int mlx_distributed_recv(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, int src, const mlx_distributed_group group , const mlx_stream s) {
|
|
return mlx_distributed_recv_ptr(res, shape, shape_num, dtype, src, group, s);
|
|
}
|
|
|
|
int mlx_distributed_recv_like(mlx_array* res, const mlx_array x, int src, const mlx_distributed_group group , const mlx_stream s) {
|
|
return mlx_distributed_recv_like_ptr(res, x, src, group, s);
|
|
}
|
|
|
|
int mlx_distributed_send(mlx_array* res, const mlx_array x, int dst, const mlx_distributed_group group , const mlx_stream s) {
|
|
return mlx_distributed_send_ptr(res, x, dst, group, s);
|
|
}
|
|
|
|
int mlx_distributed_sum_scatter(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) {
|
|
return mlx_distributed_sum_scatter_ptr(res, x, group, s);
|
|
}
|
|
|
|
int mlx_distributed_group_rank(mlx_distributed_group group) {
|
|
return mlx_distributed_group_rank_ptr(group);
|
|
}
|
|
|
|
int mlx_distributed_group_size(mlx_distributed_group group) {
|
|
return mlx_distributed_group_size_ptr(group);
|
|
}
|
|
|
|
mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) {
|
|
return mlx_distributed_group_split_ptr(group, color, key);
|
|
}
|
|
|
|
bool mlx_distributed_is_available(void) {
|
|
return mlx_distributed_is_available_ptr();
|
|
}
|
|
|
|
mlx_distributed_group mlx_distributed_init(bool strict) {
|
|
return mlx_distributed_init_ptr(strict);
|
|
}
|
|
|
|
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) {
|
|
mlx_set_error_handler_ptr(handler, data, dtor);
|
|
}
|
|
|
|
void _mlx_error(const char* file, const int line, const char* fmt, ...) {
|
|
_mlx_error_ptr(file, line, fmt);
|
|
}
|
|
|
|
int mlx_export_function(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) {
|
|
return mlx_export_function_ptr(file, fun, args, shapeless);
|
|
}
|
|
|
|
int mlx_export_function_kwargs(const char* file, const mlx_closure_kwargs fun, const mlx_vector_array args, const mlx_map_string_to_array kwargs, bool shapeless) {
|
|
return mlx_export_function_kwargs_ptr(file, fun, args, kwargs, shapeless);
|
|
}
|
|
|
|
mlx_function_exporter mlx_function_exporter_new(const char* file, const mlx_closure fun, bool shapeless) {
|
|
return mlx_function_exporter_new_ptr(file, fun, shapeless);
|
|
}
|
|
|
|
int mlx_function_exporter_free(mlx_function_exporter xfunc) {
|
|
return mlx_function_exporter_free_ptr(xfunc);
|
|
}
|
|
|
|
int mlx_function_exporter_apply(const mlx_function_exporter xfunc, const mlx_vector_array args) {
|
|
return mlx_function_exporter_apply_ptr(xfunc, args);
|
|
}
|
|
|
|
int mlx_function_exporter_apply_kwargs(const mlx_function_exporter xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) {
|
|
return mlx_function_exporter_apply_kwargs_ptr(xfunc, args, kwargs);
|
|
}
|
|
|
|
mlx_imported_function mlx_imported_function_new(const char* file) {
|
|
return mlx_imported_function_new_ptr(file);
|
|
}
|
|
|
|
int mlx_imported_function_free(mlx_imported_function xfunc) {
|
|
return mlx_imported_function_free_ptr(xfunc);
|
|
}
|
|
|
|
int mlx_imported_function_apply(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args) {
|
|
return mlx_imported_function_apply_ptr(res, xfunc, args);
|
|
}
|
|
|
|
int mlx_imported_function_apply_kwargs(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) {
|
|
return mlx_imported_function_apply_kwargs_ptr(res, xfunc, args, kwargs);
|
|
}
|
|
|
|
mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void) {
|
|
return mlx_fast_cuda_kernel_config_new_ptr();
|
|
}
|
|
|
|
void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls) {
|
|
mlx_fast_cuda_kernel_config_free_ptr(cls);
|
|
}
|
|
|
|
int mlx_fast_cuda_kernel_config_add_output_arg(mlx_fast_cuda_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype) {
|
|
return mlx_fast_cuda_kernel_config_add_output_arg_ptr(cls, shape, size, dtype);
|
|
}
|
|
|
|
int mlx_fast_cuda_kernel_config_set_grid(mlx_fast_cuda_kernel_config cls, int grid1, int grid2, int grid3) {
|
|
return mlx_fast_cuda_kernel_config_set_grid_ptr(cls, grid1, grid2, grid3);
|
|
}
|
|
|
|
int mlx_fast_cuda_kernel_config_set_thread_group(mlx_fast_cuda_kernel_config cls, int thread1, int thread2, int thread3) {
|
|
return mlx_fast_cuda_kernel_config_set_thread_group_ptr(cls, thread1, thread2, thread3);
|
|
}
|
|
|
|
int mlx_fast_cuda_kernel_config_set_init_value(mlx_fast_cuda_kernel_config cls, float value) {
|
|
return mlx_fast_cuda_kernel_config_set_init_value_ptr(cls, value);
|
|
}
|
|
|
|
int mlx_fast_cuda_kernel_config_set_verbose(mlx_fast_cuda_kernel_config cls, bool verbose) {
|
|
return mlx_fast_cuda_kernel_config_set_verbose_ptr(cls, verbose);
|
|
}
|
|
|
|
int mlx_fast_cuda_kernel_config_add_template_arg_dtype(mlx_fast_cuda_kernel_config cls, const char* name, mlx_dtype dtype) {
|
|
return mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr(cls, name, dtype);
|
|
}
|
|
|
|
int mlx_fast_cuda_kernel_config_add_template_arg_int(mlx_fast_cuda_kernel_config cls, const char* name, int value) {
|
|
return mlx_fast_cuda_kernel_config_add_template_arg_int_ptr(cls, name, value);
|
|
}
|
|
|
|
int mlx_fast_cuda_kernel_config_add_template_arg_bool(mlx_fast_cuda_kernel_config cls, const char* name, bool value) {
|
|
return mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr(cls, name, value);
|
|
}
|
|
|
|
mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, int shared_memory) {
|
|
return mlx_fast_cuda_kernel_new_ptr(name, input_names, output_names, source, header, ensure_row_contiguous, shared_memory);
|
|
}
|
|
|
|
void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls) {
|
|
mlx_fast_cuda_kernel_free_ptr(cls);
|
|
}
|
|
|
|
int mlx_fast_cuda_kernel_apply(mlx_vector_array* outputs, mlx_fast_cuda_kernel cls, const mlx_vector_array inputs, const mlx_fast_cuda_kernel_config config, const mlx_stream stream) {
|
|
return mlx_fast_cuda_kernel_apply_ptr(outputs, cls, inputs, config, stream);
|
|
}
|
|
|
|
int mlx_fast_layer_norm(mlx_array* res, const mlx_array x, const mlx_array weight , const mlx_array bias , float eps, const mlx_stream s) {
|
|
return mlx_fast_layer_norm_ptr(res, x, weight, bias, eps, s);
|
|
}
|
|
|
|
mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void) {
|
|
return mlx_fast_metal_kernel_config_new_ptr();
|
|
}
|
|
|
|
void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls) {
|
|
mlx_fast_metal_kernel_config_free_ptr(cls);
|
|
}
|
|
|
|
int mlx_fast_metal_kernel_config_add_output_arg(mlx_fast_metal_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype) {
|
|
return mlx_fast_metal_kernel_config_add_output_arg_ptr(cls, shape, size, dtype);
|
|
}
|
|
|
|
int mlx_fast_metal_kernel_config_set_grid(mlx_fast_metal_kernel_config cls, int grid1, int grid2, int grid3) {
|
|
return mlx_fast_metal_kernel_config_set_grid_ptr(cls, grid1, grid2, grid3);
|
|
}
|
|
|
|
int mlx_fast_metal_kernel_config_set_thread_group(mlx_fast_metal_kernel_config cls, int thread1, int thread2, int thread3) {
|
|
return mlx_fast_metal_kernel_config_set_thread_group_ptr(cls, thread1, thread2, thread3);
|
|
}
|
|
|
|
int mlx_fast_metal_kernel_config_set_init_value(mlx_fast_metal_kernel_config cls, float value) {
|
|
return mlx_fast_metal_kernel_config_set_init_value_ptr(cls, value);
|
|
}
|
|
|
|
int mlx_fast_metal_kernel_config_set_verbose(mlx_fast_metal_kernel_config cls, bool verbose) {
|
|
return mlx_fast_metal_kernel_config_set_verbose_ptr(cls, verbose);
|
|
}
|
|
|
|
int mlx_fast_metal_kernel_config_add_template_arg_dtype(mlx_fast_metal_kernel_config cls, const char* name, mlx_dtype dtype) {
|
|
return mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr(cls, name, dtype);
|
|
}
|
|
|
|
int mlx_fast_metal_kernel_config_add_template_arg_int(mlx_fast_metal_kernel_config cls, const char* name, int value) {
|
|
return mlx_fast_metal_kernel_config_add_template_arg_int_ptr(cls, name, value);
|
|
}
|
|
|
|
int mlx_fast_metal_kernel_config_add_template_arg_bool(mlx_fast_metal_kernel_config cls, const char* name, bool value) {
|
|
return mlx_fast_metal_kernel_config_add_template_arg_bool_ptr(cls, name, value);
|
|
}
|
|
|
|
mlx_fast_metal_kernel mlx_fast_metal_kernel_new(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, bool atomic_outputs) {
|
|
return mlx_fast_metal_kernel_new_ptr(name, input_names, output_names, source, header, ensure_row_contiguous, atomic_outputs);
|
|
}
|
|
|
|
void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls) {
|
|
mlx_fast_metal_kernel_free_ptr(cls);
|
|
}
|
|
|
|
int mlx_fast_metal_kernel_apply(mlx_vector_array* outputs, mlx_fast_metal_kernel cls, const mlx_vector_array inputs, const mlx_fast_metal_kernel_config config, const mlx_stream stream) {
|
|
return mlx_fast_metal_kernel_apply_ptr(outputs, cls, inputs, config, stream);
|
|
}
|
|
|
|
int mlx_fast_rms_norm(mlx_array* res, const mlx_array x, const mlx_array weight , float eps, const mlx_stream s) {
|
|
return mlx_fast_rms_norm_ptr(res, x, weight, eps, s);
|
|
}
|
|
|
|
int mlx_fast_rope(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, int offset, const mlx_array freqs , const mlx_stream s) {
|
|
return mlx_fast_rope_ptr(res, x, dims, traditional, base, scale, offset, freqs, s);
|
|
}
|
|
|
|
int mlx_fast_rope_dynamic(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, const mlx_array offset, const mlx_array freqs , const mlx_stream s) {
|
|
return mlx_fast_rope_dynamic_ptr(res, x, dims, traditional, base, scale, offset, freqs, s);
|
|
}
|
|
|
|
int mlx_fast_scaled_dot_product_attention(mlx_array* res, const mlx_array queries, const mlx_array keys, const mlx_array values, float scale, const char* mask_mode, const mlx_array mask_arr , const mlx_array sinks , const mlx_stream s) {
|
|
return mlx_fast_scaled_dot_product_attention_ptr(res, queries, keys, values, scale, mask_mode, mask_arr, sinks, s);
|
|
}
|
|
|
|
int mlx_fft_fft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) {
|
|
return mlx_fft_fft_ptr(res, a, n, axis, s);
|
|
}
|
|
|
|
int mlx_fft_fft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_fft_fft2_ptr(res, a, n, n_num, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_fft_fftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_fft_fftn_ptr(res, a, n, n_num, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_fft_fftshift(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_fft_fftshift_ptr(res, a, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_fft_ifft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) {
|
|
return mlx_fft_ifft_ptr(res, a, n, axis, s);
|
|
}
|
|
|
|
int mlx_fft_ifft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_fft_ifft2_ptr(res, a, n, n_num, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_fft_ifftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_fft_ifftn_ptr(res, a, n, n_num, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_fft_ifftshift(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_fft_ifftshift_ptr(res, a, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_fft_irfft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) {
|
|
return mlx_fft_irfft_ptr(res, a, n, axis, s);
|
|
}
|
|
|
|
int mlx_fft_irfft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_fft_irfft2_ptr(res, a, n, n_num, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_fft_irfftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_fft_irfftn_ptr(res, a, n, n_num, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_fft_rfft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) {
|
|
return mlx_fft_rfft_ptr(res, a, n, axis, s);
|
|
}
|
|
|
|
int mlx_fft_rfft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_fft_rfft2_ptr(res, a, n, n_num, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_fft_rfftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_fft_rfftn_ptr(res, a, n, n_num, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_load_reader(mlx_array* res, mlx_io_reader in_stream, const mlx_stream s) {
|
|
return mlx_load_reader_ptr(res, in_stream, s);
|
|
}
|
|
|
|
int mlx_load(mlx_array* res, const char* file, const mlx_stream s) {
|
|
return mlx_load_ptr(res, file, s);
|
|
}
|
|
|
|
int mlx_load_safetensors_reader(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, mlx_io_reader in_stream, const mlx_stream s) {
|
|
return mlx_load_safetensors_reader_ptr(res_0, res_1, in_stream, s);
|
|
}
|
|
|
|
int mlx_load_safetensors(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, const char* file, const mlx_stream s) {
|
|
return mlx_load_safetensors_ptr(res_0, res_1, file, s);
|
|
}
|
|
|
|
int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a) {
|
|
return mlx_save_writer_ptr(out_stream, a);
|
|
}
|
|
|
|
int mlx_save(const char* file, const mlx_array a) {
|
|
return mlx_save_ptr(file, a);
|
|
}
|
|
|
|
int mlx_save_safetensors_writer(mlx_io_writer in_stream, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) {
|
|
return mlx_save_safetensors_writer_ptr(in_stream, param, metadata);
|
|
}
|
|
|
|
int mlx_save_safetensors(const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) {
|
|
return mlx_save_safetensors_ptr(file, param, metadata);
|
|
}
|
|
|
|
mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable) {
|
|
return mlx_io_reader_new_ptr(desc, vtable);
|
|
}
|
|
|
|
int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io) {
|
|
return mlx_io_reader_descriptor_ptr(desc_, io);
|
|
}
|
|
|
|
int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io) {
|
|
return mlx_io_reader_tostring_ptr(str_, io);
|
|
}
|
|
|
|
int mlx_io_reader_free(mlx_io_reader io) {
|
|
return mlx_io_reader_free_ptr(io);
|
|
}
|
|
|
|
mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable) {
|
|
return mlx_io_writer_new_ptr(desc, vtable);
|
|
}
|
|
|
|
int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io) {
|
|
return mlx_io_writer_descriptor_ptr(desc_, io);
|
|
}
|
|
|
|
int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io) {
|
|
return mlx_io_writer_tostring_ptr(str_, io);
|
|
}
|
|
|
|
int mlx_io_writer_free(mlx_io_writer io) {
|
|
return mlx_io_writer_free_ptr(io);
|
|
}
|
|
|
|
int mlx_linalg_cholesky(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) {
|
|
return mlx_linalg_cholesky_ptr(res, a, upper, s);
|
|
}
|
|
|
|
int mlx_linalg_cholesky_inv(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) {
|
|
return mlx_linalg_cholesky_inv_ptr(res, a, upper, s);
|
|
}
|
|
|
|
int mlx_linalg_cross(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s) {
|
|
return mlx_linalg_cross_ptr(res, a, b, axis, s);
|
|
}
|
|
|
|
int mlx_linalg_eig(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) {
|
|
return mlx_linalg_eig_ptr(res_0, res_1, a, s);
|
|
}
|
|
|
|
int mlx_linalg_eigh(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const char* UPLO, const mlx_stream s) {
|
|
return mlx_linalg_eigh_ptr(res_0, res_1, a, UPLO, s);
|
|
}
|
|
|
|
int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_linalg_eigvals_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_linalg_eigvalsh(mlx_array* res, const mlx_array a, const char* UPLO, const mlx_stream s) {
|
|
return mlx_linalg_eigvalsh_ptr(res, a, UPLO, s);
|
|
}
|
|
|
|
int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_linalg_inv_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_linalg_lu_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_linalg_lu_factor(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) {
|
|
return mlx_linalg_lu_factor_ptr(res_0, res_1, a, s);
|
|
}
|
|
|
|
int mlx_linalg_norm(mlx_array* res, const mlx_array a, double ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) {
|
|
return mlx_linalg_norm_ptr(res, a, ord, axis, axis_num, keepdims, s);
|
|
}
|
|
|
|
int mlx_linalg_norm_matrix(mlx_array* res, const mlx_array a, const char* ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) {
|
|
return mlx_linalg_norm_matrix_ptr(res, a, ord, axis, axis_num, keepdims, s);
|
|
}
|
|
|
|
int mlx_linalg_norm_l2(mlx_array* res, const mlx_array a, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) {
|
|
return mlx_linalg_norm_l2_ptr(res, a, axis, axis_num, keepdims, s);
|
|
}
|
|
|
|
int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_linalg_pinv_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_linalg_qr(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) {
|
|
return mlx_linalg_qr_ptr(res_0, res_1, a, s);
|
|
}
|
|
|
|
int mlx_linalg_solve(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_linalg_solve_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_linalg_solve_triangular(mlx_array* res, const mlx_array a, const mlx_array b, bool upper, const mlx_stream s) {
|
|
return mlx_linalg_solve_triangular_ptr(res, a, b, upper, s);
|
|
}
|
|
|
|
int mlx_linalg_svd(mlx_vector_array* res, const mlx_array a, bool compute_uv, const mlx_stream s) {
|
|
return mlx_linalg_svd_ptr(res, a, compute_uv, s);
|
|
}
|
|
|
|
int mlx_linalg_tri_inv(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) {
|
|
return mlx_linalg_tri_inv_ptr(res, a, upper, s);
|
|
}
|
|
|
|
mlx_map_string_to_array mlx_map_string_to_array_new(void) {
|
|
return mlx_map_string_to_array_new_ptr();
|
|
}
|
|
|
|
int mlx_map_string_to_array_set(mlx_map_string_to_array* map, const mlx_map_string_to_array src) {
|
|
return mlx_map_string_to_array_set_ptr(map, src);
|
|
}
|
|
|
|
int mlx_map_string_to_array_free(mlx_map_string_to_array map) {
|
|
return mlx_map_string_to_array_free_ptr(map);
|
|
}
|
|
|
|
int mlx_map_string_to_array_insert(mlx_map_string_to_array map, const char* key, const mlx_array value) {
|
|
return mlx_map_string_to_array_insert_ptr(map, key, value);
|
|
}
|
|
|
|
int mlx_map_string_to_array_get(mlx_array* value, const mlx_map_string_to_array map, const char* key) {
|
|
return mlx_map_string_to_array_get_ptr(value, map, key);
|
|
}
|
|
|
|
mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new(mlx_map_string_to_array map) {
|
|
return mlx_map_string_to_array_iterator_new_ptr(map);
|
|
}
|
|
|
|
int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it) {
|
|
return mlx_map_string_to_array_iterator_free_ptr(it);
|
|
}
|
|
|
|
int mlx_map_string_to_array_iterator_next(const char** key, mlx_array* value, mlx_map_string_to_array_iterator it) {
|
|
return mlx_map_string_to_array_iterator_next_ptr(key, value, it);
|
|
}
|
|
|
|
mlx_map_string_to_string mlx_map_string_to_string_new(void) {
|
|
return mlx_map_string_to_string_new_ptr();
|
|
}
|
|
|
|
int mlx_map_string_to_string_set(mlx_map_string_to_string* map, const mlx_map_string_to_string src) {
|
|
return mlx_map_string_to_string_set_ptr(map, src);
|
|
}
|
|
|
|
int mlx_map_string_to_string_free(mlx_map_string_to_string map) {
|
|
return mlx_map_string_to_string_free_ptr(map);
|
|
}
|
|
|
|
int mlx_map_string_to_string_insert(mlx_map_string_to_string map, const char* key, const char* value) {
|
|
return mlx_map_string_to_string_insert_ptr(map, key, value);
|
|
}
|
|
|
|
int mlx_map_string_to_string_get(const char** value, const mlx_map_string_to_string map, const char* key) {
|
|
return mlx_map_string_to_string_get_ptr(value, map, key);
|
|
}
|
|
|
|
mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new(mlx_map_string_to_string map) {
|
|
return mlx_map_string_to_string_iterator_new_ptr(map);
|
|
}
|
|
|
|
int mlx_map_string_to_string_iterator_free(mlx_map_string_to_string_iterator it) {
|
|
return mlx_map_string_to_string_iterator_free_ptr(it);
|
|
}
|
|
|
|
int mlx_map_string_to_string_iterator_next(const char** key, const char** value, mlx_map_string_to_string_iterator it) {
|
|
return mlx_map_string_to_string_iterator_next_ptr(key, value, it);
|
|
}
|
|
|
|
int mlx_clear_cache(void) {
|
|
return mlx_clear_cache_ptr();
|
|
}
|
|
|
|
int mlx_get_active_memory(size_t* res) {
|
|
return mlx_get_active_memory_ptr(res);
|
|
}
|
|
|
|
int mlx_get_cache_memory(size_t* res) {
|
|
return mlx_get_cache_memory_ptr(res);
|
|
}
|
|
|
|
int mlx_get_memory_limit(size_t* res) {
|
|
return mlx_get_memory_limit_ptr(res);
|
|
}
|
|
|
|
int mlx_get_peak_memory(size_t* res) {
|
|
return mlx_get_peak_memory_ptr(res);
|
|
}
|
|
|
|
int mlx_reset_peak_memory(void) {
|
|
return mlx_reset_peak_memory_ptr();
|
|
}
|
|
|
|
int mlx_set_cache_limit(size_t* res, size_t limit) {
|
|
return mlx_set_cache_limit_ptr(res, limit);
|
|
}
|
|
|
|
int mlx_set_memory_limit(size_t* res, size_t limit) {
|
|
return mlx_set_memory_limit_ptr(res, limit);
|
|
}
|
|
|
|
int mlx_set_wired_limit(size_t* res, size_t limit) {
|
|
return mlx_set_wired_limit_ptr(res, limit);
|
|
}
|
|
|
|
int mlx_metal_is_available(bool* res) {
|
|
return mlx_metal_is_available_ptr(res);
|
|
}
|
|
|
|
int mlx_metal_start_capture(const char* path) {
|
|
return mlx_metal_start_capture_ptr(path);
|
|
}
|
|
|
|
int mlx_metal_stop_capture(void) {
|
|
return mlx_metal_stop_capture_ptr();
|
|
}
|
|
|
|
int mlx_abs(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_abs_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_add(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_add_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_addmm(mlx_array* res, const mlx_array c, const mlx_array a, const mlx_array b, float alpha, float beta, const mlx_stream s) {
|
|
return mlx_addmm_ptr(res, c, a, b, alpha, beta, s);
|
|
}
|
|
|
|
int mlx_all_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) {
|
|
return mlx_all_axes_ptr(res, a, axes, axes_num, keepdims, s);
|
|
}
|
|
|
|
int mlx_all_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
|
|
return mlx_all_axis_ptr(res, a, axis, keepdims, s);
|
|
}
|
|
|
|
int mlx_all(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
|
|
return mlx_all_ptr(res, a, keepdims, s);
|
|
}
|
|
|
|
int mlx_allclose(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s) {
|
|
return mlx_allclose_ptr(res, a, b, rtol, atol, equal_nan, s);
|
|
}
|
|
|
|
int mlx_any_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) {
|
|
return mlx_any_axes_ptr(res, a, axes, axes_num, keepdims, s);
|
|
}
|
|
|
|
int mlx_any_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
|
|
return mlx_any_axis_ptr(res, a, axis, keepdims, s);
|
|
}
|
|
|
|
int mlx_any(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
|
|
return mlx_any_ptr(res, a, keepdims, s);
|
|
}
|
|
|
|
int mlx_arange(mlx_array* res, double start, double stop, double step, mlx_dtype dtype, const mlx_stream s) {
|
|
return mlx_arange_ptr(res, start, stop, step, dtype, s);
|
|
}
|
|
|
|
int mlx_arccos(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_arccos_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_arccosh(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_arccosh_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_arcsin(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_arcsin_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_arcsinh(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_arcsinh_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_arctan(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_arctan_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_arctan2(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_arctan2_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_arctanh(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_arctanh_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_argmax_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
|
|
return mlx_argmax_axis_ptr(res, a, axis, keepdims, s);
|
|
}
|
|
|
|
int mlx_argmax(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
|
|
return mlx_argmax_ptr(res, a, keepdims, s);
|
|
}
|
|
|
|
int mlx_argmin_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
|
|
return mlx_argmin_axis_ptr(res, a, axis, keepdims, s);
|
|
}
|
|
|
|
int mlx_argmin(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
|
|
return mlx_argmin_ptr(res, a, keepdims, s);
|
|
}
|
|
|
|
int mlx_argpartition_axis(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s) {
|
|
return mlx_argpartition_axis_ptr(res, a, kth, axis, s);
|
|
}
|
|
|
|
int mlx_argpartition(mlx_array* res, const mlx_array a, int kth, const mlx_stream s) {
|
|
return mlx_argpartition_ptr(res, a, kth, s);
|
|
}
|
|
|
|
int mlx_argsort_axis(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) {
|
|
return mlx_argsort_axis_ptr(res, a, axis, s);
|
|
}
|
|
|
|
int mlx_argsort(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_argsort_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_array_equal(mlx_array* res, const mlx_array a, const mlx_array b, bool equal_nan, const mlx_stream s) {
|
|
return mlx_array_equal_ptr(res, a, b, equal_nan, s);
|
|
}
|
|
|
|
int mlx_as_strided(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const int64_t* strides, size_t strides_num, size_t offset, const mlx_stream s) {
|
|
return mlx_as_strided_ptr(res, a, shape, shape_num, strides, strides_num, offset, s);
|
|
}
|
|
|
|
int mlx_astype(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s) {
|
|
return mlx_astype_ptr(res, a, dtype, s);
|
|
}
|
|
|
|
int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_atleast_1d_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_atleast_2d_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_atleast_3d_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_bitwise_and_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_bitwise_invert_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_bitwise_or(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_bitwise_or_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_bitwise_xor_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) {
|
|
return mlx_block_masked_mm_ptr(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s);
|
|
}
|
|
|
|
int mlx_broadcast_arrays(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) {
|
|
return mlx_broadcast_arrays_ptr(res, inputs, s);
|
|
}
|
|
|
|
int mlx_broadcast_to(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) {
|
|
return mlx_broadcast_to_ptr(res, a, shape, shape_num, s);
|
|
}
|
|
|
|
int mlx_ceil(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_ceil_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_clip(mlx_array* res, const mlx_array a, const mlx_array a_min , const mlx_array a_max , const mlx_stream s) {
|
|
return mlx_clip_ptr(res, a, a_min, a_max, s);
|
|
}
|
|
|
|
int mlx_concatenate_axis(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s) {
|
|
return mlx_concatenate_axis_ptr(res, arrays, axis, s);
|
|
}
|
|
|
|
int mlx_concatenate(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) {
|
|
return mlx_concatenate_ptr(res, arrays, s);
|
|
}
|
|
|
|
int mlx_conjugate(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_conjugate_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_contiguous(mlx_array* res, const mlx_array a, bool allow_col_major, const mlx_stream s) {
|
|
return mlx_contiguous_ptr(res, a, allow_col_major, s);
|
|
}
|
|
|
|
int mlx_conv1d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int groups, const mlx_stream s) {
|
|
return mlx_conv1d_ptr(res, input, weight, stride, padding, dilation, groups, s);
|
|
}
|
|
|
|
int mlx_conv2d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int groups, const mlx_stream s) {
|
|
return mlx_conv2d_ptr(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, groups, s);
|
|
}
|
|
|
|
int mlx_conv3d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int groups, const mlx_stream s) {
|
|
return mlx_conv3d_ptr(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, groups, s);
|
|
}
|
|
|
|
int mlx_conv_general(mlx_array* res, const mlx_array input, const mlx_array weight, const int* stride, size_t stride_num, const int* padding_lo, size_t padding_lo_num, const int* padding_hi, size_t padding_hi_num, const int* kernel_dilation, size_t kernel_dilation_num, const int* input_dilation, size_t input_dilation_num, int groups, bool flip, const mlx_stream s) {
|
|
return mlx_conv_general_ptr(res, input, weight, stride, stride_num, padding_lo, padding_lo_num, padding_hi, padding_hi_num, kernel_dilation, kernel_dilation_num, input_dilation, input_dilation_num, groups, flip, s);
|
|
}
|
|
|
|
int mlx_conv_transpose1d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int output_padding, int groups, const mlx_stream s) {
|
|
return mlx_conv_transpose1d_ptr(res, input, weight, stride, padding, dilation, output_padding, groups, s);
|
|
}
|
|
|
|
int mlx_conv_transpose2d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int output_padding_0, int output_padding_1, int groups, const mlx_stream s) {
|
|
return mlx_conv_transpose2d_ptr(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, output_padding_0, output_padding_1, groups, s);
|
|
}
|
|
|
|
int mlx_conv_transpose3d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int output_padding_0, int output_padding_1, int output_padding_2, int groups, const mlx_stream s) {
|
|
return mlx_conv_transpose3d_ptr(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, output_padding_0, output_padding_1, output_padding_2, groups, s);
|
|
}
|
|
|
|
int mlx_copy(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_copy_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_cos(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_cos_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_cosh(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_cosh_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_cummax(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) {
|
|
return mlx_cummax_ptr(res, a, axis, reverse, inclusive, s);
|
|
}
|
|
|
|
int mlx_cummin(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) {
|
|
return mlx_cummin_ptr(res, a, axis, reverse, inclusive, s);
|
|
}
|
|
|
|
int mlx_cumprod(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) {
|
|
return mlx_cumprod_ptr(res, a, axis, reverse, inclusive, s);
|
|
}
|
|
|
|
int mlx_cumsum(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) {
|
|
return mlx_cumsum_ptr(res, a, axis, reverse, inclusive, s);
|
|
}
|
|
|
|
int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_degrees_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) {
|
|
return mlx_depends_ptr(res, inputs, dependencies);
|
|
}
|
|
|
|
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) {
|
|
return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, dtype, s);
|
|
}
|
|
|
|
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
|
|
return mlx_diag_ptr(res, a, k, s);
|
|
}
|
|
|
|
int mlx_diagonal(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) {
|
|
return mlx_diagonal_ptr(res, a, offset, axis1, axis2, s);
|
|
}
|
|
|
|
int mlx_divide(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_divide_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_divmod(mlx_vector_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_divmod_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_einsum(mlx_array* res, const char* subscripts, const mlx_vector_array operands, const mlx_stream s) {
|
|
return mlx_einsum_ptr(res, subscripts, operands, s);
|
|
}
|
|
|
|
int mlx_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_equal_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_erf(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_erf_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_erfinv(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_erfinv_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_exp(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_exp_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_expand_dims_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_expand_dims_axes_ptr(res, a, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_expand_dims(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) {
|
|
return mlx_expand_dims_ptr(res, a, axis, s);
|
|
}
|
|
|
|
int mlx_expm1(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_expm1_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_eye(mlx_array* res, int n, int m, int k, mlx_dtype dtype, const mlx_stream s) {
|
|
return mlx_eye_ptr(res, n, m, k, dtype, s);
|
|
}
|
|
|
|
int mlx_flatten(mlx_array* res, const mlx_array a, int start_axis, int end_axis, const mlx_stream s) {
|
|
return mlx_flatten_ptr(res, a, start_axis, end_axis, s);
|
|
}
|
|
|
|
int mlx_floor(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_floor_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_floor_divide(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_floor_divide_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_from_fp8(mlx_array* res, const mlx_array x, mlx_dtype dtype, const mlx_stream s) {
|
|
return mlx_from_fp8_ptr(res, x, dtype, s);
|
|
}
|
|
|
|
int mlx_full(mlx_array* res, const int* shape, size_t shape_num, const mlx_array vals, mlx_dtype dtype, const mlx_stream s) {
|
|
return mlx_full_ptr(res, shape, shape_num, vals, dtype, s);
|
|
}
|
|
|
|
int mlx_full_like(mlx_array* res, const mlx_array a, const mlx_array vals, mlx_dtype dtype, const mlx_stream s) {
|
|
return mlx_full_like_ptr(res, a, vals, dtype, s);
|
|
}
|
|
|
|
int mlx_gather(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const int* axes, size_t axes_num, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) {
|
|
return mlx_gather_ptr(res, a, indices, axes, axes_num, slice_sizes, slice_sizes_num, s);
|
|
}
|
|
|
|
int mlx_gather_single(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) {
|
|
return mlx_gather_single_ptr(res, a, indices, axis, slice_sizes, slice_sizes_num, s);
|
|
}
|
|
|
|
int mlx_gather_mm(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array lhs_indices , const mlx_array rhs_indices , bool sorted_indices, const mlx_stream s) {
|
|
return mlx_gather_mm_ptr(res, a, b, lhs_indices, rhs_indices, sorted_indices, s);
|
|
}
|
|
|
|
int mlx_gather_qmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , const mlx_array lhs_indices , const mlx_array rhs_indices , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, bool sorted_indices, const mlx_stream s) {
|
|
return mlx_gather_qmm_ptr(res, x, w, scales, biases, lhs_indices, rhs_indices, transpose, group_size, bits, mode, sorted_indices, s);
|
|
}
|
|
|
|
int mlx_greater(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_greater_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_greater_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_greater_equal_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) {
|
|
return mlx_hadamard_transform_ptr(res, a, scale, s);
|
|
}
|
|
|
|
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
|
|
return mlx_identity_ptr(res, n, dtype, s);
|
|
}
|
|
|
|
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_imag_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_inner(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_inner_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_isclose(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s) {
|
|
return mlx_isclose_ptr(res, a, b, rtol, atol, equal_nan, s);
|
|
}
|
|
|
|
int mlx_isfinite(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_isfinite_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_isinf(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_isinf_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_isnan(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_isnan_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_isneginf(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_isneginf_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_isposinf(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_isposinf_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_kron(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_kron_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_left_shift(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_left_shift_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_less(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_less_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_less_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_less_equal_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_linspace(mlx_array* res, double start, double stop, int num, mlx_dtype dtype, const mlx_stream s) {
|
|
return mlx_linspace_ptr(res, start, stop, num, dtype, s);
|
|
}
|
|
|
|
int mlx_log(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_log_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_log10(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_log10_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_log1p(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_log1p_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_log2(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_log2_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_logaddexp(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_logaddexp_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_logcumsumexp(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) {
|
|
return mlx_logcumsumexp_ptr(res, a, axis, reverse, inclusive, s);
|
|
}
|
|
|
|
int mlx_logical_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_logical_and_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_logical_not(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_logical_not_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_logical_or(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_logical_or_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_logsumexp_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) {
|
|
return mlx_logsumexp_axes_ptr(res, a, axes, axes_num, keepdims, s);
|
|
}
|
|
|
|
int mlx_logsumexp_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
|
|
return mlx_logsumexp_axis_ptr(res, a, axis, keepdims, s);
|
|
}
|
|
|
|
int mlx_logsumexp(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
|
|
return mlx_logsumexp_ptr(res, a, keepdims, s);
|
|
}
|
|
|
|
int mlx_masked_scatter(mlx_array* res, const mlx_array a, const mlx_array mask, const mlx_array src, const mlx_stream s) {
|
|
return mlx_masked_scatter_ptr(res, a, mask, src, s);
|
|
}
|
|
|
|
int mlx_matmul(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_matmul_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_max_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) {
|
|
return mlx_max_axes_ptr(res, a, axes, axes_num, keepdims, s);
|
|
}
|
|
|
|
int mlx_max_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
|
|
return mlx_max_axis_ptr(res, a, axis, keepdims, s);
|
|
}
|
|
|
|
int mlx_max(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
|
|
return mlx_max_ptr(res, a, keepdims, s);
|
|
}
|
|
|
|
int mlx_maximum(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_maximum_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_mean_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) {
|
|
return mlx_mean_axes_ptr(res, a, axes, axes_num, keepdims, s);
|
|
}
|
|
|
|
int mlx_mean_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
|
|
return mlx_mean_axis_ptr(res, a, axis, keepdims, s);
|
|
}
|
|
|
|
int mlx_mean(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
|
|
return mlx_mean_ptr(res, a, keepdims, s);
|
|
}
|
|
|
|
int mlx_median(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) {
|
|
return mlx_median_ptr(res, a, axes, axes_num, keepdims, s);
|
|
}
|
|
|
|
int mlx_meshgrid(mlx_vector_array* res, const mlx_vector_array arrays, bool sparse, const char* indexing, const mlx_stream s) {
|
|
return mlx_meshgrid_ptr(res, arrays, sparse, indexing, s);
|
|
}
|
|
|
|
int mlx_min_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) {
|
|
return mlx_min_axes_ptr(res, a, axes, axes_num, keepdims, s);
|
|
}
|
|
|
|
int mlx_min_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
|
|
return mlx_min_axis_ptr(res, a, axis, keepdims, s);
|
|
}
|
|
|
|
int mlx_min(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
|
|
return mlx_min_ptr(res, a, keepdims, s);
|
|
}
|
|
|
|
int mlx_minimum(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_minimum_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_moveaxis(mlx_array* res, const mlx_array a, int source, int destination, const mlx_stream s) {
|
|
return mlx_moveaxis_ptr(res, a, source, destination, s);
|
|
}
|
|
|
|
int mlx_multiply(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_multiply_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_nan_to_num(mlx_array* res, const mlx_array a, float nan, mlx_optional_float posinf, mlx_optional_float neginf, const mlx_stream s) {
|
|
return mlx_nan_to_num_ptr(res, a, nan, posinf, neginf, s);
|
|
}
|
|
|
|
int mlx_negative(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_negative_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_not_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_not_equal_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_number_of_elements(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool inverted, mlx_dtype dtype, const mlx_stream s) {
|
|
return mlx_number_of_elements_ptr(res, a, axes, axes_num, inverted, dtype, s);
|
|
}
|
|
|
|
int mlx_ones(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s) {
|
|
return mlx_ones_ptr(res, shape, shape_num, dtype, s);
|
|
}
|
|
|
|
int mlx_ones_like(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_ones_like_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_outer(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_outer_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_pad(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const int* low_pad_size, size_t low_pad_size_num, const int* high_pad_size, size_t high_pad_size_num, const mlx_array pad_value, const char* mode, const mlx_stream s) {
|
|
return mlx_pad_ptr(res, a, axes, axes_num, low_pad_size, low_pad_size_num, high_pad_size, high_pad_size_num, pad_value, mode, s);
|
|
}
|
|
|
|
int mlx_pad_symmetric(mlx_array* res, const mlx_array a, int pad_width, const mlx_array pad_value, const char* mode, const mlx_stream s) {
|
|
return mlx_pad_symmetric_ptr(res, a, pad_width, pad_value, mode, s);
|
|
}
|
|
|
|
int mlx_partition_axis(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s) {
|
|
return mlx_partition_axis_ptr(res, a, kth, axis, s);
|
|
}
|
|
|
|
int mlx_partition(mlx_array* res, const mlx_array a, int kth, const mlx_stream s) {
|
|
return mlx_partition_ptr(res, a, kth, s);
|
|
}
|
|
|
|
int mlx_power(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_power_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_prod_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) {
|
|
return mlx_prod_axes_ptr(res, a, axes, axes_num, keepdims, s);
|
|
}
|
|
|
|
int mlx_prod_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
|
|
return mlx_prod_axis_ptr(res, a, axis, keepdims, s);
|
|
}
|
|
|
|
int mlx_prod(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
|
|
return mlx_prod_ptr(res, a, keepdims, s);
|
|
}
|
|
|
|
int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) {
|
|
return mlx_put_along_axis_ptr(res, a, indices, values, axis, s);
|
|
}
|
|
|
|
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
|
return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, s);
|
|
}
|
|
|
|
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
|
return mlx_quantize_ptr(res, w, group_size, bits, mode, s);
|
|
}
|
|
|
|
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
|
return mlx_quantized_matmul_ptr(res, x, w, scales, biases, transpose, group_size, bits, mode, s);
|
|
}
|
|
|
|
int mlx_radians(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_radians_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_real(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_real_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_reciprocal(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_reciprocal_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_remainder(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_remainder_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_repeat_axis(mlx_array* res, const mlx_array arr, int repeats, int axis, const mlx_stream s) {
|
|
return mlx_repeat_axis_ptr(res, arr, repeats, axis, s);
|
|
}
|
|
|
|
int mlx_repeat(mlx_array* res, const mlx_array arr, int repeats, const mlx_stream s) {
|
|
return mlx_repeat_ptr(res, arr, repeats, s);
|
|
}
|
|
|
|
int mlx_reshape(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) {
|
|
return mlx_reshape_ptr(res, a, shape, shape_num, s);
|
|
}
|
|
|
|
int mlx_right_shift(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_right_shift_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_roll_axis(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, int axis, const mlx_stream s) {
|
|
return mlx_roll_axis_ptr(res, a, shift, shift_num, axis, s);
|
|
}
|
|
|
|
int mlx_roll_axes(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_roll_axes_ptr(res, a, shift, shift_num, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_roll(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const mlx_stream s) {
|
|
return mlx_roll_ptr(res, a, shift, shift_num, s);
|
|
}
|
|
|
|
int mlx_round(mlx_array* res, const mlx_array a, int decimals, const mlx_stream s) {
|
|
return mlx_round_ptr(res, a, decimals, s);
|
|
}
|
|
|
|
int mlx_rsqrt(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_rsqrt_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_scatter(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_scatter_ptr(res, a, indices, updates, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_scatter_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) {
|
|
return mlx_scatter_single_ptr(res, a, indices, updates, axis, s);
|
|
}
|
|
|
|
int mlx_scatter_add(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_scatter_add_ptr(res, a, indices, updates, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_scatter_add_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) {
|
|
return mlx_scatter_add_single_ptr(res, a, indices, updates, axis, s);
|
|
}
|
|
|
|
int mlx_scatter_add_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) {
|
|
return mlx_scatter_add_axis_ptr(res, a, indices, values, axis, s);
|
|
}
|
|
|
|
int mlx_scatter_max(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_scatter_max_ptr(res, a, indices, updates, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_scatter_max_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) {
|
|
return mlx_scatter_max_single_ptr(res, a, indices, updates, axis, s);
|
|
}
|
|
|
|
int mlx_scatter_min(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_scatter_min_ptr(res, a, indices, updates, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_scatter_min_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) {
|
|
return mlx_scatter_min_single_ptr(res, a, indices, updates, axis, s);
|
|
}
|
|
|
|
int mlx_scatter_prod(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_scatter_prod_ptr(res, a, indices, updates, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_scatter_prod_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) {
|
|
return mlx_scatter_prod_single_ptr(res, a, indices, updates, axis, s);
|
|
}
|
|
|
|
int mlx_segmented_mm(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array segments, const mlx_stream s) {
|
|
return mlx_segmented_mm_ptr(res, a, b, segments, s);
|
|
}
|
|
|
|
int mlx_sigmoid(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_sigmoid_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_sign(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_sign_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_sin(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_sin_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_sinh(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_sinh_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_slice(mlx_array* res, const mlx_array a, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) {
|
|
return mlx_slice_ptr(res, a, start, start_num, stop, stop_num, strides, strides_num, s);
|
|
}
|
|
|
|
int mlx_slice_dynamic(mlx_array* res, const mlx_array a, const mlx_array start, const int* axes, size_t axes_num, const int* slice_size, size_t slice_size_num, const mlx_stream s) {
|
|
return mlx_slice_dynamic_ptr(res, a, start, axes, axes_num, slice_size, slice_size_num, s);
|
|
}
|
|
|
|
int mlx_slice_update(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) {
|
|
return mlx_slice_update_ptr(res, src, update, start, start_num, stop, stop_num, strides, strides_num, s);
|
|
}
|
|
|
|
int mlx_slice_update_dynamic(mlx_array* res, const mlx_array src, const mlx_array update, const mlx_array start, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_slice_update_dynamic_ptr(res, src, update, start, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_softmax_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool precise, const mlx_stream s) {
|
|
return mlx_softmax_axes_ptr(res, a, axes, axes_num, precise, s);
|
|
}
|
|
|
|
int mlx_softmax_axis(mlx_array* res, const mlx_array a, int axis, bool precise, const mlx_stream s) {
|
|
return mlx_softmax_axis_ptr(res, a, axis, precise, s);
|
|
}
|
|
|
|
int mlx_softmax(mlx_array* res, const mlx_array a, bool precise, const mlx_stream s) {
|
|
return mlx_softmax_ptr(res, a, precise, s);
|
|
}
|
|
|
|
int mlx_sort_axis(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) {
|
|
return mlx_sort_axis_ptr(res, a, axis, s);
|
|
}
|
|
|
|
int mlx_sort(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_sort_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_split(mlx_vector_array* res, const mlx_array a, int num_splits, int axis, const mlx_stream s) {
|
|
return mlx_split_ptr(res, a, num_splits, axis, s);
|
|
}
|
|
|
|
int mlx_split_sections(mlx_vector_array* res, const mlx_array a, const int* indices, size_t indices_num, int axis, const mlx_stream s) {
|
|
return mlx_split_sections_ptr(res, a, indices, indices_num, axis, s);
|
|
}
|
|
|
|
int mlx_sqrt(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_sqrt_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_square(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_square_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_squeeze_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_squeeze_axes_ptr(res, a, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_squeeze_axis(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) {
|
|
return mlx_squeeze_axis_ptr(res, a, axis, s);
|
|
}
|
|
|
|
int mlx_squeeze(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_squeeze_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_stack_axis(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s) {
|
|
return mlx_stack_axis_ptr(res, arrays, axis, s);
|
|
}
|
|
|
|
int mlx_stack(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) {
|
|
return mlx_stack_ptr(res, arrays, s);
|
|
}
|
|
|
|
int mlx_std_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s) {
|
|
return mlx_std_axes_ptr(res, a, axes, axes_num, keepdims, ddof, s);
|
|
}
|
|
|
|
int mlx_std_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s) {
|
|
return mlx_std_axis_ptr(res, a, axis, keepdims, ddof, s);
|
|
}
|
|
|
|
int mlx_std(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s) {
|
|
return mlx_std_ptr(res, a, keepdims, ddof, s);
|
|
}
|
|
|
|
int mlx_stop_gradient(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_stop_gradient_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_subtract(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
|
return mlx_subtract_ptr(res, a, b, s);
|
|
}
|
|
|
|
int mlx_sum_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) {
|
|
return mlx_sum_axes_ptr(res, a, axes, axes_num, keepdims, s);
|
|
}
|
|
|
|
int mlx_sum_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
|
|
return mlx_sum_axis_ptr(res, a, axis, keepdims, s);
|
|
}
|
|
|
|
int mlx_sum(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
|
|
return mlx_sum_ptr(res, a, keepdims, s);
|
|
}
|
|
|
|
int mlx_swapaxes(mlx_array* res, const mlx_array a, int axis1, int axis2, const mlx_stream s) {
|
|
return mlx_swapaxes_ptr(res, a, axis1, axis2, s);
|
|
}
|
|
|
|
int mlx_take_axis(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s) {
|
|
return mlx_take_axis_ptr(res, a, indices, axis, s);
|
|
}
|
|
|
|
int mlx_take(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_stream s) {
|
|
return mlx_take_ptr(res, a, indices, s);
|
|
}
|
|
|
|
int mlx_take_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s) {
|
|
return mlx_take_along_axis_ptr(res, a, indices, axis, s);
|
|
}
|
|
|
|
int mlx_tan(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_tan_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_tanh(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_tanh_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_tensordot(mlx_array* res, const mlx_array a, const mlx_array b, const int* axes_a, size_t axes_a_num, const int* axes_b, size_t axes_b_num, const mlx_stream s) {
|
|
return mlx_tensordot_ptr(res, a, b, axes_a, axes_a_num, axes_b, axes_b_num, s);
|
|
}
|
|
|
|
int mlx_tensordot_axis(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s) {
|
|
return mlx_tensordot_axis_ptr(res, a, b, axis, s);
|
|
}
|
|
|
|
int mlx_tile(mlx_array* res, const mlx_array arr, const int* reps, size_t reps_num, const mlx_stream s) {
|
|
return mlx_tile_ptr(res, arr, reps, reps_num, s);
|
|
}
|
|
|
|
int mlx_to_fp8(mlx_array* res, const mlx_array x, const mlx_stream s) {
|
|
return mlx_to_fp8_ptr(res, x, s);
|
|
}
|
|
|
|
int mlx_topk_axis(mlx_array* res, const mlx_array a, int k, int axis, const mlx_stream s) {
|
|
return mlx_topk_axis_ptr(res, a, k, axis, s);
|
|
}
|
|
|
|
int mlx_topk(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
|
|
return mlx_topk_ptr(res, a, k, s);
|
|
}
|
|
|
|
int mlx_trace(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, mlx_dtype dtype, const mlx_stream s) {
|
|
return mlx_trace_ptr(res, a, offset, axis1, axis2, dtype, s);
|
|
}
|
|
|
|
int mlx_transpose_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) {
|
|
return mlx_transpose_axes_ptr(res, a, axes, axes_num, s);
|
|
}
|
|
|
|
int mlx_transpose(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_transpose_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_tri(mlx_array* res, int n, int m, int k, mlx_dtype type, const mlx_stream s) {
|
|
return mlx_tri_ptr(res, n, m, k, type, s);
|
|
}
|
|
|
|
int mlx_tril(mlx_array* res, const mlx_array x, int k, const mlx_stream s) {
|
|
return mlx_tril_ptr(res, x, k, s);
|
|
}
|
|
|
|
int mlx_triu(mlx_array* res, const mlx_array x, int k, const mlx_stream s) {
|
|
return mlx_triu_ptr(res, x, k, s);
|
|
}
|
|
|
|
int mlx_unflatten(mlx_array* res, const mlx_array a, int axis, const int* shape, size_t shape_num, const mlx_stream s) {
|
|
return mlx_unflatten_ptr(res, a, axis, shape, shape_num, s);
|
|
}
|
|
|
|
int mlx_var_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s) {
|
|
return mlx_var_axes_ptr(res, a, axes, axes_num, keepdims, ddof, s);
|
|
}
|
|
|
|
int mlx_var_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s) {
|
|
return mlx_var_axis_ptr(res, a, axis, keepdims, ddof, s);
|
|
}
|
|
|
|
int mlx_var(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s) {
|
|
return mlx_var_ptr(res, a, keepdims, ddof, s);
|
|
}
|
|
|
|
int mlx_view(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s) {
|
|
return mlx_view_ptr(res, a, dtype, s);
|
|
}
|
|
|
|
int mlx_where(mlx_array* res, const mlx_array condition, const mlx_array x, const mlx_array y, const mlx_stream s) {
|
|
return mlx_where_ptr(res, condition, x, y, s);
|
|
}
|
|
|
|
int mlx_zeros(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s) {
|
|
return mlx_zeros_ptr(res, shape, shape_num, dtype, s);
|
|
}
|
|
|
|
int mlx_zeros_like(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|
return mlx_zeros_like_ptr(res, a, s);
|
|
}
|
|
|
|
int mlx_random_bernoulli(mlx_array* res, const mlx_array p, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s) {
|
|
return mlx_random_bernoulli_ptr(res, p, shape, shape_num, key, s);
|
|
}
|
|
|
|
int mlx_random_bits(mlx_array* res, const int* shape, size_t shape_num, int width, const mlx_array key , const mlx_stream s) {
|
|
return mlx_random_bits_ptr(res, shape, shape_num, width, key, s);
|
|
}
|
|
|
|
int mlx_random_categorical_shape(mlx_array* res, const mlx_array logits, int axis, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s) {
|
|
return mlx_random_categorical_shape_ptr(res, logits, axis, shape, shape_num, key, s);
|
|
}
|
|
|
|
int mlx_random_categorical_num_samples(mlx_array* res, const mlx_array logits_, int axis, int num_samples, const mlx_array key , const mlx_stream s) {
|
|
return mlx_random_categorical_num_samples_ptr(res, logits_, axis, num_samples, key, s);
|
|
}
|
|
|
|
int mlx_random_categorical(mlx_array* res, const mlx_array logits, int axis, const mlx_array key , const mlx_stream s) {
|
|
return mlx_random_categorical_ptr(res, logits, axis, key, s);
|
|
}
|
|
|
|
int mlx_random_gumbel(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) {
|
|
return mlx_random_gumbel_ptr(res, shape, shape_num, dtype, key, s);
|
|
}
|
|
|
|
int mlx_random_key(mlx_array* res, uint64_t seed) {
|
|
return mlx_random_key_ptr(res, seed);
|
|
}
|
|
|
|
int mlx_random_laplace(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s) {
|
|
return mlx_random_laplace_ptr(res, shape, shape_num, dtype, loc, scale, key, s);
|
|
}
|
|
|
|
int mlx_random_multivariate_normal(mlx_array* res, const mlx_array mean, const mlx_array cov, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) {
|
|
return mlx_random_multivariate_normal_ptr(res, mean, cov, shape, shape_num, dtype, key, s);
|
|
}
|
|
|
|
int mlx_random_normal_broadcast(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array loc , const mlx_array scale , const mlx_array key , const mlx_stream s) {
|
|
return mlx_random_normal_broadcast_ptr(res, shape, shape_num, dtype, loc, scale, key, s);
|
|
}
|
|
|
|
int mlx_random_normal(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s) {
|
|
return mlx_random_normal_ptr(res, shape, shape_num, dtype, loc, scale, key, s);
|
|
}
|
|
|
|
int mlx_random_permutation(mlx_array* res, const mlx_array x, int axis, const mlx_array key , const mlx_stream s) {
|
|
return mlx_random_permutation_ptr(res, x, axis, key, s);
|
|
}
|
|
|
|
int mlx_random_permutation_arange(mlx_array* res, int x, const mlx_array key , const mlx_stream s) {
|
|
return mlx_random_permutation_arange_ptr(res, x, key, s);
|
|
}
|
|
|
|
int mlx_random_randint(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) {
|
|
return mlx_random_randint_ptr(res, low, high, shape, shape_num, dtype, key, s);
|
|
}
|
|
|
|
int mlx_random_seed(uint64_t seed) {
|
|
return mlx_random_seed_ptr(seed);
|
|
}
|
|
|
|
int mlx_random_split_num(mlx_array* res, const mlx_array key, int num, const mlx_stream s) {
|
|
return mlx_random_split_num_ptr(res, key, num, s);
|
|
}
|
|
|
|
int mlx_random_split(mlx_array* res_0, mlx_array* res_1, const mlx_array key, const mlx_stream s) {
|
|
return mlx_random_split_ptr(res_0, res_1, key, s);
|
|
}
|
|
|
|
int mlx_random_truncated_normal(mlx_array* res, const mlx_array lower, const mlx_array upper, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) {
|
|
return mlx_random_truncated_normal_ptr(res, lower, upper, shape, shape_num, dtype, key, s);
|
|
}
|
|
|
|
int mlx_random_uniform(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) {
|
|
return mlx_random_uniform_ptr(res, low, high, shape, shape_num, dtype, key, s);
|
|
}
|
|
|
|
mlx_stream mlx_stream_new(void) {
|
|
return mlx_stream_new_ptr();
|
|
}
|
|
|
|
mlx_stream mlx_stream_new_device(mlx_device dev) {
|
|
return mlx_stream_new_device_ptr(dev);
|
|
}
|
|
|
|
int mlx_stream_set(mlx_stream* stream, const mlx_stream src) {
|
|
return mlx_stream_set_ptr(stream, src);
|
|
}
|
|
|
|
int mlx_stream_free(mlx_stream stream) {
|
|
return mlx_stream_free_ptr(stream);
|
|
}
|
|
|
|
int mlx_stream_tostring(mlx_string* str, mlx_stream stream) {
|
|
return mlx_stream_tostring_ptr(str, stream);
|
|
}
|
|
|
|
bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs) {
|
|
return mlx_stream_equal_ptr(lhs, rhs);
|
|
}
|
|
|
|
int mlx_stream_get_device(mlx_device* dev, mlx_stream stream) {
|
|
return mlx_stream_get_device_ptr(dev, stream);
|
|
}
|
|
|
|
int mlx_stream_get_index(int* index, mlx_stream stream) {
|
|
return mlx_stream_get_index_ptr(index, stream);
|
|
}
|
|
|
|
int mlx_synchronize(mlx_stream stream) {
|
|
return mlx_synchronize_ptr(stream);
|
|
}
|
|
|
|
int mlx_get_default_stream(mlx_stream* stream, mlx_device dev) {
|
|
return mlx_get_default_stream_ptr(stream, dev);
|
|
}
|
|
|
|
int mlx_set_default_stream(mlx_stream stream) {
|
|
return mlx_set_default_stream_ptr(stream);
|
|
}
|
|
|
|
mlx_stream mlx_default_cpu_stream_new(void) {
|
|
return mlx_default_cpu_stream_new_ptr();
|
|
}
|
|
|
|
mlx_stream mlx_default_gpu_stream_new(void) {
|
|
return mlx_default_gpu_stream_new_ptr();
|
|
}
|
|
|
|
mlx_string mlx_string_new(void) {
|
|
return mlx_string_new_ptr();
|
|
}
|
|
|
|
mlx_string mlx_string_new_data(const char* str) {
|
|
return mlx_string_new_data_ptr(str);
|
|
}
|
|
|
|
int mlx_string_set(mlx_string* str, const mlx_string src) {
|
|
return mlx_string_set_ptr(str, src);
|
|
}
|
|
|
|
const char* mlx_string_data(mlx_string str) {
|
|
return mlx_string_data_ptr(str);
|
|
}
|
|
|
|
int mlx_string_free(mlx_string str) {
|
|
return mlx_string_free_ptr(str);
|
|
}
|
|
|
|
int mlx_async_eval(const mlx_vector_array outputs) {
|
|
return mlx_async_eval_ptr(outputs);
|
|
}
|
|
|
|
int mlx_checkpoint(mlx_closure* res, const mlx_closure fun) {
|
|
return mlx_checkpoint_ptr(res, fun);
|
|
}
|
|
|
|
int mlx_custom_function(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp , const mlx_closure_custom_jvp fun_jvp , const mlx_closure_custom_vmap fun_vmap) {
|
|
return mlx_custom_function_ptr(res, fun, fun_vjp, fun_jvp, fun_vmap);
|
|
}
|
|
|
|
int mlx_custom_vjp(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp) {
|
|
return mlx_custom_vjp_ptr(res, fun, fun_vjp);
|
|
}
|
|
|
|
int mlx_eval(const mlx_vector_array outputs) {
|
|
return mlx_eval_ptr(outputs);
|
|
}
|
|
|
|
int mlx_jvp(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array tangents) {
|
|
return mlx_jvp_ptr(res_0, res_1, fun, primals, tangents);
|
|
}
|
|
|
|
int mlx_value_and_grad(mlx_closure_value_and_grad* res, const mlx_closure fun, const int* argnums, size_t argnums_num) {
|
|
return mlx_value_and_grad_ptr(res, fun, argnums, argnums_num);
|
|
}
|
|
|
|
int mlx_vjp(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array cotangents) {
|
|
return mlx_vjp_ptr(res_0, res_1, fun, primals, cotangents);
|
|
}
|
|
|
|
int mlx_detail_vmap_replace(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array s_inputs, const mlx_vector_array s_outputs, const int* in_axes, size_t in_axes_num, const int* out_axes, size_t out_axes_num) {
|
|
return mlx_detail_vmap_replace_ptr(res, inputs, s_inputs, s_outputs, in_axes, in_axes_num, out_axes, out_axes_num);
|
|
}
|
|
|
|
int mlx_detail_vmap_trace(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array inputs, const int* in_axes, size_t in_axes_num) {
|
|
return mlx_detail_vmap_trace_ptr(res_0, res_1, fun, inputs, in_axes, in_axes_num);
|
|
}
|
|
|
|
mlx_vector_array mlx_vector_array_new(void) {
|
|
return mlx_vector_array_new_ptr();
|
|
}
|
|
|
|
int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src) {
|
|
return mlx_vector_array_set_ptr(vec, src);
|
|
}
|
|
|
|
int mlx_vector_array_free(mlx_vector_array vec) {
|
|
return mlx_vector_array_free_ptr(vec);
|
|
}
|
|
|
|
mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size) {
|
|
return mlx_vector_array_new_data_ptr(data, size);
|
|
}
|
|
|
|
mlx_vector_array mlx_vector_array_new_value(const mlx_array val) {
|
|
return mlx_vector_array_new_value_ptr(val);
|
|
}
|
|
|
|
int mlx_vector_array_set_data(mlx_vector_array* vec, const mlx_array* data, size_t size) {
|
|
return mlx_vector_array_set_data_ptr(vec, data, size);
|
|
}
|
|
|
|
int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val) {
|
|
return mlx_vector_array_set_value_ptr(vec, val);
|
|
}
|
|
|
|
int mlx_vector_array_append_data(mlx_vector_array vec, const mlx_array* data, size_t size) {
|
|
return mlx_vector_array_append_data_ptr(vec, data, size);
|
|
}
|
|
|
|
int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val) {
|
|
return mlx_vector_array_append_value_ptr(vec, val);
|
|
}
|
|
|
|
size_t mlx_vector_array_size(mlx_vector_array vec) {
|
|
return mlx_vector_array_size_ptr(vec);
|
|
}
|
|
|
|
int mlx_vector_array_get(mlx_array* res, const mlx_vector_array vec, size_t idx) {
|
|
return mlx_vector_array_get_ptr(res, vec, idx);
|
|
}
|
|
|
|
mlx_vector_vector_array mlx_vector_vector_array_new(void) {
|
|
return mlx_vector_vector_array_new_ptr();
|
|
}
|
|
|
|
int mlx_vector_vector_array_set(mlx_vector_vector_array* vec, const mlx_vector_vector_array src) {
|
|
return mlx_vector_vector_array_set_ptr(vec, src);
|
|
}
|
|
|
|
int mlx_vector_vector_array_free(mlx_vector_vector_array vec) {
|
|
return mlx_vector_vector_array_free_ptr(vec);
|
|
}
|
|
|
|
mlx_vector_vector_array mlx_vector_vector_array_new_data(const mlx_vector_array* data, size_t size) {
|
|
return mlx_vector_vector_array_new_data_ptr(data, size);
|
|
}
|
|
|
|
mlx_vector_vector_array mlx_vector_vector_array_new_value(const mlx_vector_array val) {
|
|
return mlx_vector_vector_array_new_value_ptr(val);
|
|
}
|
|
|
|
int mlx_vector_vector_array_set_data(mlx_vector_vector_array* vec, const mlx_vector_array* data, size_t size) {
|
|
return mlx_vector_vector_array_set_data_ptr(vec, data, size);
|
|
}
|
|
|
|
int mlx_vector_vector_array_set_value(mlx_vector_vector_array* vec, const mlx_vector_array val) {
|
|
return mlx_vector_vector_array_set_value_ptr(vec, val);
|
|
}
|
|
|
|
int mlx_vector_vector_array_append_data(mlx_vector_vector_array vec, const mlx_vector_array* data, size_t size) {
|
|
return mlx_vector_vector_array_append_data_ptr(vec, data, size);
|
|
}
|
|
|
|
int mlx_vector_vector_array_append_value(mlx_vector_vector_array vec, const mlx_vector_array val) {
|
|
return mlx_vector_vector_array_append_value_ptr(vec, val);
|
|
}
|
|
|
|
size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec) {
|
|
return mlx_vector_vector_array_size_ptr(vec);
|
|
}
|
|
|
|
int mlx_vector_vector_array_get(mlx_vector_array* res, const mlx_vector_vector_array vec, size_t idx) {
|
|
return mlx_vector_vector_array_get_ptr(res, vec, idx);
|
|
}
|
|
|
|
mlx_vector_int mlx_vector_int_new(void) {
|
|
return mlx_vector_int_new_ptr();
|
|
}
|
|
|
|
int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src) {
|
|
return mlx_vector_int_set_ptr(vec, src);
|
|
}
|
|
|
|
int mlx_vector_int_free(mlx_vector_int vec) {
|
|
return mlx_vector_int_free_ptr(vec);
|
|
}
|
|
|
|
mlx_vector_int mlx_vector_int_new_data(int* data, size_t size) {
|
|
return mlx_vector_int_new_data_ptr(data, size);
|
|
}
|
|
|
|
mlx_vector_int mlx_vector_int_new_value(int val) {
|
|
return mlx_vector_int_new_value_ptr(val);
|
|
}
|
|
|
|
int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size) {
|
|
return mlx_vector_int_set_data_ptr(vec, data, size);
|
|
}
|
|
|
|
int mlx_vector_int_set_value(mlx_vector_int* vec, int val) {
|
|
return mlx_vector_int_set_value_ptr(vec, val);
|
|
}
|
|
|
|
int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size) {
|
|
return mlx_vector_int_append_data_ptr(vec, data, size);
|
|
}
|
|
|
|
int mlx_vector_int_append_value(mlx_vector_int vec, int val) {
|
|
return mlx_vector_int_append_value_ptr(vec, val);
|
|
}
|
|
|
|
size_t mlx_vector_int_size(mlx_vector_int vec) {
|
|
return mlx_vector_int_size_ptr(vec);
|
|
}
|
|
|
|
int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx) {
|
|
return mlx_vector_int_get_ptr(res, vec, idx);
|
|
}
|
|
|
|
mlx_vector_string mlx_vector_string_new(void) {
|
|
return mlx_vector_string_new_ptr();
|
|
}
|
|
|
|
int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src) {
|
|
return mlx_vector_string_set_ptr(vec, src);
|
|
}
|
|
|
|
int mlx_vector_string_free(mlx_vector_string vec) {
|
|
return mlx_vector_string_free_ptr(vec);
|
|
}
|
|
|
|
mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size) {
|
|
return mlx_vector_string_new_data_ptr(data, size);
|
|
}
|
|
|
|
mlx_vector_string mlx_vector_string_new_value(const char* val) {
|
|
return mlx_vector_string_new_value_ptr(val);
|
|
}
|
|
|
|
int mlx_vector_string_set_data(mlx_vector_string* vec, const char** data, size_t size) {
|
|
return mlx_vector_string_set_data_ptr(vec, data, size);
|
|
}
|
|
|
|
int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val) {
|
|
return mlx_vector_string_set_value_ptr(vec, val);
|
|
}
|
|
|
|
int mlx_vector_string_append_data(mlx_vector_string vec, const char** data, size_t size) {
|
|
return mlx_vector_string_append_data_ptr(vec, data, size);
|
|
}
|
|
|
|
int mlx_vector_string_append_value(mlx_vector_string vec, const char* val) {
|
|
return mlx_vector_string_append_value_ptr(vec, val);
|
|
}
|
|
|
|
size_t mlx_vector_string_size(mlx_vector_string vec) {
|
|
return mlx_vector_string_size_ptr(vec);
|
|
}
|
|
|
|
int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx) {
|
|
return mlx_vector_string_get_ptr(res, vec, idx);
|
|
}
|
|
|
|
int mlx_version(mlx_string* str_) {
|
|
return mlx_version_ptr(str_);
|
|
}
|
|
|