// This code is auto-generated; DO NOT EDIT. #ifndef MLX_GENERATED_H #define MLX_GENERATED_H #include "dynamic.h" #define mlx_dtype_size mlx_dtype_size_mlx_gen_orig_ #define mlx_array_tostring mlx_array_tostring_mlx_gen_orig_ #define mlx_array_new mlx_array_new_mlx_gen_orig_ #define mlx_array_free mlx_array_free_mlx_gen_orig_ #define mlx_array_new_bool mlx_array_new_bool_mlx_gen_orig_ #define mlx_array_new_int mlx_array_new_int_mlx_gen_orig_ #define mlx_array_new_float32 mlx_array_new_float32_mlx_gen_orig_ #define mlx_array_new_float mlx_array_new_float_mlx_gen_orig_ #define mlx_array_new_float64 mlx_array_new_float64_mlx_gen_orig_ #define mlx_array_new_double mlx_array_new_double_mlx_gen_orig_ #define mlx_array_new_complex mlx_array_new_complex_mlx_gen_orig_ #define mlx_array_new_data mlx_array_new_data_mlx_gen_orig_ #define mlx_array_new_data_managed mlx_array_new_data_managed_mlx_gen_orig_ #define mlx_array_new_data_managed_payload mlx_array_new_data_managed_payload_mlx_gen_orig_ #define mlx_array_set mlx_array_set_mlx_gen_orig_ #define mlx_array_set_bool mlx_array_set_bool_mlx_gen_orig_ #define mlx_array_set_int mlx_array_set_int_mlx_gen_orig_ #define mlx_array_set_float32 mlx_array_set_float32_mlx_gen_orig_ #define mlx_array_set_float mlx_array_set_float_mlx_gen_orig_ #define mlx_array_set_float64 mlx_array_set_float64_mlx_gen_orig_ #define mlx_array_set_double mlx_array_set_double_mlx_gen_orig_ #define mlx_array_set_complex mlx_array_set_complex_mlx_gen_orig_ #define mlx_array_set_data mlx_array_set_data_mlx_gen_orig_ #define mlx_array_itemsize mlx_array_itemsize_mlx_gen_orig_ #define mlx_array_size mlx_array_size_mlx_gen_orig_ #define mlx_array_nbytes mlx_array_nbytes_mlx_gen_orig_ #define mlx_array_ndim mlx_array_ndim_mlx_gen_orig_ #define mlx_array_shape mlx_array_shape_mlx_gen_orig_ #define mlx_array_strides mlx_array_strides_mlx_gen_orig_ #define mlx_array_dim mlx_array_dim_mlx_gen_orig_ #define mlx_array_dtype mlx_array_dtype_mlx_gen_orig_ #define mlx_array_eval mlx_array_eval_mlx_gen_orig_ #define mlx_array_item_bool mlx_array_item_bool_mlx_gen_orig_ #define mlx_array_item_uint8 mlx_array_item_uint8_mlx_gen_orig_ #define mlx_array_item_uint16 mlx_array_item_uint16_mlx_gen_orig_ #define mlx_array_item_uint32 mlx_array_item_uint32_mlx_gen_orig_ #define mlx_array_item_uint64 mlx_array_item_uint64_mlx_gen_orig_ #define mlx_array_item_int8 mlx_array_item_int8_mlx_gen_orig_ #define mlx_array_item_int16 mlx_array_item_int16_mlx_gen_orig_ #define mlx_array_item_int32 mlx_array_item_int32_mlx_gen_orig_ #define mlx_array_item_int64 mlx_array_item_int64_mlx_gen_orig_ #define mlx_array_item_float32 mlx_array_item_float32_mlx_gen_orig_ #define mlx_array_item_float64 mlx_array_item_float64_mlx_gen_orig_ #define mlx_array_item_complex64 mlx_array_item_complex64_mlx_gen_orig_ #define mlx_array_item_float16 mlx_array_item_float16_mlx_gen_orig_ #define mlx_array_item_bfloat16 mlx_array_item_bfloat16_mlx_gen_orig_ #define mlx_array_data_bool mlx_array_data_bool_mlx_gen_orig_ #define mlx_array_data_uint8 mlx_array_data_uint8_mlx_gen_orig_ #define mlx_array_data_uint16 mlx_array_data_uint16_mlx_gen_orig_ #define mlx_array_data_uint32 mlx_array_data_uint32_mlx_gen_orig_ #define mlx_array_data_uint64 mlx_array_data_uint64_mlx_gen_orig_ #define mlx_array_data_int8 mlx_array_data_int8_mlx_gen_orig_ #define mlx_array_data_int16 mlx_array_data_int16_mlx_gen_orig_ #define mlx_array_data_int32 mlx_array_data_int32_mlx_gen_orig_ #define mlx_array_data_int64 mlx_array_data_int64_mlx_gen_orig_ #define mlx_array_data_float32 mlx_array_data_float32_mlx_gen_orig_ #define mlx_array_data_float64 mlx_array_data_float64_mlx_gen_orig_ #define mlx_array_data_complex64 mlx_array_data_complex64_mlx_gen_orig_ #define mlx_array_data_float16 mlx_array_data_float16_mlx_gen_orig_ #define mlx_array_data_bfloat16 mlx_array_data_bfloat16_mlx_gen_orig_ #define _mlx_array_is_available _mlx_array_is_available_mlx_gen_orig_ #define _mlx_array_wait _mlx_array_wait_mlx_gen_orig_ #define _mlx_array_is_contiguous _mlx_array_is_contiguous_mlx_gen_orig_ #define _mlx_array_is_row_contiguous _mlx_array_is_row_contiguous_mlx_gen_orig_ #define _mlx_array_is_col_contiguous _mlx_array_is_col_contiguous_mlx_gen_orig_ #define mlx_closure_new mlx_closure_new_mlx_gen_orig_ #define mlx_closure_free mlx_closure_free_mlx_gen_orig_ #define mlx_closure_new_func mlx_closure_new_func_mlx_gen_orig_ #define mlx_closure_new_func_payload mlx_closure_new_func_payload_mlx_gen_orig_ #define mlx_closure_set mlx_closure_set_mlx_gen_orig_ #define mlx_closure_apply mlx_closure_apply_mlx_gen_orig_ #define mlx_closure_new_unary mlx_closure_new_unary_mlx_gen_orig_ #define mlx_closure_kwargs_new mlx_closure_kwargs_new_mlx_gen_orig_ #define mlx_closure_kwargs_free mlx_closure_kwargs_free_mlx_gen_orig_ #define mlx_closure_kwargs_new_func mlx_closure_kwargs_new_func_mlx_gen_orig_ #define mlx_closure_kwargs_new_func_payload mlx_closure_kwargs_new_func_payload_mlx_gen_orig_ #define mlx_closure_kwargs_set mlx_closure_kwargs_set_mlx_gen_orig_ #define mlx_closure_kwargs_apply mlx_closure_kwargs_apply_mlx_gen_orig_ #define mlx_closure_value_and_grad_new mlx_closure_value_and_grad_new_mlx_gen_orig_ #define mlx_closure_value_and_grad_free mlx_closure_value_and_grad_free_mlx_gen_orig_ #define mlx_closure_value_and_grad_new_func mlx_closure_value_and_grad_new_func_mlx_gen_orig_ #define mlx_closure_value_and_grad_new_func_payload mlx_closure_value_and_grad_new_func_payload_mlx_gen_orig_ #define mlx_closure_value_and_grad_set mlx_closure_value_and_grad_set_mlx_gen_orig_ #define mlx_closure_value_and_grad_apply mlx_closure_value_and_grad_apply_mlx_gen_orig_ #define mlx_closure_custom_new mlx_closure_custom_new_mlx_gen_orig_ #define mlx_closure_custom_free mlx_closure_custom_free_mlx_gen_orig_ #define mlx_closure_custom_new_func mlx_closure_custom_new_func_mlx_gen_orig_ #define mlx_closure_custom_new_func_payload mlx_closure_custom_new_func_payload_mlx_gen_orig_ #define mlx_closure_custom_set mlx_closure_custom_set_mlx_gen_orig_ #define mlx_closure_custom_apply mlx_closure_custom_apply_mlx_gen_orig_ #define mlx_closure_custom_jvp_new mlx_closure_custom_jvp_new_mlx_gen_orig_ #define mlx_closure_custom_jvp_free mlx_closure_custom_jvp_free_mlx_gen_orig_ #define mlx_closure_custom_jvp_new_func mlx_closure_custom_jvp_new_func_mlx_gen_orig_ #define mlx_closure_custom_jvp_new_func_payload mlx_closure_custom_jvp_new_func_payload_mlx_gen_orig_ #define mlx_closure_custom_jvp_set mlx_closure_custom_jvp_set_mlx_gen_orig_ #define mlx_closure_custom_jvp_apply mlx_closure_custom_jvp_apply_mlx_gen_orig_ #define mlx_closure_custom_vmap_new mlx_closure_custom_vmap_new_mlx_gen_orig_ #define mlx_closure_custom_vmap_free mlx_closure_custom_vmap_free_mlx_gen_orig_ #define mlx_closure_custom_vmap_new_func mlx_closure_custom_vmap_new_func_mlx_gen_orig_ #define mlx_closure_custom_vmap_new_func_payload mlx_closure_custom_vmap_new_func_payload_mlx_gen_orig_ #define mlx_closure_custom_vmap_set mlx_closure_custom_vmap_set_mlx_gen_orig_ #define mlx_closure_custom_vmap_apply mlx_closure_custom_vmap_apply_mlx_gen_orig_ #define mlx_compile mlx_compile_mlx_gen_orig_ #define mlx_detail_compile mlx_detail_compile_mlx_gen_orig_ #define mlx_detail_compile_clear_cache mlx_detail_compile_clear_cache_mlx_gen_orig_ #define mlx_detail_compile_erase mlx_detail_compile_erase_mlx_gen_orig_ #define mlx_disable_compile mlx_disable_compile_mlx_gen_orig_ #define mlx_enable_compile mlx_enable_compile_mlx_gen_orig_ #define mlx_set_compile_mode mlx_set_compile_mode_mlx_gen_orig_ #define mlx_cuda_is_available mlx_cuda_is_available_mlx_gen_orig_ #define mlx_device_new mlx_device_new_mlx_gen_orig_ #define mlx_device_new_type mlx_device_new_type_mlx_gen_orig_ #define mlx_device_free mlx_device_free_mlx_gen_orig_ #define mlx_device_set mlx_device_set_mlx_gen_orig_ #define mlx_device_tostring mlx_device_tostring_mlx_gen_orig_ #define mlx_device_equal mlx_device_equal_mlx_gen_orig_ #define mlx_device_get_index mlx_device_get_index_mlx_gen_orig_ #define mlx_device_get_type mlx_device_get_type_mlx_gen_orig_ #define mlx_get_default_device mlx_get_default_device_mlx_gen_orig_ #define mlx_set_default_device mlx_set_default_device_mlx_gen_orig_ #define mlx_device_is_available mlx_device_is_available_mlx_gen_orig_ #define mlx_device_count mlx_device_count_mlx_gen_orig_ #define mlx_device_info_new mlx_device_info_new_mlx_gen_orig_ #define mlx_device_info_get mlx_device_info_get_mlx_gen_orig_ #define mlx_device_info_free mlx_device_info_free_mlx_gen_orig_ #define mlx_device_info_has_key mlx_device_info_has_key_mlx_gen_orig_ #define mlx_device_info_is_string mlx_device_info_is_string_mlx_gen_orig_ #define mlx_device_info_get_string mlx_device_info_get_string_mlx_gen_orig_ #define mlx_device_info_get_size mlx_device_info_get_size_mlx_gen_orig_ #define mlx_device_info_get_keys mlx_device_info_get_keys_mlx_gen_orig_ #define mlx_distributed_all_gather mlx_distributed_all_gather_mlx_gen_orig_ #define mlx_distributed_all_max mlx_distributed_all_max_mlx_gen_orig_ #define mlx_distributed_all_min mlx_distributed_all_min_mlx_gen_orig_ #define mlx_distributed_all_sum mlx_distributed_all_sum_mlx_gen_orig_ #define mlx_distributed_recv mlx_distributed_recv_mlx_gen_orig_ #define mlx_distributed_recv_like mlx_distributed_recv_like_mlx_gen_orig_ #define mlx_distributed_send mlx_distributed_send_mlx_gen_orig_ #define mlx_distributed_sum_scatter mlx_distributed_sum_scatter_mlx_gen_orig_ #define mlx_distributed_group_rank mlx_distributed_group_rank_mlx_gen_orig_ #define mlx_distributed_group_size mlx_distributed_group_size_mlx_gen_orig_ #define mlx_distributed_group_split mlx_distributed_group_split_mlx_gen_orig_ #define mlx_distributed_is_available mlx_distributed_is_available_mlx_gen_orig_ #define mlx_distributed_init mlx_distributed_init_mlx_gen_orig_ #define mlx_set_error_handler mlx_set_error_handler_mlx_gen_orig_ #define _mlx_error _mlx_error_mlx_gen_orig_ #define mlx_export_function mlx_export_function_mlx_gen_orig_ #define mlx_export_function_kwargs mlx_export_function_kwargs_mlx_gen_orig_ #define mlx_function_exporter_new mlx_function_exporter_new_mlx_gen_orig_ #define mlx_function_exporter_free mlx_function_exporter_free_mlx_gen_orig_ #define mlx_function_exporter_apply mlx_function_exporter_apply_mlx_gen_orig_ #define mlx_function_exporter_apply_kwargs mlx_function_exporter_apply_kwargs_mlx_gen_orig_ #define mlx_imported_function_new mlx_imported_function_new_mlx_gen_orig_ #define mlx_imported_function_free mlx_imported_function_free_mlx_gen_orig_ #define mlx_imported_function_apply mlx_imported_function_apply_mlx_gen_orig_ #define mlx_imported_function_apply_kwargs mlx_imported_function_apply_kwargs_mlx_gen_orig_ #define mlx_fast_cuda_kernel_config_new mlx_fast_cuda_kernel_config_new_mlx_gen_orig_ #define mlx_fast_cuda_kernel_config_free mlx_fast_cuda_kernel_config_free_mlx_gen_orig_ #define mlx_fast_cuda_kernel_config_add_output_arg mlx_fast_cuda_kernel_config_add_output_arg_mlx_gen_orig_ #define mlx_fast_cuda_kernel_config_set_grid mlx_fast_cuda_kernel_config_set_grid_mlx_gen_orig_ #define mlx_fast_cuda_kernel_config_set_thread_group mlx_fast_cuda_kernel_config_set_thread_group_mlx_gen_orig_ #define mlx_fast_cuda_kernel_config_set_init_value mlx_fast_cuda_kernel_config_set_init_value_mlx_gen_orig_ #define mlx_fast_cuda_kernel_config_set_verbose mlx_fast_cuda_kernel_config_set_verbose_mlx_gen_orig_ #define mlx_fast_cuda_kernel_config_add_template_arg_dtype mlx_fast_cuda_kernel_config_add_template_arg_dtype_mlx_gen_orig_ #define mlx_fast_cuda_kernel_config_add_template_arg_int mlx_fast_cuda_kernel_config_add_template_arg_int_mlx_gen_orig_ #define mlx_fast_cuda_kernel_config_add_template_arg_bool mlx_fast_cuda_kernel_config_add_template_arg_bool_mlx_gen_orig_ #define mlx_fast_cuda_kernel_new mlx_fast_cuda_kernel_new_mlx_gen_orig_ #define mlx_fast_cuda_kernel_free mlx_fast_cuda_kernel_free_mlx_gen_orig_ #define mlx_fast_cuda_kernel_apply mlx_fast_cuda_kernel_apply_mlx_gen_orig_ #define mlx_fast_layer_norm mlx_fast_layer_norm_mlx_gen_orig_ #define mlx_fast_metal_kernel_config_new mlx_fast_metal_kernel_config_new_mlx_gen_orig_ #define mlx_fast_metal_kernel_config_free mlx_fast_metal_kernel_config_free_mlx_gen_orig_ #define mlx_fast_metal_kernel_config_add_output_arg mlx_fast_metal_kernel_config_add_output_arg_mlx_gen_orig_ #define mlx_fast_metal_kernel_config_set_grid mlx_fast_metal_kernel_config_set_grid_mlx_gen_orig_ #define mlx_fast_metal_kernel_config_set_thread_group mlx_fast_metal_kernel_config_set_thread_group_mlx_gen_orig_ #define mlx_fast_metal_kernel_config_set_init_value mlx_fast_metal_kernel_config_set_init_value_mlx_gen_orig_ #define mlx_fast_metal_kernel_config_set_verbose mlx_fast_metal_kernel_config_set_verbose_mlx_gen_orig_ #define mlx_fast_metal_kernel_config_add_template_arg_dtype mlx_fast_metal_kernel_config_add_template_arg_dtype_mlx_gen_orig_ #define mlx_fast_metal_kernel_config_add_template_arg_int mlx_fast_metal_kernel_config_add_template_arg_int_mlx_gen_orig_ #define mlx_fast_metal_kernel_config_add_template_arg_bool mlx_fast_metal_kernel_config_add_template_arg_bool_mlx_gen_orig_ #define mlx_fast_metal_kernel_new mlx_fast_metal_kernel_new_mlx_gen_orig_ #define mlx_fast_metal_kernel_free mlx_fast_metal_kernel_free_mlx_gen_orig_ #define mlx_fast_metal_kernel_apply mlx_fast_metal_kernel_apply_mlx_gen_orig_ #define mlx_fast_rms_norm mlx_fast_rms_norm_mlx_gen_orig_ #define mlx_fast_rope mlx_fast_rope_mlx_gen_orig_ #define mlx_fast_rope_dynamic mlx_fast_rope_dynamic_mlx_gen_orig_ #define mlx_fast_scaled_dot_product_attention mlx_fast_scaled_dot_product_attention_mlx_gen_orig_ #define mlx_fft_fft mlx_fft_fft_mlx_gen_orig_ #define mlx_fft_fft2 mlx_fft_fft2_mlx_gen_orig_ #define mlx_fft_fftn mlx_fft_fftn_mlx_gen_orig_ #define mlx_fft_fftshift mlx_fft_fftshift_mlx_gen_orig_ #define mlx_fft_ifft mlx_fft_ifft_mlx_gen_orig_ #define mlx_fft_ifft2 mlx_fft_ifft2_mlx_gen_orig_ #define mlx_fft_ifftn mlx_fft_ifftn_mlx_gen_orig_ #define mlx_fft_ifftshift mlx_fft_ifftshift_mlx_gen_orig_ #define mlx_fft_irfft mlx_fft_irfft_mlx_gen_orig_ #define mlx_fft_irfft2 mlx_fft_irfft2_mlx_gen_orig_ #define mlx_fft_irfftn mlx_fft_irfftn_mlx_gen_orig_ #define mlx_fft_rfft mlx_fft_rfft_mlx_gen_orig_ #define mlx_fft_rfft2 mlx_fft_rfft2_mlx_gen_orig_ #define mlx_fft_rfftn mlx_fft_rfftn_mlx_gen_orig_ #define mlx_load_reader mlx_load_reader_mlx_gen_orig_ #define mlx_load mlx_load_mlx_gen_orig_ #define mlx_load_safetensors_reader mlx_load_safetensors_reader_mlx_gen_orig_ #define mlx_load_safetensors mlx_load_safetensors_mlx_gen_orig_ #define mlx_save_writer mlx_save_writer_mlx_gen_orig_ #define mlx_save mlx_save_mlx_gen_orig_ #define mlx_save_safetensors_writer mlx_save_safetensors_writer_mlx_gen_orig_ #define mlx_save_safetensors mlx_save_safetensors_mlx_gen_orig_ #define mlx_io_reader_new mlx_io_reader_new_mlx_gen_orig_ #define mlx_io_reader_descriptor mlx_io_reader_descriptor_mlx_gen_orig_ #define mlx_io_reader_tostring mlx_io_reader_tostring_mlx_gen_orig_ #define mlx_io_reader_free mlx_io_reader_free_mlx_gen_orig_ #define mlx_io_writer_new mlx_io_writer_new_mlx_gen_orig_ #define mlx_io_writer_descriptor mlx_io_writer_descriptor_mlx_gen_orig_ #define mlx_io_writer_tostring mlx_io_writer_tostring_mlx_gen_orig_ #define mlx_io_writer_free mlx_io_writer_free_mlx_gen_orig_ #define mlx_linalg_cholesky mlx_linalg_cholesky_mlx_gen_orig_ #define mlx_linalg_cholesky_inv mlx_linalg_cholesky_inv_mlx_gen_orig_ #define mlx_linalg_cross mlx_linalg_cross_mlx_gen_orig_ #define mlx_linalg_eig mlx_linalg_eig_mlx_gen_orig_ #define mlx_linalg_eigh mlx_linalg_eigh_mlx_gen_orig_ #define mlx_linalg_eigvals mlx_linalg_eigvals_mlx_gen_orig_ #define mlx_linalg_eigvalsh mlx_linalg_eigvalsh_mlx_gen_orig_ #define mlx_linalg_inv mlx_linalg_inv_mlx_gen_orig_ #define mlx_linalg_lu mlx_linalg_lu_mlx_gen_orig_ #define mlx_linalg_lu_factor mlx_linalg_lu_factor_mlx_gen_orig_ #define mlx_linalg_norm mlx_linalg_norm_mlx_gen_orig_ #define mlx_linalg_norm_matrix mlx_linalg_norm_matrix_mlx_gen_orig_ #define mlx_linalg_norm_l2 mlx_linalg_norm_l2_mlx_gen_orig_ #define mlx_linalg_pinv mlx_linalg_pinv_mlx_gen_orig_ #define mlx_linalg_qr mlx_linalg_qr_mlx_gen_orig_ #define mlx_linalg_solve mlx_linalg_solve_mlx_gen_orig_ #define mlx_linalg_solve_triangular mlx_linalg_solve_triangular_mlx_gen_orig_ #define mlx_linalg_svd mlx_linalg_svd_mlx_gen_orig_ #define mlx_linalg_tri_inv mlx_linalg_tri_inv_mlx_gen_orig_ #define mlx_map_string_to_array_new mlx_map_string_to_array_new_mlx_gen_orig_ #define mlx_map_string_to_array_set mlx_map_string_to_array_set_mlx_gen_orig_ #define mlx_map_string_to_array_free mlx_map_string_to_array_free_mlx_gen_orig_ #define mlx_map_string_to_array_insert mlx_map_string_to_array_insert_mlx_gen_orig_ #define mlx_map_string_to_array_get mlx_map_string_to_array_get_mlx_gen_orig_ #define mlx_map_string_to_array_iterator_new mlx_map_string_to_array_iterator_new_mlx_gen_orig_ #define mlx_map_string_to_array_iterator_free mlx_map_string_to_array_iterator_free_mlx_gen_orig_ #define mlx_map_string_to_array_iterator_next mlx_map_string_to_array_iterator_next_mlx_gen_orig_ #define mlx_map_string_to_string_new mlx_map_string_to_string_new_mlx_gen_orig_ #define mlx_map_string_to_string_set mlx_map_string_to_string_set_mlx_gen_orig_ #define mlx_map_string_to_string_free mlx_map_string_to_string_free_mlx_gen_orig_ #define mlx_map_string_to_string_insert mlx_map_string_to_string_insert_mlx_gen_orig_ #define mlx_map_string_to_string_get mlx_map_string_to_string_get_mlx_gen_orig_ #define mlx_map_string_to_string_iterator_new mlx_map_string_to_string_iterator_new_mlx_gen_orig_ #define mlx_map_string_to_string_iterator_free mlx_map_string_to_string_iterator_free_mlx_gen_orig_ #define mlx_map_string_to_string_iterator_next mlx_map_string_to_string_iterator_next_mlx_gen_orig_ #define mlx_clear_cache mlx_clear_cache_mlx_gen_orig_ #define mlx_get_active_memory mlx_get_active_memory_mlx_gen_orig_ #define mlx_get_cache_memory mlx_get_cache_memory_mlx_gen_orig_ #define mlx_get_memory_limit mlx_get_memory_limit_mlx_gen_orig_ #define mlx_get_peak_memory mlx_get_peak_memory_mlx_gen_orig_ #define mlx_reset_peak_memory mlx_reset_peak_memory_mlx_gen_orig_ #define mlx_set_cache_limit mlx_set_cache_limit_mlx_gen_orig_ #define mlx_set_memory_limit mlx_set_memory_limit_mlx_gen_orig_ #define mlx_set_wired_limit mlx_set_wired_limit_mlx_gen_orig_ #define mlx_metal_is_available mlx_metal_is_available_mlx_gen_orig_ #define mlx_metal_start_capture mlx_metal_start_capture_mlx_gen_orig_ #define mlx_metal_stop_capture mlx_metal_stop_capture_mlx_gen_orig_ #define mlx_abs mlx_abs_mlx_gen_orig_ #define mlx_add mlx_add_mlx_gen_orig_ #define mlx_addmm mlx_addmm_mlx_gen_orig_ #define mlx_all_axes mlx_all_axes_mlx_gen_orig_ #define mlx_all_axis mlx_all_axis_mlx_gen_orig_ #define mlx_all mlx_all_mlx_gen_orig_ #define mlx_allclose mlx_allclose_mlx_gen_orig_ #define mlx_any_axes mlx_any_axes_mlx_gen_orig_ #define mlx_any_axis mlx_any_axis_mlx_gen_orig_ #define mlx_any mlx_any_mlx_gen_orig_ #define mlx_arange mlx_arange_mlx_gen_orig_ #define mlx_arccos mlx_arccos_mlx_gen_orig_ #define mlx_arccosh mlx_arccosh_mlx_gen_orig_ #define mlx_arcsin mlx_arcsin_mlx_gen_orig_ #define mlx_arcsinh mlx_arcsinh_mlx_gen_orig_ #define mlx_arctan mlx_arctan_mlx_gen_orig_ #define mlx_arctan2 mlx_arctan2_mlx_gen_orig_ #define mlx_arctanh mlx_arctanh_mlx_gen_orig_ #define mlx_argmax_axis mlx_argmax_axis_mlx_gen_orig_ #define mlx_argmax mlx_argmax_mlx_gen_orig_ #define mlx_argmin_axis mlx_argmin_axis_mlx_gen_orig_ #define mlx_argmin mlx_argmin_mlx_gen_orig_ #define mlx_argpartition_axis mlx_argpartition_axis_mlx_gen_orig_ #define mlx_argpartition mlx_argpartition_mlx_gen_orig_ #define mlx_argsort_axis mlx_argsort_axis_mlx_gen_orig_ #define mlx_argsort mlx_argsort_mlx_gen_orig_ #define mlx_array_equal mlx_array_equal_mlx_gen_orig_ #define mlx_as_strided mlx_as_strided_mlx_gen_orig_ #define mlx_astype mlx_astype_mlx_gen_orig_ #define mlx_atleast_1d mlx_atleast_1d_mlx_gen_orig_ #define mlx_atleast_2d mlx_atleast_2d_mlx_gen_orig_ #define mlx_atleast_3d mlx_atleast_3d_mlx_gen_orig_ #define mlx_bartlett mlx_bartlett_mlx_gen_orig_ #define mlx_bitwise_and mlx_bitwise_and_mlx_gen_orig_ #define mlx_bitwise_invert mlx_bitwise_invert_mlx_gen_orig_ #define mlx_bitwise_or mlx_bitwise_or_mlx_gen_orig_ #define mlx_bitwise_xor mlx_bitwise_xor_mlx_gen_orig_ #define mlx_blackman mlx_blackman_mlx_gen_orig_ #define mlx_block_masked_mm mlx_block_masked_mm_mlx_gen_orig_ #define mlx_broadcast_arrays mlx_broadcast_arrays_mlx_gen_orig_ #define mlx_broadcast_to mlx_broadcast_to_mlx_gen_orig_ #define mlx_ceil mlx_ceil_mlx_gen_orig_ #define mlx_clip mlx_clip_mlx_gen_orig_ #define mlx_concatenate_axis mlx_concatenate_axis_mlx_gen_orig_ #define mlx_concatenate mlx_concatenate_mlx_gen_orig_ #define mlx_conjugate mlx_conjugate_mlx_gen_orig_ #define mlx_contiguous mlx_contiguous_mlx_gen_orig_ #define mlx_conv1d mlx_conv1d_mlx_gen_orig_ #define mlx_conv2d mlx_conv2d_mlx_gen_orig_ #define mlx_conv3d mlx_conv3d_mlx_gen_orig_ #define mlx_conv_general mlx_conv_general_mlx_gen_orig_ #define mlx_conv_transpose1d mlx_conv_transpose1d_mlx_gen_orig_ #define mlx_conv_transpose2d mlx_conv_transpose2d_mlx_gen_orig_ #define mlx_conv_transpose3d mlx_conv_transpose3d_mlx_gen_orig_ #define mlx_copy mlx_copy_mlx_gen_orig_ #define mlx_cos mlx_cos_mlx_gen_orig_ #define mlx_cosh mlx_cosh_mlx_gen_orig_ #define mlx_cummax mlx_cummax_mlx_gen_orig_ #define mlx_cummin mlx_cummin_mlx_gen_orig_ #define mlx_cumprod mlx_cumprod_mlx_gen_orig_ #define mlx_cumsum mlx_cumsum_mlx_gen_orig_ #define mlx_degrees mlx_degrees_mlx_gen_orig_ #define mlx_depends mlx_depends_mlx_gen_orig_ #define mlx_dequantize mlx_dequantize_mlx_gen_orig_ #define mlx_diag mlx_diag_mlx_gen_orig_ #define mlx_diagonal mlx_diagonal_mlx_gen_orig_ #define mlx_divide mlx_divide_mlx_gen_orig_ #define mlx_divmod mlx_divmod_mlx_gen_orig_ #define mlx_einsum mlx_einsum_mlx_gen_orig_ #define mlx_equal mlx_equal_mlx_gen_orig_ #define mlx_erf mlx_erf_mlx_gen_orig_ #define mlx_erfinv mlx_erfinv_mlx_gen_orig_ #define mlx_exp mlx_exp_mlx_gen_orig_ #define mlx_expand_dims_axes mlx_expand_dims_axes_mlx_gen_orig_ #define mlx_expand_dims mlx_expand_dims_mlx_gen_orig_ #define mlx_expm1 mlx_expm1_mlx_gen_orig_ #define mlx_eye mlx_eye_mlx_gen_orig_ #define mlx_flatten mlx_flatten_mlx_gen_orig_ #define mlx_floor mlx_floor_mlx_gen_orig_ #define mlx_floor_divide mlx_floor_divide_mlx_gen_orig_ #define mlx_from_fp8 mlx_from_fp8_mlx_gen_orig_ #define mlx_full mlx_full_mlx_gen_orig_ #define mlx_full_like mlx_full_like_mlx_gen_orig_ #define mlx_gather mlx_gather_mlx_gen_orig_ #define mlx_gather_single mlx_gather_single_mlx_gen_orig_ #define mlx_gather_mm mlx_gather_mm_mlx_gen_orig_ #define mlx_gather_qmm mlx_gather_qmm_mlx_gen_orig_ #define mlx_greater mlx_greater_mlx_gen_orig_ #define mlx_greater_equal mlx_greater_equal_mlx_gen_orig_ #define mlx_hadamard_transform mlx_hadamard_transform_mlx_gen_orig_ #define mlx_hamming mlx_hamming_mlx_gen_orig_ #define mlx_hanning mlx_hanning_mlx_gen_orig_ #define mlx_identity mlx_identity_mlx_gen_orig_ #define mlx_imag mlx_imag_mlx_gen_orig_ #define mlx_inner mlx_inner_mlx_gen_orig_ #define mlx_isclose mlx_isclose_mlx_gen_orig_ #define mlx_isfinite mlx_isfinite_mlx_gen_orig_ #define mlx_isinf mlx_isinf_mlx_gen_orig_ #define mlx_isnan mlx_isnan_mlx_gen_orig_ #define mlx_isneginf mlx_isneginf_mlx_gen_orig_ #define mlx_isposinf mlx_isposinf_mlx_gen_orig_ #define mlx_kron mlx_kron_mlx_gen_orig_ #define mlx_left_shift mlx_left_shift_mlx_gen_orig_ #define mlx_less mlx_less_mlx_gen_orig_ #define mlx_less_equal mlx_less_equal_mlx_gen_orig_ #define mlx_linspace mlx_linspace_mlx_gen_orig_ #define mlx_log mlx_log_mlx_gen_orig_ #define mlx_log10 mlx_log10_mlx_gen_orig_ #define mlx_log1p mlx_log1p_mlx_gen_orig_ #define mlx_log2 mlx_log2_mlx_gen_orig_ #define mlx_logaddexp mlx_logaddexp_mlx_gen_orig_ #define mlx_logcumsumexp mlx_logcumsumexp_mlx_gen_orig_ #define mlx_logical_and mlx_logical_and_mlx_gen_orig_ #define mlx_logical_not mlx_logical_not_mlx_gen_orig_ #define mlx_logical_or mlx_logical_or_mlx_gen_orig_ #define mlx_logsumexp_axes mlx_logsumexp_axes_mlx_gen_orig_ #define mlx_logsumexp_axis mlx_logsumexp_axis_mlx_gen_orig_ #define mlx_logsumexp mlx_logsumexp_mlx_gen_orig_ #define mlx_masked_scatter mlx_masked_scatter_mlx_gen_orig_ #define mlx_matmul mlx_matmul_mlx_gen_orig_ #define mlx_max_axes mlx_max_axes_mlx_gen_orig_ #define mlx_max_axis mlx_max_axis_mlx_gen_orig_ #define mlx_max mlx_max_mlx_gen_orig_ #define mlx_maximum mlx_maximum_mlx_gen_orig_ #define mlx_mean_axes mlx_mean_axes_mlx_gen_orig_ #define mlx_mean_axis mlx_mean_axis_mlx_gen_orig_ #define mlx_mean mlx_mean_mlx_gen_orig_ #define mlx_median mlx_median_mlx_gen_orig_ #define mlx_meshgrid mlx_meshgrid_mlx_gen_orig_ #define mlx_min_axes mlx_min_axes_mlx_gen_orig_ #define mlx_min_axis mlx_min_axis_mlx_gen_orig_ #define mlx_min mlx_min_mlx_gen_orig_ #define mlx_minimum mlx_minimum_mlx_gen_orig_ #define mlx_moveaxis mlx_moveaxis_mlx_gen_orig_ #define mlx_multiply mlx_multiply_mlx_gen_orig_ #define mlx_nan_to_num mlx_nan_to_num_mlx_gen_orig_ #define mlx_negative mlx_negative_mlx_gen_orig_ #define mlx_not_equal mlx_not_equal_mlx_gen_orig_ #define mlx_number_of_elements mlx_number_of_elements_mlx_gen_orig_ #define mlx_ones mlx_ones_mlx_gen_orig_ #define mlx_ones_like mlx_ones_like_mlx_gen_orig_ #define mlx_outer mlx_outer_mlx_gen_orig_ #define mlx_pad mlx_pad_mlx_gen_orig_ #define mlx_pad_symmetric mlx_pad_symmetric_mlx_gen_orig_ #define mlx_partition_axis mlx_partition_axis_mlx_gen_orig_ #define mlx_partition mlx_partition_mlx_gen_orig_ #define mlx_power mlx_power_mlx_gen_orig_ #define mlx_prod_axes mlx_prod_axes_mlx_gen_orig_ #define mlx_prod_axis mlx_prod_axis_mlx_gen_orig_ #define mlx_prod mlx_prod_mlx_gen_orig_ #define mlx_put_along_axis mlx_put_along_axis_mlx_gen_orig_ #define mlx_qqmm mlx_qqmm_mlx_gen_orig_ #define mlx_quantize mlx_quantize_mlx_gen_orig_ #define mlx_quantized_matmul mlx_quantized_matmul_mlx_gen_orig_ #define mlx_radians mlx_radians_mlx_gen_orig_ #define mlx_real mlx_real_mlx_gen_orig_ #define mlx_reciprocal mlx_reciprocal_mlx_gen_orig_ #define mlx_remainder mlx_remainder_mlx_gen_orig_ #define mlx_repeat_axis mlx_repeat_axis_mlx_gen_orig_ #define mlx_repeat mlx_repeat_mlx_gen_orig_ #define mlx_reshape mlx_reshape_mlx_gen_orig_ #define mlx_right_shift mlx_right_shift_mlx_gen_orig_ #define mlx_roll_axis mlx_roll_axis_mlx_gen_orig_ #define mlx_roll_axes mlx_roll_axes_mlx_gen_orig_ #define mlx_roll mlx_roll_mlx_gen_orig_ #define mlx_round mlx_round_mlx_gen_orig_ #define mlx_rsqrt mlx_rsqrt_mlx_gen_orig_ #define mlx_scatter mlx_scatter_mlx_gen_orig_ #define mlx_scatter_single mlx_scatter_single_mlx_gen_orig_ #define mlx_scatter_add mlx_scatter_add_mlx_gen_orig_ #define mlx_scatter_add_single mlx_scatter_add_single_mlx_gen_orig_ #define mlx_scatter_add_axis mlx_scatter_add_axis_mlx_gen_orig_ #define mlx_scatter_max mlx_scatter_max_mlx_gen_orig_ #define mlx_scatter_max_single mlx_scatter_max_single_mlx_gen_orig_ #define mlx_scatter_min mlx_scatter_min_mlx_gen_orig_ #define mlx_scatter_min_single mlx_scatter_min_single_mlx_gen_orig_ #define mlx_scatter_prod mlx_scatter_prod_mlx_gen_orig_ #define mlx_scatter_prod_single mlx_scatter_prod_single_mlx_gen_orig_ #define mlx_segmented_mm mlx_segmented_mm_mlx_gen_orig_ #define mlx_sigmoid mlx_sigmoid_mlx_gen_orig_ #define mlx_sign mlx_sign_mlx_gen_orig_ #define mlx_sin mlx_sin_mlx_gen_orig_ #define mlx_sinh mlx_sinh_mlx_gen_orig_ #define mlx_slice mlx_slice_mlx_gen_orig_ #define mlx_slice_dynamic mlx_slice_dynamic_mlx_gen_orig_ #define mlx_slice_update mlx_slice_update_mlx_gen_orig_ #define mlx_slice_update_dynamic mlx_slice_update_dynamic_mlx_gen_orig_ #define mlx_softmax_axes mlx_softmax_axes_mlx_gen_orig_ #define mlx_softmax_axis mlx_softmax_axis_mlx_gen_orig_ #define mlx_softmax mlx_softmax_mlx_gen_orig_ #define mlx_sort_axis mlx_sort_axis_mlx_gen_orig_ #define mlx_sort mlx_sort_mlx_gen_orig_ #define mlx_split mlx_split_mlx_gen_orig_ #define mlx_split_sections mlx_split_sections_mlx_gen_orig_ #define mlx_sqrt mlx_sqrt_mlx_gen_orig_ #define mlx_square mlx_square_mlx_gen_orig_ #define mlx_squeeze_axes mlx_squeeze_axes_mlx_gen_orig_ #define mlx_squeeze_axis mlx_squeeze_axis_mlx_gen_orig_ #define mlx_squeeze mlx_squeeze_mlx_gen_orig_ #define mlx_stack_axis mlx_stack_axis_mlx_gen_orig_ #define mlx_stack mlx_stack_mlx_gen_orig_ #define mlx_std_axes mlx_std_axes_mlx_gen_orig_ #define mlx_std_axis mlx_std_axis_mlx_gen_orig_ #define mlx_std mlx_std_mlx_gen_orig_ #define mlx_stop_gradient mlx_stop_gradient_mlx_gen_orig_ #define mlx_subtract mlx_subtract_mlx_gen_orig_ #define mlx_sum_axes mlx_sum_axes_mlx_gen_orig_ #define mlx_sum_axis mlx_sum_axis_mlx_gen_orig_ #define mlx_sum mlx_sum_mlx_gen_orig_ #define mlx_swapaxes mlx_swapaxes_mlx_gen_orig_ #define mlx_take_axis mlx_take_axis_mlx_gen_orig_ #define mlx_take mlx_take_mlx_gen_orig_ #define mlx_take_along_axis mlx_take_along_axis_mlx_gen_orig_ #define mlx_tan mlx_tan_mlx_gen_orig_ #define mlx_tanh mlx_tanh_mlx_gen_orig_ #define mlx_tensordot mlx_tensordot_mlx_gen_orig_ #define mlx_tensordot_axis mlx_tensordot_axis_mlx_gen_orig_ #define mlx_tile mlx_tile_mlx_gen_orig_ #define mlx_to_fp8 mlx_to_fp8_mlx_gen_orig_ #define mlx_topk_axis mlx_topk_axis_mlx_gen_orig_ #define mlx_topk mlx_topk_mlx_gen_orig_ #define mlx_trace mlx_trace_mlx_gen_orig_ #define mlx_transpose_axes mlx_transpose_axes_mlx_gen_orig_ #define mlx_transpose mlx_transpose_mlx_gen_orig_ #define mlx_tri mlx_tri_mlx_gen_orig_ #define mlx_tril mlx_tril_mlx_gen_orig_ #define mlx_triu mlx_triu_mlx_gen_orig_ #define mlx_unflatten mlx_unflatten_mlx_gen_orig_ #define mlx_var_axes mlx_var_axes_mlx_gen_orig_ #define mlx_var_axis mlx_var_axis_mlx_gen_orig_ #define mlx_var mlx_var_mlx_gen_orig_ #define mlx_view mlx_view_mlx_gen_orig_ #define mlx_where mlx_where_mlx_gen_orig_ #define mlx_zeros mlx_zeros_mlx_gen_orig_ #define mlx_zeros_like mlx_zeros_like_mlx_gen_orig_ #define mlx_random_bernoulli mlx_random_bernoulli_mlx_gen_orig_ #define mlx_random_bits mlx_random_bits_mlx_gen_orig_ #define mlx_random_categorical_shape mlx_random_categorical_shape_mlx_gen_orig_ #define mlx_random_categorical_num_samples mlx_random_categorical_num_samples_mlx_gen_orig_ #define mlx_random_categorical mlx_random_categorical_mlx_gen_orig_ #define mlx_random_gumbel mlx_random_gumbel_mlx_gen_orig_ #define mlx_random_key mlx_random_key_mlx_gen_orig_ #define mlx_random_laplace mlx_random_laplace_mlx_gen_orig_ #define mlx_random_multivariate_normal mlx_random_multivariate_normal_mlx_gen_orig_ #define mlx_random_normal_broadcast mlx_random_normal_broadcast_mlx_gen_orig_ #define mlx_random_normal mlx_random_normal_mlx_gen_orig_ #define mlx_random_permutation mlx_random_permutation_mlx_gen_orig_ #define mlx_random_permutation_arange mlx_random_permutation_arange_mlx_gen_orig_ #define mlx_random_randint mlx_random_randint_mlx_gen_orig_ #define mlx_random_seed mlx_random_seed_mlx_gen_orig_ #define mlx_random_split_num mlx_random_split_num_mlx_gen_orig_ #define mlx_random_split mlx_random_split_mlx_gen_orig_ #define mlx_random_truncated_normal mlx_random_truncated_normal_mlx_gen_orig_ #define mlx_random_uniform mlx_random_uniform_mlx_gen_orig_ #define mlx_stream_new mlx_stream_new_mlx_gen_orig_ #define mlx_stream_new_device mlx_stream_new_device_mlx_gen_orig_ #define mlx_stream_set mlx_stream_set_mlx_gen_orig_ #define mlx_stream_free mlx_stream_free_mlx_gen_orig_ #define mlx_stream_tostring mlx_stream_tostring_mlx_gen_orig_ #define mlx_stream_equal mlx_stream_equal_mlx_gen_orig_ #define mlx_stream_get_device mlx_stream_get_device_mlx_gen_orig_ #define mlx_stream_get_index mlx_stream_get_index_mlx_gen_orig_ #define mlx_synchronize mlx_synchronize_mlx_gen_orig_ #define mlx_get_default_stream mlx_get_default_stream_mlx_gen_orig_ #define mlx_set_default_stream mlx_set_default_stream_mlx_gen_orig_ #define mlx_default_cpu_stream_new mlx_default_cpu_stream_new_mlx_gen_orig_ #define mlx_default_gpu_stream_new mlx_default_gpu_stream_new_mlx_gen_orig_ #define mlx_string_new mlx_string_new_mlx_gen_orig_ #define mlx_string_new_data mlx_string_new_data_mlx_gen_orig_ #define mlx_string_set mlx_string_set_mlx_gen_orig_ #define mlx_string_data mlx_string_data_mlx_gen_orig_ #define mlx_string_free mlx_string_free_mlx_gen_orig_ #define mlx_async_eval mlx_async_eval_mlx_gen_orig_ #define mlx_checkpoint mlx_checkpoint_mlx_gen_orig_ #define mlx_custom_function mlx_custom_function_mlx_gen_orig_ #define mlx_custom_vjp mlx_custom_vjp_mlx_gen_orig_ #define mlx_eval mlx_eval_mlx_gen_orig_ #define mlx_jvp mlx_jvp_mlx_gen_orig_ #define mlx_value_and_grad mlx_value_and_grad_mlx_gen_orig_ #define mlx_vjp mlx_vjp_mlx_gen_orig_ #define mlx_detail_vmap_replace mlx_detail_vmap_replace_mlx_gen_orig_ #define mlx_detail_vmap_trace mlx_detail_vmap_trace_mlx_gen_orig_ #define mlx_vector_array_new mlx_vector_array_new_mlx_gen_orig_ #define mlx_vector_array_set mlx_vector_array_set_mlx_gen_orig_ #define mlx_vector_array_free mlx_vector_array_free_mlx_gen_orig_ #define mlx_vector_array_new_data mlx_vector_array_new_data_mlx_gen_orig_ #define mlx_vector_array_new_value mlx_vector_array_new_value_mlx_gen_orig_ #define mlx_vector_array_set_data mlx_vector_array_set_data_mlx_gen_orig_ #define mlx_vector_array_set_value mlx_vector_array_set_value_mlx_gen_orig_ #define mlx_vector_array_append_data mlx_vector_array_append_data_mlx_gen_orig_ #define mlx_vector_array_append_value mlx_vector_array_append_value_mlx_gen_orig_ #define mlx_vector_array_size mlx_vector_array_size_mlx_gen_orig_ #define mlx_vector_array_get mlx_vector_array_get_mlx_gen_orig_ #define mlx_vector_vector_array_new mlx_vector_vector_array_new_mlx_gen_orig_ #define mlx_vector_vector_array_set mlx_vector_vector_array_set_mlx_gen_orig_ #define mlx_vector_vector_array_free mlx_vector_vector_array_free_mlx_gen_orig_ #define mlx_vector_vector_array_new_data mlx_vector_vector_array_new_data_mlx_gen_orig_ #define mlx_vector_vector_array_new_value mlx_vector_vector_array_new_value_mlx_gen_orig_ #define mlx_vector_vector_array_set_data mlx_vector_vector_array_set_data_mlx_gen_orig_ #define mlx_vector_vector_array_set_value mlx_vector_vector_array_set_value_mlx_gen_orig_ #define mlx_vector_vector_array_append_data mlx_vector_vector_array_append_data_mlx_gen_orig_ #define mlx_vector_vector_array_append_value mlx_vector_vector_array_append_value_mlx_gen_orig_ #define mlx_vector_vector_array_size mlx_vector_vector_array_size_mlx_gen_orig_ #define mlx_vector_vector_array_get mlx_vector_vector_array_get_mlx_gen_orig_ #define mlx_vector_int_new mlx_vector_int_new_mlx_gen_orig_ #define mlx_vector_int_set mlx_vector_int_set_mlx_gen_orig_ #define mlx_vector_int_free mlx_vector_int_free_mlx_gen_orig_ #define mlx_vector_int_new_data mlx_vector_int_new_data_mlx_gen_orig_ #define mlx_vector_int_new_value mlx_vector_int_new_value_mlx_gen_orig_ #define mlx_vector_int_set_data mlx_vector_int_set_data_mlx_gen_orig_ #define mlx_vector_int_set_value mlx_vector_int_set_value_mlx_gen_orig_ #define mlx_vector_int_append_data mlx_vector_int_append_data_mlx_gen_orig_ #define mlx_vector_int_append_value mlx_vector_int_append_value_mlx_gen_orig_ #define mlx_vector_int_size mlx_vector_int_size_mlx_gen_orig_ #define mlx_vector_int_get mlx_vector_int_get_mlx_gen_orig_ #define mlx_vector_string_new mlx_vector_string_new_mlx_gen_orig_ #define mlx_vector_string_set mlx_vector_string_set_mlx_gen_orig_ #define mlx_vector_string_free mlx_vector_string_free_mlx_gen_orig_ #define mlx_vector_string_new_data mlx_vector_string_new_data_mlx_gen_orig_ #define mlx_vector_string_new_value mlx_vector_string_new_value_mlx_gen_orig_ #define mlx_vector_string_set_data mlx_vector_string_set_data_mlx_gen_orig_ #define mlx_vector_string_set_value mlx_vector_string_set_value_mlx_gen_orig_ #define mlx_vector_string_append_data mlx_vector_string_append_data_mlx_gen_orig_ #define mlx_vector_string_append_value mlx_vector_string_append_value_mlx_gen_orig_ #define mlx_vector_string_size mlx_vector_string_size_mlx_gen_orig_ #define mlx_vector_string_get mlx_vector_string_get_mlx_gen_orig_ #define mlx_version mlx_version_mlx_gen_orig_ #include "mlx/c/mlx.h" #undef mlx_dtype_size #undef mlx_array_tostring #undef mlx_array_new #undef mlx_array_free #undef mlx_array_new_bool #undef mlx_array_new_int #undef mlx_array_new_float32 #undef mlx_array_new_float #undef mlx_array_new_float64 #undef mlx_array_new_double #undef mlx_array_new_complex #undef mlx_array_new_data #undef mlx_array_new_data_managed #undef mlx_array_new_data_managed_payload #undef mlx_array_set #undef mlx_array_set_bool #undef mlx_array_set_int #undef mlx_array_set_float32 #undef mlx_array_set_float #undef mlx_array_set_float64 #undef mlx_array_set_double #undef mlx_array_set_complex #undef mlx_array_set_data #undef mlx_array_itemsize #undef mlx_array_size #undef mlx_array_nbytes #undef mlx_array_ndim #undef mlx_array_shape #undef mlx_array_strides #undef mlx_array_dim #undef mlx_array_dtype #undef mlx_array_eval #undef mlx_array_item_bool #undef mlx_array_item_uint8 #undef mlx_array_item_uint16 #undef mlx_array_item_uint32 #undef mlx_array_item_uint64 #undef mlx_array_item_int8 #undef mlx_array_item_int16 #undef mlx_array_item_int32 #undef mlx_array_item_int64 #undef mlx_array_item_float32 #undef mlx_array_item_float64 #undef mlx_array_item_complex64 #undef mlx_array_item_float16 #undef mlx_array_item_bfloat16 #undef mlx_array_data_bool #undef mlx_array_data_uint8 #undef mlx_array_data_uint16 #undef mlx_array_data_uint32 #undef mlx_array_data_uint64 #undef mlx_array_data_int8 #undef mlx_array_data_int16 #undef mlx_array_data_int32 #undef mlx_array_data_int64 #undef mlx_array_data_float32 #undef mlx_array_data_float64 #undef mlx_array_data_complex64 #undef mlx_array_data_float16 #undef mlx_array_data_bfloat16 #undef _mlx_array_is_available #undef _mlx_array_wait #undef _mlx_array_is_contiguous #undef _mlx_array_is_row_contiguous #undef _mlx_array_is_col_contiguous #undef mlx_closure_new #undef mlx_closure_free #undef mlx_closure_new_func #undef mlx_closure_new_func_payload #undef mlx_closure_set #undef mlx_closure_apply #undef mlx_closure_new_unary #undef mlx_closure_kwargs_new #undef mlx_closure_kwargs_free #undef mlx_closure_kwargs_new_func #undef mlx_closure_kwargs_new_func_payload #undef mlx_closure_kwargs_set #undef mlx_closure_kwargs_apply #undef mlx_closure_value_and_grad_new #undef mlx_closure_value_and_grad_free #undef mlx_closure_value_and_grad_new_func #undef mlx_closure_value_and_grad_new_func_payload #undef mlx_closure_value_and_grad_set #undef mlx_closure_value_and_grad_apply #undef mlx_closure_custom_new #undef mlx_closure_custom_free #undef mlx_closure_custom_new_func #undef mlx_closure_custom_new_func_payload #undef mlx_closure_custom_set #undef mlx_closure_custom_apply #undef mlx_closure_custom_jvp_new #undef mlx_closure_custom_jvp_free #undef mlx_closure_custom_jvp_new_func #undef mlx_closure_custom_jvp_new_func_payload #undef mlx_closure_custom_jvp_set #undef mlx_closure_custom_jvp_apply #undef mlx_closure_custom_vmap_new #undef mlx_closure_custom_vmap_free #undef mlx_closure_custom_vmap_new_func #undef mlx_closure_custom_vmap_new_func_payload #undef mlx_closure_custom_vmap_set #undef mlx_closure_custom_vmap_apply #undef mlx_compile #undef mlx_detail_compile #undef mlx_detail_compile_clear_cache #undef mlx_detail_compile_erase #undef mlx_disable_compile #undef mlx_enable_compile #undef mlx_set_compile_mode #undef mlx_cuda_is_available #undef mlx_device_new #undef mlx_device_new_type #undef mlx_device_free #undef mlx_device_set #undef mlx_device_tostring #undef mlx_device_equal #undef mlx_device_get_index #undef mlx_device_get_type #undef mlx_get_default_device #undef mlx_set_default_device #undef mlx_device_is_available #undef mlx_device_count #undef mlx_device_info_new #undef mlx_device_info_get #undef mlx_device_info_free #undef mlx_device_info_has_key #undef mlx_device_info_is_string #undef mlx_device_info_get_string #undef mlx_device_info_get_size #undef mlx_device_info_get_keys #undef mlx_distributed_all_gather #undef mlx_distributed_all_max #undef mlx_distributed_all_min #undef mlx_distributed_all_sum #undef mlx_distributed_recv #undef mlx_distributed_recv_like #undef mlx_distributed_send #undef mlx_distributed_sum_scatter #undef mlx_distributed_group_rank #undef mlx_distributed_group_size #undef mlx_distributed_group_split #undef mlx_distributed_is_available #undef mlx_distributed_init #undef mlx_set_error_handler #undef _mlx_error #undef mlx_export_function #undef mlx_export_function_kwargs #undef mlx_function_exporter_new #undef mlx_function_exporter_free #undef mlx_function_exporter_apply #undef mlx_function_exporter_apply_kwargs #undef mlx_imported_function_new #undef mlx_imported_function_free #undef mlx_imported_function_apply #undef mlx_imported_function_apply_kwargs #undef mlx_fast_cuda_kernel_config_new #undef mlx_fast_cuda_kernel_config_free #undef mlx_fast_cuda_kernel_config_add_output_arg #undef mlx_fast_cuda_kernel_config_set_grid #undef mlx_fast_cuda_kernel_config_set_thread_group #undef mlx_fast_cuda_kernel_config_set_init_value #undef mlx_fast_cuda_kernel_config_set_verbose #undef mlx_fast_cuda_kernel_config_add_template_arg_dtype #undef mlx_fast_cuda_kernel_config_add_template_arg_int #undef mlx_fast_cuda_kernel_config_add_template_arg_bool #undef mlx_fast_cuda_kernel_new #undef mlx_fast_cuda_kernel_free #undef mlx_fast_cuda_kernel_apply #undef mlx_fast_layer_norm #undef mlx_fast_metal_kernel_config_new #undef mlx_fast_metal_kernel_config_free #undef mlx_fast_metal_kernel_config_add_output_arg #undef mlx_fast_metal_kernel_config_set_grid #undef mlx_fast_metal_kernel_config_set_thread_group #undef mlx_fast_metal_kernel_config_set_init_value #undef mlx_fast_metal_kernel_config_set_verbose #undef mlx_fast_metal_kernel_config_add_template_arg_dtype #undef mlx_fast_metal_kernel_config_add_template_arg_int #undef mlx_fast_metal_kernel_config_add_template_arg_bool #undef mlx_fast_metal_kernel_new #undef mlx_fast_metal_kernel_free #undef mlx_fast_metal_kernel_apply #undef mlx_fast_rms_norm #undef mlx_fast_rope #undef mlx_fast_rope_dynamic #undef mlx_fast_scaled_dot_product_attention #undef mlx_fft_fft #undef mlx_fft_fft2 #undef mlx_fft_fftn #undef mlx_fft_fftshift #undef mlx_fft_ifft #undef mlx_fft_ifft2 #undef mlx_fft_ifftn #undef mlx_fft_ifftshift #undef mlx_fft_irfft #undef mlx_fft_irfft2 #undef mlx_fft_irfftn #undef mlx_fft_rfft #undef mlx_fft_rfft2 #undef mlx_fft_rfftn #undef mlx_load_reader #undef mlx_load #undef mlx_load_safetensors_reader #undef mlx_load_safetensors #undef mlx_save_writer #undef mlx_save #undef mlx_save_safetensors_writer #undef mlx_save_safetensors #undef mlx_io_reader_new #undef mlx_io_reader_descriptor #undef mlx_io_reader_tostring #undef mlx_io_reader_free #undef mlx_io_writer_new #undef mlx_io_writer_descriptor #undef mlx_io_writer_tostring #undef mlx_io_writer_free #undef mlx_linalg_cholesky #undef mlx_linalg_cholesky_inv #undef mlx_linalg_cross #undef mlx_linalg_eig #undef mlx_linalg_eigh #undef mlx_linalg_eigvals #undef mlx_linalg_eigvalsh #undef mlx_linalg_inv #undef mlx_linalg_lu #undef mlx_linalg_lu_factor #undef mlx_linalg_norm #undef mlx_linalg_norm_matrix #undef mlx_linalg_norm_l2 #undef mlx_linalg_pinv #undef mlx_linalg_qr #undef mlx_linalg_solve #undef mlx_linalg_solve_triangular #undef mlx_linalg_svd #undef mlx_linalg_tri_inv #undef mlx_map_string_to_array_new #undef mlx_map_string_to_array_set #undef mlx_map_string_to_array_free #undef mlx_map_string_to_array_insert #undef mlx_map_string_to_array_get #undef mlx_map_string_to_array_iterator_new #undef mlx_map_string_to_array_iterator_free #undef mlx_map_string_to_array_iterator_next #undef mlx_map_string_to_string_new #undef mlx_map_string_to_string_set #undef mlx_map_string_to_string_free #undef mlx_map_string_to_string_insert #undef mlx_map_string_to_string_get #undef mlx_map_string_to_string_iterator_new #undef mlx_map_string_to_string_iterator_free #undef mlx_map_string_to_string_iterator_next #undef mlx_clear_cache #undef mlx_get_active_memory #undef mlx_get_cache_memory #undef mlx_get_memory_limit #undef mlx_get_peak_memory #undef mlx_reset_peak_memory #undef mlx_set_cache_limit #undef mlx_set_memory_limit #undef mlx_set_wired_limit #undef mlx_metal_is_available #undef mlx_metal_start_capture #undef mlx_metal_stop_capture #undef mlx_abs #undef mlx_add #undef mlx_addmm #undef mlx_all_axes #undef mlx_all_axis #undef mlx_all #undef mlx_allclose #undef mlx_any_axes #undef mlx_any_axis #undef mlx_any #undef mlx_arange #undef mlx_arccos #undef mlx_arccosh #undef mlx_arcsin #undef mlx_arcsinh #undef mlx_arctan #undef mlx_arctan2 #undef mlx_arctanh #undef mlx_argmax_axis #undef mlx_argmax #undef mlx_argmin_axis #undef mlx_argmin #undef mlx_argpartition_axis #undef mlx_argpartition #undef mlx_argsort_axis #undef mlx_argsort #undef mlx_array_equal #undef mlx_as_strided #undef mlx_astype #undef mlx_atleast_1d #undef mlx_atleast_2d #undef mlx_atleast_3d #undef mlx_bartlett #undef mlx_bitwise_and #undef mlx_bitwise_invert #undef mlx_bitwise_or #undef mlx_bitwise_xor #undef mlx_blackman #undef mlx_block_masked_mm #undef mlx_broadcast_arrays #undef mlx_broadcast_to #undef mlx_ceil #undef mlx_clip #undef mlx_concatenate_axis #undef mlx_concatenate #undef mlx_conjugate #undef mlx_contiguous #undef mlx_conv1d #undef mlx_conv2d #undef mlx_conv3d #undef mlx_conv_general #undef mlx_conv_transpose1d #undef mlx_conv_transpose2d #undef mlx_conv_transpose3d #undef mlx_copy #undef mlx_cos #undef mlx_cosh #undef mlx_cummax #undef mlx_cummin #undef mlx_cumprod #undef mlx_cumsum #undef mlx_degrees #undef mlx_depends #undef mlx_dequantize #undef mlx_diag #undef mlx_diagonal #undef mlx_divide #undef mlx_divmod #undef mlx_einsum #undef mlx_equal #undef mlx_erf #undef mlx_erfinv #undef mlx_exp #undef mlx_expand_dims_axes #undef mlx_expand_dims #undef mlx_expm1 #undef mlx_eye #undef mlx_flatten #undef mlx_floor #undef mlx_floor_divide #undef mlx_from_fp8 #undef mlx_full #undef mlx_full_like #undef mlx_gather #undef mlx_gather_single #undef mlx_gather_mm #undef mlx_gather_qmm #undef mlx_greater #undef mlx_greater_equal #undef mlx_hadamard_transform #undef mlx_hamming #undef mlx_hanning #undef mlx_identity #undef mlx_imag #undef mlx_inner #undef mlx_isclose #undef mlx_isfinite #undef mlx_isinf #undef mlx_isnan #undef mlx_isneginf #undef mlx_isposinf #undef mlx_kron #undef mlx_left_shift #undef mlx_less #undef mlx_less_equal #undef mlx_linspace #undef mlx_log #undef mlx_log10 #undef mlx_log1p #undef mlx_log2 #undef mlx_logaddexp #undef mlx_logcumsumexp #undef mlx_logical_and #undef mlx_logical_not #undef mlx_logical_or #undef mlx_logsumexp_axes #undef mlx_logsumexp_axis #undef mlx_logsumexp #undef mlx_masked_scatter #undef mlx_matmul #undef mlx_max_axes #undef mlx_max_axis #undef mlx_max #undef mlx_maximum #undef mlx_mean_axes #undef mlx_mean_axis #undef mlx_mean #undef mlx_median #undef mlx_meshgrid #undef mlx_min_axes #undef mlx_min_axis #undef mlx_min #undef mlx_minimum #undef mlx_moveaxis #undef mlx_multiply #undef mlx_nan_to_num #undef mlx_negative #undef mlx_not_equal #undef mlx_number_of_elements #undef mlx_ones #undef mlx_ones_like #undef mlx_outer #undef mlx_pad #undef mlx_pad_symmetric #undef mlx_partition_axis #undef mlx_partition #undef mlx_power #undef mlx_prod_axes #undef mlx_prod_axis #undef mlx_prod #undef mlx_put_along_axis #undef mlx_qqmm #undef mlx_quantize #undef mlx_quantized_matmul #undef mlx_radians #undef mlx_real #undef mlx_reciprocal #undef mlx_remainder #undef mlx_repeat_axis #undef mlx_repeat #undef mlx_reshape #undef mlx_right_shift #undef mlx_roll_axis #undef mlx_roll_axes #undef mlx_roll #undef mlx_round #undef mlx_rsqrt #undef mlx_scatter #undef mlx_scatter_single #undef mlx_scatter_add #undef mlx_scatter_add_single #undef mlx_scatter_add_axis #undef mlx_scatter_max #undef mlx_scatter_max_single #undef mlx_scatter_min #undef mlx_scatter_min_single #undef mlx_scatter_prod #undef mlx_scatter_prod_single #undef mlx_segmented_mm #undef mlx_sigmoid #undef mlx_sign #undef mlx_sin #undef mlx_sinh #undef mlx_slice #undef mlx_slice_dynamic #undef mlx_slice_update #undef mlx_slice_update_dynamic #undef mlx_softmax_axes #undef mlx_softmax_axis #undef mlx_softmax #undef mlx_sort_axis #undef mlx_sort #undef mlx_split #undef mlx_split_sections #undef mlx_sqrt #undef mlx_square #undef mlx_squeeze_axes #undef mlx_squeeze_axis #undef mlx_squeeze #undef mlx_stack_axis #undef mlx_stack #undef mlx_std_axes #undef mlx_std_axis #undef mlx_std #undef mlx_stop_gradient #undef mlx_subtract #undef mlx_sum_axes #undef mlx_sum_axis #undef mlx_sum #undef mlx_swapaxes #undef mlx_take_axis #undef mlx_take #undef mlx_take_along_axis #undef mlx_tan #undef mlx_tanh #undef mlx_tensordot #undef mlx_tensordot_axis #undef mlx_tile #undef mlx_to_fp8 #undef mlx_topk_axis #undef mlx_topk #undef mlx_trace #undef mlx_transpose_axes #undef mlx_transpose #undef mlx_tri #undef mlx_tril #undef mlx_triu #undef mlx_unflatten #undef mlx_var_axes #undef mlx_var_axis #undef mlx_var #undef mlx_view #undef mlx_where #undef mlx_zeros #undef mlx_zeros_like #undef mlx_random_bernoulli #undef mlx_random_bits #undef mlx_random_categorical_shape #undef mlx_random_categorical_num_samples #undef mlx_random_categorical #undef mlx_random_gumbel #undef mlx_random_key #undef mlx_random_laplace #undef mlx_random_multivariate_normal #undef mlx_random_normal_broadcast #undef mlx_random_normal #undef mlx_random_permutation #undef mlx_random_permutation_arange #undef mlx_random_randint #undef mlx_random_seed #undef mlx_random_split_num #undef mlx_random_split #undef mlx_random_truncated_normal #undef mlx_random_uniform #undef mlx_stream_new #undef mlx_stream_new_device #undef mlx_stream_set #undef mlx_stream_free #undef mlx_stream_tostring #undef mlx_stream_equal #undef mlx_stream_get_device #undef mlx_stream_get_index #undef mlx_synchronize #undef mlx_get_default_stream #undef mlx_set_default_stream #undef mlx_default_cpu_stream_new #undef mlx_default_gpu_stream_new #undef mlx_string_new #undef mlx_string_new_data #undef mlx_string_set #undef mlx_string_data #undef mlx_string_free #undef mlx_async_eval #undef mlx_checkpoint #undef mlx_custom_function #undef mlx_custom_vjp #undef mlx_eval #undef mlx_jvp #undef mlx_value_and_grad #undef mlx_vjp #undef mlx_detail_vmap_replace #undef mlx_detail_vmap_trace #undef mlx_vector_array_new #undef mlx_vector_array_set #undef mlx_vector_array_free #undef mlx_vector_array_new_data #undef mlx_vector_array_new_value #undef mlx_vector_array_set_data #undef mlx_vector_array_set_value #undef mlx_vector_array_append_data #undef mlx_vector_array_append_value #undef mlx_vector_array_size #undef mlx_vector_array_get #undef mlx_vector_vector_array_new #undef mlx_vector_vector_array_set #undef mlx_vector_vector_array_free #undef mlx_vector_vector_array_new_data #undef mlx_vector_vector_array_new_value #undef mlx_vector_vector_array_set_data #undef mlx_vector_vector_array_set_value #undef mlx_vector_vector_array_append_data #undef mlx_vector_vector_array_append_value #undef mlx_vector_vector_array_size #undef mlx_vector_vector_array_get #undef mlx_vector_int_new #undef mlx_vector_int_set #undef mlx_vector_int_free #undef mlx_vector_int_new_data #undef mlx_vector_int_new_value #undef mlx_vector_int_set_data #undef mlx_vector_int_set_value #undef mlx_vector_int_append_data #undef mlx_vector_int_append_value #undef mlx_vector_int_size #undef mlx_vector_int_get #undef mlx_vector_string_new #undef mlx_vector_string_set #undef mlx_vector_string_free #undef mlx_vector_string_new_data #undef mlx_vector_string_new_value #undef mlx_vector_string_set_data #undef mlx_vector_string_set_value #undef mlx_vector_string_append_data #undef mlx_vector_string_append_value #undef mlx_vector_string_size #undef mlx_vector_string_get #undef mlx_version extern size_t (*mlx_dtype_size_)(mlx_dtype dtype); extern int (*mlx_array_tostring_)(mlx_string* str, const mlx_array arr); extern mlx_array (*mlx_array_new_)(void); extern int (*mlx_array_free_)(mlx_array arr); extern mlx_array (*mlx_array_new_bool_)(bool val); extern mlx_array (*mlx_array_new_int_)(int val); extern mlx_array (*mlx_array_new_float32_)(float val); extern mlx_array (*mlx_array_new_float_)(float val); extern mlx_array (*mlx_array_new_float64_)(double val); extern mlx_array (*mlx_array_new_double_)(double val); extern mlx_array (*mlx_array_new_complex_)(float real_val, float imag_val); extern mlx_array (*mlx_array_new_data_)( const void* data, const int* shape, int dim, mlx_dtype dtype); extern mlx_array (*mlx_array_new_data_managed_)( void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)); extern mlx_array (*mlx_array_new_data_managed_payload_)( void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)); extern int (*mlx_array_set_)(mlx_array* arr, const mlx_array src); extern int (*mlx_array_set_bool_)(mlx_array* arr, bool val); extern int (*mlx_array_set_int_)(mlx_array* arr, int val); extern int (*mlx_array_set_float32_)(mlx_array* arr, float val); extern int (*mlx_array_set_float_)(mlx_array* arr, float val); extern int (*mlx_array_set_float64_)(mlx_array* arr, double val); extern int (*mlx_array_set_double_)(mlx_array* arr, double val); extern int (*mlx_array_set_complex_)(mlx_array* arr, float real_val, float imag_val); extern int (*mlx_array_set_data_)( mlx_array* arr, const void* data, const int* shape, int dim, mlx_dtype dtype); extern size_t (*mlx_array_itemsize_)(const mlx_array arr); extern size_t (*mlx_array_size_)(const mlx_array arr); extern size_t (*mlx_array_nbytes_)(const mlx_array arr); extern size_t (*mlx_array_ndim_)(const mlx_array arr); extern const int * (*mlx_array_shape_)(const mlx_array arr); extern const size_t * (*mlx_array_strides_)(const mlx_array arr); extern int (*mlx_array_dim_)(const mlx_array arr, int dim); extern mlx_dtype (*mlx_array_dtype_)(const mlx_array arr); extern int (*mlx_array_eval_)(mlx_array arr); extern int (*mlx_array_item_bool_)(bool* res, const mlx_array arr); extern int (*mlx_array_item_uint8_)(uint8_t* res, const mlx_array arr); extern int (*mlx_array_item_uint16_)(uint16_t* res, const mlx_array arr); extern int (*mlx_array_item_uint32_)(uint32_t* res, const mlx_array arr); extern int (*mlx_array_item_uint64_)(uint64_t* res, const mlx_array arr); extern int (*mlx_array_item_int8_)(int8_t* res, const mlx_array arr); extern int (*mlx_array_item_int16_)(int16_t* res, const mlx_array arr); extern int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr); extern int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr); extern int (*mlx_array_item_float32_)(float* res, const mlx_array arr); extern int (*mlx_array_item_float64_)(double* res, const mlx_array arr); extern int (*mlx_array_item_complex64_)(mlx_complex64_t* res, const mlx_array arr); extern int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr); extern int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr); extern const bool * (*mlx_array_data_bool_)(const mlx_array arr); extern const uint8_t * (*mlx_array_data_uint8_)(const mlx_array arr); extern const uint16_t * (*mlx_array_data_uint16_)(const mlx_array arr); extern const uint32_t * (*mlx_array_data_uint32_)(const mlx_array arr); extern const uint64_t * (*mlx_array_data_uint64_)(const mlx_array arr); extern const int8_t * (*mlx_array_data_int8_)(const mlx_array arr); extern const int16_t * (*mlx_array_data_int16_)(const mlx_array arr); extern const int32_t * (*mlx_array_data_int32_)(const mlx_array arr); extern const int64_t * (*mlx_array_data_int64_)(const mlx_array arr); extern const float * (*mlx_array_data_float32_)(const mlx_array arr); extern const double * (*mlx_array_data_float64_)(const mlx_array arr); extern const mlx_complex64_t * (*mlx_array_data_complex64_)(const mlx_array arr); extern const float16_t * (*mlx_array_data_float16_)(const mlx_array arr); extern const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr); extern int (*_mlx_array_is_available_)(bool* res, const mlx_array arr); extern int (*_mlx_array_wait_)(const mlx_array arr); extern int (*_mlx_array_is_contiguous_)(bool* res, const mlx_array arr); extern int (*_mlx_array_is_row_contiguous_)(bool* res, const mlx_array arr); extern int (*_mlx_array_is_col_contiguous_)(bool* res, const mlx_array arr); extern mlx_closure (*mlx_closure_new_)(void); extern int (*mlx_closure_free_)(mlx_closure cls); extern mlx_closure (*mlx_closure_new_func_)( int (*fun)(mlx_vector_array*, const mlx_vector_array)); extern mlx_closure (*mlx_closure_new_func_payload_)( int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)); extern int (*mlx_closure_set_)(mlx_closure* cls, const mlx_closure src); extern int (*mlx_closure_apply_)( mlx_vector_array* res, mlx_closure cls, const mlx_vector_array input); extern mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array)); extern mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void); extern int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls); extern mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)( int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array)); extern mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)( int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array, void*), void* payload, void (*dtor)(void*)); extern int (*mlx_closure_kwargs_set_)( mlx_closure_kwargs* cls, const mlx_closure_kwargs src); extern int (*mlx_closure_kwargs_apply_)( mlx_vector_array* res, mlx_closure_kwargs cls, const mlx_vector_array input_0, const mlx_map_string_to_array input_1); extern mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_)(void); extern int (*mlx_closure_value_and_grad_free_)(mlx_closure_value_and_grad cls); extern mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_)( int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)); extern mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_payload_)( int (*fun)( mlx_vector_array*, mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)); extern int (*mlx_closure_value_and_grad_set_)( mlx_closure_value_and_grad* cls, const mlx_closure_value_and_grad src); extern int (*mlx_closure_value_and_grad_apply_)( mlx_vector_array* res_0, mlx_vector_array* res_1, mlx_closure_value_and_grad cls, const mlx_vector_array input); extern mlx_closure_custom (*mlx_closure_custom_new_)(void); extern int (*mlx_closure_custom_free_)(mlx_closure_custom cls); extern mlx_closure_custom (*mlx_closure_custom_new_func_)( int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array)); extern mlx_closure_custom (*mlx_closure_custom_new_func_payload_)( int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)); extern int (*mlx_closure_custom_set_)( mlx_closure_custom* cls, const mlx_closure_custom src); extern int (*mlx_closure_custom_apply_)( mlx_vector_array* res, mlx_closure_custom cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const mlx_vector_array input_2); extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void); extern int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls); extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)( int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num)); extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)( int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)); extern int (*mlx_closure_custom_jvp_set_)( mlx_closure_custom_jvp* cls, const mlx_closure_custom_jvp src); extern int (*mlx_closure_custom_jvp_apply_)( mlx_vector_array* res, mlx_closure_custom_jvp cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const int* input_2, size_t input_2_num); extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void); extern int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls); extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)( int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num)); extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)( int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)); extern int (*mlx_closure_custom_vmap_set_)( mlx_closure_custom_vmap* cls, const mlx_closure_custom_vmap src); extern int (*mlx_closure_custom_vmap_apply_)( mlx_vector_array* res_0, mlx_vector_int* res_1, mlx_closure_custom_vmap cls, const mlx_vector_array input_0, const int* input_1, size_t input_1_num); extern int (*mlx_compile_)(mlx_closure* res, const mlx_closure fun, bool shapeless); extern int (*mlx_detail_compile_)( mlx_closure* res, const mlx_closure fun, uintptr_t fun_id, bool shapeless, const uint64_t* constants, size_t constants_num); extern int (*mlx_detail_compile_clear_cache_)(void); extern int (*mlx_detail_compile_erase_)(uintptr_t fun_id); extern int (*mlx_disable_compile_)(void); extern int (*mlx_enable_compile_)(void); extern int (*mlx_set_compile_mode_)(mlx_compile_mode mode); extern int (*mlx_cuda_is_available_)(bool* res); extern mlx_device (*mlx_device_new_)(void); extern mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index); extern int (*mlx_device_free_)(mlx_device dev); extern int (*mlx_device_set_)(mlx_device* dev, const mlx_device src); extern int (*mlx_device_tostring_)(mlx_string* str, mlx_device dev); extern bool (*mlx_device_equal_)(mlx_device lhs, mlx_device rhs); extern int (*mlx_device_get_index_)(int* index, mlx_device dev); extern int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev); extern int (*mlx_get_default_device_)(mlx_device* dev); extern int (*mlx_set_default_device_)(mlx_device dev); extern int (*mlx_device_is_available_)(bool* avail, mlx_device dev); extern int (*mlx_device_count_)(int* count, mlx_device_type type); extern mlx_device_info (*mlx_device_info_new_)(void); extern int (*mlx_device_info_get_)(mlx_device_info* info, mlx_device dev); extern int (*mlx_device_info_free_)(mlx_device_info info); extern int (*mlx_device_info_has_key_)( bool* exists, mlx_device_info info, const char* key); extern int (*mlx_device_info_is_string_)( bool* is_string, mlx_device_info info, const char* key); extern int (*mlx_device_info_get_string_)( const char** value, mlx_device_info info, const char* key); extern int (*mlx_device_info_get_size_)( size_t* value, mlx_device_info info, const char* key); extern int (*mlx_device_info_get_keys_)(mlx_vector_string* keys, mlx_device_info info); extern int (*mlx_distributed_all_gather_)( mlx_array* res, const mlx_array x, const mlx_distributed_group group /* may be null */, const mlx_stream S); extern int (*mlx_distributed_all_max_)( mlx_array* res, const mlx_array x, const mlx_distributed_group group /* may be null */, const mlx_stream s); extern int (*mlx_distributed_all_min_)( mlx_array* res, const mlx_array x, const mlx_distributed_group group /* may be null */, const mlx_stream s); extern int (*mlx_distributed_all_sum_)( mlx_array* res, const mlx_array x, const mlx_distributed_group group /* may be null */, const mlx_stream s); extern int (*mlx_distributed_recv_)( mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, int src, const mlx_distributed_group group /* may be null */, const mlx_stream s); extern int (*mlx_distributed_recv_like_)( mlx_array* res, const mlx_array x, int src, const mlx_distributed_group group /* may be null */, const mlx_stream s); extern int (*mlx_distributed_send_)( mlx_array* res, const mlx_array x, int dst, const mlx_distributed_group group /* may be null */, const mlx_stream s); extern int (*mlx_distributed_sum_scatter_)( mlx_array* res, const mlx_array x, const mlx_distributed_group group /* may be null */, const mlx_stream s); extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group); extern int (*mlx_distributed_group_size_)(mlx_distributed_group group); extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key); extern bool (*mlx_distributed_is_available_)(const char* bk /* may be null */); extern mlx_distributed_group (*mlx_distributed_init_)( bool strict, const char* bk /* may be null */); extern void (*mlx_set_error_handler_)( mlx_error_handler_func handler, void* data, void (*dtor)(void*)); extern void (*_mlx_error_)(const char* file, const int line, const char* fmt, ...); extern int (*mlx_export_function_)( const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless); extern int (*mlx_export_function_kwargs_)( const char* file, const mlx_closure_kwargs fun, const mlx_vector_array args, const mlx_map_string_to_array kwargs, bool shapeless); extern mlx_function_exporter (*mlx_function_exporter_new_)( const char* file, const mlx_closure fun, bool shapeless); extern int (*mlx_function_exporter_free_)(mlx_function_exporter xfunc); extern int (*mlx_function_exporter_apply_)( const mlx_function_exporter xfunc, const mlx_vector_array args); extern int (*mlx_function_exporter_apply_kwargs_)( const mlx_function_exporter xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs); extern mlx_imported_function (*mlx_imported_function_new_)(const char* file); extern int (*mlx_imported_function_free_)(mlx_imported_function xfunc); extern int (*mlx_imported_function_apply_)( mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args); extern int (*mlx_imported_function_apply_kwargs_)( mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs); extern mlx_fast_cuda_kernel_config (*mlx_fast_cuda_kernel_config_new_)(void); extern void (*mlx_fast_cuda_kernel_config_free_)(mlx_fast_cuda_kernel_config cls); extern int (*mlx_fast_cuda_kernel_config_add_output_arg_)( mlx_fast_cuda_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype); extern int (*mlx_fast_cuda_kernel_config_set_grid_)( mlx_fast_cuda_kernel_config cls, int grid1, int grid2, int grid3); extern int (*mlx_fast_cuda_kernel_config_set_thread_group_)( mlx_fast_cuda_kernel_config cls, int thread1, int thread2, int thread3); extern int (*mlx_fast_cuda_kernel_config_set_init_value_)( mlx_fast_cuda_kernel_config cls, float value); extern int (*mlx_fast_cuda_kernel_config_set_verbose_)( mlx_fast_cuda_kernel_config cls, bool verbose); extern int (*mlx_fast_cuda_kernel_config_add_template_arg_dtype_)( mlx_fast_cuda_kernel_config cls, const char* name, mlx_dtype dtype); extern int (*mlx_fast_cuda_kernel_config_add_template_arg_int_)( mlx_fast_cuda_kernel_config cls, const char* name, int value); extern int (*mlx_fast_cuda_kernel_config_add_template_arg_bool_)( mlx_fast_cuda_kernel_config cls, const char* name, bool value); extern mlx_fast_cuda_kernel (*mlx_fast_cuda_kernel_new_)( const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, int shared_memory); extern void (*mlx_fast_cuda_kernel_free_)(mlx_fast_cuda_kernel cls); extern int (*mlx_fast_cuda_kernel_apply_)( mlx_vector_array* outputs, mlx_fast_cuda_kernel cls, const mlx_vector_array inputs, const mlx_fast_cuda_kernel_config config, const mlx_stream stream); extern int (*mlx_fast_layer_norm_)( mlx_array* res, const mlx_array x, const mlx_array weight /* may be null */, const mlx_array bias /* may be null */, float eps, const mlx_stream s); extern mlx_fast_metal_kernel_config (*mlx_fast_metal_kernel_config_new_)(void); extern void (*mlx_fast_metal_kernel_config_free_)(mlx_fast_metal_kernel_config cls); extern int (*mlx_fast_metal_kernel_config_add_output_arg_)( mlx_fast_metal_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype); extern int (*mlx_fast_metal_kernel_config_set_grid_)( mlx_fast_metal_kernel_config cls, int grid1, int grid2, int grid3); extern int (*mlx_fast_metal_kernel_config_set_thread_group_)( mlx_fast_metal_kernel_config cls, int thread1, int thread2, int thread3); extern int (*mlx_fast_metal_kernel_config_set_init_value_)( mlx_fast_metal_kernel_config cls, float value); extern int (*mlx_fast_metal_kernel_config_set_verbose_)( mlx_fast_metal_kernel_config cls, bool verbose); extern int (*mlx_fast_metal_kernel_config_add_template_arg_dtype_)( mlx_fast_metal_kernel_config cls, const char* name, mlx_dtype dtype); extern int (*mlx_fast_metal_kernel_config_add_template_arg_int_)( mlx_fast_metal_kernel_config cls, const char* name, int value); extern int (*mlx_fast_metal_kernel_config_add_template_arg_bool_)( mlx_fast_metal_kernel_config cls, const char* name, bool value); extern mlx_fast_metal_kernel (*mlx_fast_metal_kernel_new_)( const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, bool atomic_outputs); extern void (*mlx_fast_metal_kernel_free_)(mlx_fast_metal_kernel cls); extern int (*mlx_fast_metal_kernel_apply_)( mlx_vector_array* outputs, mlx_fast_metal_kernel cls, const mlx_vector_array inputs, const mlx_fast_metal_kernel_config config, const mlx_stream stream); extern int (*mlx_fast_rms_norm_)( mlx_array* res, const mlx_array x, const mlx_array weight /* may be null */, float eps, const mlx_stream s); extern int (*mlx_fast_rope_)( mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, int offset, const mlx_array freqs /* may be null */, const mlx_stream s); extern int (*mlx_fast_rope_dynamic_)( mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, const mlx_array offset, const mlx_array freqs /* may be null */, const mlx_stream s); extern int (*mlx_fast_scaled_dot_product_attention_)( mlx_array* res, const mlx_array queries, const mlx_array keys, const mlx_array values, float scale, const char* mask_mode, const mlx_array mask_arr /* may be null */, const mlx_array sinks /* may be null */, const mlx_stream s); extern int (*mlx_fft_fft_)( mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); extern int (*mlx_fft_fft2_)( mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_fft_fftn_)( mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_fft_fftshift_)( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_fft_ifft_)( mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); extern int (*mlx_fft_ifft2_)( mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_fft_ifftn_)( mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_fft_ifftshift_)( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_fft_irfft_)( mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); extern int (*mlx_fft_irfft2_)( mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_fft_irfftn_)( mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_fft_rfft_)( mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); extern int (*mlx_fft_rfft2_)( mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_fft_rfftn_)( mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_load_reader_)( mlx_array* res, mlx_io_reader in_stream, const mlx_stream s); extern int (*mlx_load_)(mlx_array* res, const char* file, const mlx_stream s); extern int (*mlx_load_safetensors_reader_)( mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, mlx_io_reader in_stream, const mlx_stream s); extern int (*mlx_load_safetensors_)( mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, const char* file, const mlx_stream s); extern int (*mlx_save_writer_)(mlx_io_writer out_stream, const mlx_array a); extern int (*mlx_save_)(const char* file, const mlx_array a); extern int (*mlx_save_safetensors_writer_)( mlx_io_writer in_stream, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata); extern int (*mlx_save_safetensors_)( const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata); extern mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable); extern int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io); extern int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io); extern int (*mlx_io_reader_free_)(mlx_io_reader io); extern mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable); extern int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io); extern int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io); extern int (*mlx_io_writer_free_)(mlx_io_writer io); extern int (*mlx_linalg_cholesky_)( mlx_array* res, const mlx_array a, bool upper, const mlx_stream s); extern int (*mlx_linalg_cholesky_inv_)( mlx_array* res, const mlx_array a, bool upper, const mlx_stream s); extern int (*mlx_linalg_cross_)( mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s); extern int (*mlx_linalg_eig_)( mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s); extern int (*mlx_linalg_eigh_)( mlx_array* res_0, mlx_array* res_1, const mlx_array a, const char* UPLO, const mlx_stream s); extern int (*mlx_linalg_eigvals_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_linalg_eigvalsh_)( mlx_array* res, const mlx_array a, const char* UPLO, const mlx_stream s); extern int (*mlx_linalg_inv_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_linalg_lu_)(mlx_vector_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_linalg_lu_factor_)( mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s); extern int (*mlx_linalg_norm_)( mlx_array* res, const mlx_array a, double ord, const int* axis /* may be null */, size_t axis_num, bool keepdims, const mlx_stream s); extern int (*mlx_linalg_norm_matrix_)( mlx_array* res, const mlx_array a, const char* ord, const int* axis /* may be null */, size_t axis_num, bool keepdims, const mlx_stream s); extern int (*mlx_linalg_norm_l2_)( mlx_array* res, const mlx_array a, const int* axis /* may be null */, size_t axis_num, bool keepdims, const mlx_stream s); extern int (*mlx_linalg_pinv_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_linalg_qr_)( mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s); extern int (*mlx_linalg_solve_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_linalg_solve_triangular_)( mlx_array* res, const mlx_array a, const mlx_array b, bool upper, const mlx_stream s); extern int (*mlx_linalg_svd_)( mlx_vector_array* res, const mlx_array a, bool compute_uv, const mlx_stream s); extern int (*mlx_linalg_tri_inv_)( mlx_array* res, const mlx_array a, bool upper, const mlx_stream s); extern mlx_map_string_to_array (*mlx_map_string_to_array_new_)(void); extern int (*mlx_map_string_to_array_set_)( mlx_map_string_to_array* map, const mlx_map_string_to_array src); extern int (*mlx_map_string_to_array_free_)(mlx_map_string_to_array map); extern int (*mlx_map_string_to_array_insert_)( mlx_map_string_to_array map, const char* key, const mlx_array value); extern int (*mlx_map_string_to_array_get_)( mlx_array* value, const mlx_map_string_to_array map, const char* key); extern mlx_map_string_to_array_iterator (*mlx_map_string_to_array_iterator_new_)( mlx_map_string_to_array map); extern int (*mlx_map_string_to_array_iterator_free_)(mlx_map_string_to_array_iterator it); extern int (*mlx_map_string_to_array_iterator_next_)( const char** key, mlx_array* value, mlx_map_string_to_array_iterator it); extern mlx_map_string_to_string (*mlx_map_string_to_string_new_)(void); extern int (*mlx_map_string_to_string_set_)( mlx_map_string_to_string* map, const mlx_map_string_to_string src); extern int (*mlx_map_string_to_string_free_)(mlx_map_string_to_string map); extern int (*mlx_map_string_to_string_insert_)( mlx_map_string_to_string map, const char* key, const char* value); extern int (*mlx_map_string_to_string_get_)( const char** value, const mlx_map_string_to_string map, const char* key); extern mlx_map_string_to_string_iterator (*mlx_map_string_to_string_iterator_new_)( mlx_map_string_to_string map); extern int (*mlx_map_string_to_string_iterator_free_)( mlx_map_string_to_string_iterator it); extern int (*mlx_map_string_to_string_iterator_next_)( const char** key, const char** value, mlx_map_string_to_string_iterator it); extern int (*mlx_clear_cache_)(void); extern int (*mlx_get_active_memory_)(size_t* res); extern int (*mlx_get_cache_memory_)(size_t* res); extern int (*mlx_get_memory_limit_)(size_t* res); extern int (*mlx_get_peak_memory_)(size_t* res); extern int (*mlx_reset_peak_memory_)(void); extern int (*mlx_set_cache_limit_)(size_t* res, size_t limit); extern int (*mlx_set_memory_limit_)(size_t* res, size_t limit); extern int (*mlx_set_wired_limit_)(size_t* res, size_t limit); extern int (*mlx_metal_is_available_)(bool* res); extern int (*mlx_metal_start_capture_)(const char* path); extern int (*mlx_metal_stop_capture_)(void); extern int (*mlx_abs_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_add_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_addmm_)( mlx_array* res, const mlx_array c, const mlx_array a, const mlx_array b, float alpha, float beta, const mlx_stream s); extern int (*mlx_all_axes_)( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); extern int (*mlx_all_axis_)( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); extern int (*mlx_all_)( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); extern int (*mlx_allclose_)( mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s); extern int (*mlx_any_axes_)( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); extern int (*mlx_any_axis_)( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); extern int (*mlx_any_)( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); extern int (*mlx_arange_)( mlx_array* res, double start, double stop, double step, mlx_dtype dtype, const mlx_stream s); extern int (*mlx_arccos_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_arccosh_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_arcsin_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_arcsinh_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_arctan_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_arctan2_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_arctanh_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_argmax_axis_)( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); extern int (*mlx_argmax_)( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); extern int (*mlx_argmin_axis_)( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); extern int (*mlx_argmin_)( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); extern int (*mlx_argpartition_axis_)( mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s); extern int (*mlx_argpartition_)( mlx_array* res, const mlx_array a, int kth, const mlx_stream s); extern int (*mlx_argsort_axis_)( mlx_array* res, const mlx_array a, int axis, const mlx_stream s); extern int (*mlx_argsort_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_array_equal_)( mlx_array* res, const mlx_array a, const mlx_array b, bool equal_nan, const mlx_stream s); extern int (*mlx_as_strided_)( mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const int64_t* strides, size_t strides_num, size_t offset, const mlx_stream s); extern int (*mlx_astype_)( mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s); extern int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_bartlett_)(mlx_array* res, int M, const mlx_stream s); extern int (*mlx_bitwise_and_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_bitwise_invert_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_bitwise_or_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_bitwise_xor_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_blackman_)(mlx_array* res, int M, const mlx_stream s); extern int (*mlx_block_masked_mm_)( mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out /* may be null */, const mlx_array mask_lhs /* may be null */, const mlx_array mask_rhs /* may be null */, const mlx_stream s); extern int (*mlx_broadcast_arrays_)( mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s); extern int (*mlx_broadcast_to_)( mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s); extern int (*mlx_ceil_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_clip_)( mlx_array* res, const mlx_array a, const mlx_array a_min /* may be null */, const mlx_array a_max /* may be null */, const mlx_stream s); extern int (*mlx_concatenate_axis_)( mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s); extern int (*mlx_concatenate_)( mlx_array* res, const mlx_vector_array arrays, const mlx_stream s); extern int (*mlx_conjugate_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_contiguous_)( mlx_array* res, const mlx_array a, bool allow_col_major, const mlx_stream s); extern int (*mlx_conv1d_)( mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int groups, const mlx_stream s); extern int (*mlx_conv2d_)( mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int groups, const mlx_stream s); extern int (*mlx_conv3d_)( mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int groups, const mlx_stream s); extern int (*mlx_conv_general_)( mlx_array* res, const mlx_array input, const mlx_array weight, const int* stride, size_t stride_num, const int* padding_lo, size_t padding_lo_num, const int* padding_hi, size_t padding_hi_num, const int* kernel_dilation, size_t kernel_dilation_num, const int* input_dilation, size_t input_dilation_num, int groups, bool flip, const mlx_stream s); extern int (*mlx_conv_transpose1d_)( mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int output_padding, int groups, const mlx_stream s); extern int (*mlx_conv_transpose2d_)( mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int output_padding_0, int output_padding_1, int groups, const mlx_stream s); extern int (*mlx_conv_transpose3d_)( mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int output_padding_0, int output_padding_1, int output_padding_2, int groups, const mlx_stream s); extern int (*mlx_copy_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_cos_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_cosh_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_cummax_)( mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s); extern int (*mlx_cummin_)( mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s); extern int (*mlx_cumprod_)( mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s); extern int (*mlx_cumsum_)( mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s); extern int (*mlx_degrees_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_depends_)( mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies); extern int (*mlx_dequantize_)( mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases /* may be null */, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale /* may be null */, mlx_optional_dtype dtype, const mlx_stream s); extern int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s); extern int (*mlx_diagonal_)( mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s); extern int (*mlx_divide_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_divmod_)( mlx_vector_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_einsum_)( mlx_array* res, const char* subscripts, const mlx_vector_array operands, const mlx_stream s); extern int (*mlx_equal_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_erf_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_erfinv_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_exp_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_expand_dims_axes_)( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_expand_dims_)( mlx_array* res, const mlx_array a, int axis, const mlx_stream s); extern int (*mlx_expm1_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_eye_)( mlx_array* res, int n, int m, int k, mlx_dtype dtype, const mlx_stream s); extern int (*mlx_flatten_)( mlx_array* res, const mlx_array a, int start_axis, int end_axis, const mlx_stream s); extern int (*mlx_floor_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_floor_divide_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_from_fp8_)( mlx_array* res, const mlx_array x, mlx_dtype dtype, const mlx_stream s); extern int (*mlx_full_)( mlx_array* res, const int* shape, size_t shape_num, const mlx_array vals, mlx_dtype dtype, const mlx_stream s); extern int (*mlx_full_like_)( mlx_array* res, const mlx_array a, const mlx_array vals, mlx_dtype dtype, const mlx_stream s); extern int (*mlx_gather_)( mlx_array* res, const mlx_array a, const mlx_vector_array indices, const int* axes, size_t axes_num, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s); extern int (*mlx_gather_single_)( mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s); extern int (*mlx_gather_mm_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array lhs_indices /* may be null */, const mlx_array rhs_indices /* may be null */, bool sorted_indices, const mlx_stream s); extern int (*mlx_gather_qmm_)( mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases /* may be null */, const mlx_array lhs_indices /* may be null */, const mlx_array rhs_indices /* may be null */, bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, bool sorted_indices, const mlx_stream s); extern int (*mlx_greater_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_greater_equal_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_hadamard_transform_)( mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s); extern int (*mlx_hamming_)(mlx_array* res, int M, const mlx_stream s); extern int (*mlx_hanning_)(mlx_array* res, int M, const mlx_stream s); extern int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s); extern int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_inner_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_isclose_)( mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s); extern int (*mlx_isfinite_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_isinf_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_isnan_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_isneginf_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_isposinf_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_kron_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_left_shift_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_less_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_less_equal_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_linspace_)( mlx_array* res, double start, double stop, int num, mlx_dtype dtype, const mlx_stream s); extern int (*mlx_log_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_log10_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_log1p_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_log2_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_logaddexp_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_logcumsumexp_)( mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s); extern int (*mlx_logical_and_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_logical_not_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_logical_or_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_logsumexp_axes_)( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); extern int (*mlx_logsumexp_axis_)( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); extern int (*mlx_logsumexp_)( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); extern int (*mlx_masked_scatter_)( mlx_array* res, const mlx_array a, const mlx_array mask, const mlx_array src, const mlx_stream s); extern int (*mlx_matmul_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_max_axes_)( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); extern int (*mlx_max_axis_)( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); extern int (*mlx_max_)( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); extern int (*mlx_maximum_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_mean_axes_)( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); extern int (*mlx_mean_axis_)( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); extern int (*mlx_mean_)( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); extern int (*mlx_median_)( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); extern int (*mlx_meshgrid_)( mlx_vector_array* res, const mlx_vector_array arrays, bool sparse, const char* indexing, const mlx_stream s); extern int (*mlx_min_axes_)( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); extern int (*mlx_min_axis_)( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); extern int (*mlx_min_)( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); extern int (*mlx_minimum_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_moveaxis_)( mlx_array* res, const mlx_array a, int source, int destination, const mlx_stream s); extern int (*mlx_multiply_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_nan_to_num_)( mlx_array* res, const mlx_array a, float nan, mlx_optional_float posinf, mlx_optional_float neginf, const mlx_stream s); extern int (*mlx_negative_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_not_equal_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_number_of_elements_)( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool inverted, mlx_dtype dtype, const mlx_stream s); extern int (*mlx_ones_)( mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s); extern int (*mlx_ones_like_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_outer_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_pad_)( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const int* low_pad_size, size_t low_pad_size_num, const int* high_pad_size, size_t high_pad_size_num, const mlx_array pad_value, const char* mode, const mlx_stream s); extern int (*mlx_pad_symmetric_)( mlx_array* res, const mlx_array a, int pad_width, const mlx_array pad_value, const char* mode, const mlx_stream s); extern int (*mlx_partition_axis_)( mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s); extern int (*mlx_partition_)( mlx_array* res, const mlx_array a, int kth, const mlx_stream s); extern int (*mlx_power_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_prod_axes_)( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); extern int (*mlx_prod_axis_)( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); extern int (*mlx_prod_)( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); extern int (*mlx_put_along_axis_)( mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s); extern int (*mlx_qqmm_)( mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales /* may be null */, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x /* may be null */, const mlx_array global_scale_w /* may be null */, const mlx_stream s); extern int (*mlx_quantize_)( mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale /* may be null */, const mlx_stream s); extern int (*mlx_quantized_matmul_)( mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases /* may be null */, bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); extern int (*mlx_radians_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_real_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_reciprocal_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_remainder_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_repeat_axis_)( mlx_array* res, const mlx_array arr, int repeats, int axis, const mlx_stream s); extern int (*mlx_repeat_)( mlx_array* res, const mlx_array arr, int repeats, const mlx_stream s); extern int (*mlx_reshape_)( mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s); extern int (*mlx_right_shift_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_roll_axis_)( mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, int axis, const mlx_stream s); extern int (*mlx_roll_axes_)( mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_roll_)( mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const mlx_stream s); extern int (*mlx_round_)( mlx_array* res, const mlx_array a, int decimals, const mlx_stream s); extern int (*mlx_rsqrt_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_scatter_)( mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_scatter_single_)( mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s); extern int (*mlx_scatter_add_)( mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_scatter_add_single_)( mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s); extern int (*mlx_scatter_add_axis_)( mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s); extern int (*mlx_scatter_max_)( mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_scatter_max_single_)( mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s); extern int (*mlx_scatter_min_)( mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_scatter_min_single_)( mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s); extern int (*mlx_scatter_prod_)( mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_scatter_prod_single_)( mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s); extern int (*mlx_segmented_mm_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array segments, const mlx_stream s); extern int (*mlx_sigmoid_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_sign_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_sin_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_sinh_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_slice_)( mlx_array* res, const mlx_array a, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s); extern int (*mlx_slice_dynamic_)( mlx_array* res, const mlx_array a, const mlx_array start, const int* axes, size_t axes_num, const int* slice_size, size_t slice_size_num, const mlx_stream s); extern int (*mlx_slice_update_)( mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s); extern int (*mlx_slice_update_dynamic_)( mlx_array* res, const mlx_array src, const mlx_array update, const mlx_array start, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_softmax_axes_)( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool precise, const mlx_stream s); extern int (*mlx_softmax_axis_)( mlx_array* res, const mlx_array a, int axis, bool precise, const mlx_stream s); extern int (*mlx_softmax_)( mlx_array* res, const mlx_array a, bool precise, const mlx_stream s); extern int (*mlx_sort_axis_)( mlx_array* res, const mlx_array a, int axis, const mlx_stream s); extern int (*mlx_sort_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_split_)( mlx_vector_array* res, const mlx_array a, int num_splits, int axis, const mlx_stream s); extern int (*mlx_split_sections_)( mlx_vector_array* res, const mlx_array a, const int* indices, size_t indices_num, int axis, const mlx_stream s); extern int (*mlx_sqrt_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_square_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_squeeze_axes_)( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_squeeze_axis_)( mlx_array* res, const mlx_array a, int axis, const mlx_stream s); extern int (*mlx_squeeze_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_stack_axis_)( mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s); extern int (*mlx_stack_)( mlx_array* res, const mlx_vector_array arrays, const mlx_stream s); extern int (*mlx_std_axes_)( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s); extern int (*mlx_std_axis_)( mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s); extern int (*mlx_std_)( mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s); extern int (*mlx_stop_gradient_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_subtract_)( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); extern int (*mlx_sum_axes_)( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); extern int (*mlx_sum_axis_)( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); extern int (*mlx_sum_)( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); extern int (*mlx_swapaxes_)( mlx_array* res, const mlx_array a, int axis1, int axis2, const mlx_stream s); extern int (*mlx_take_axis_)( mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s); extern int (*mlx_take_)( mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_stream s); extern int (*mlx_take_along_axis_)( mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s); extern int (*mlx_tan_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_tanh_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_tensordot_)( mlx_array* res, const mlx_array a, const mlx_array b, const int* axes_a, size_t axes_a_num, const int* axes_b, size_t axes_b_num, const mlx_stream s); extern int (*mlx_tensordot_axis_)( mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s); extern int (*mlx_tile_)( mlx_array* res, const mlx_array arr, const int* reps, size_t reps_num, const mlx_stream s); extern int (*mlx_to_fp8_)(mlx_array* res, const mlx_array x, const mlx_stream s); extern int (*mlx_topk_axis_)( mlx_array* res, const mlx_array a, int k, int axis, const mlx_stream s); extern int (*mlx_topk_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s); extern int (*mlx_trace_)( mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, mlx_dtype dtype, const mlx_stream s); extern int (*mlx_transpose_axes_)( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s); extern int (*mlx_transpose_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_tri_)( mlx_array* res, int n, int m, int k, mlx_dtype type, const mlx_stream s); extern int (*mlx_tril_)(mlx_array* res, const mlx_array x, int k, const mlx_stream s); extern int (*mlx_triu_)(mlx_array* res, const mlx_array x, int k, const mlx_stream s); extern int (*mlx_unflatten_)( mlx_array* res, const mlx_array a, int axis, const int* shape, size_t shape_num, const mlx_stream s); extern int (*mlx_var_axes_)( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s); extern int (*mlx_var_axis_)( mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s); extern int (*mlx_var_)( mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s); extern int (*mlx_view_)( mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s); extern int (*mlx_where_)( mlx_array* res, const mlx_array condition, const mlx_array x, const mlx_array y, const mlx_stream s); extern int (*mlx_zeros_)( mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s); extern int (*mlx_zeros_like_)(mlx_array* res, const mlx_array a, const mlx_stream s); extern int (*mlx_random_bernoulli_)( mlx_array* res, const mlx_array p, const int* shape, size_t shape_num, const mlx_array key /* may be null */, const mlx_stream s); extern int (*mlx_random_bits_)( mlx_array* res, const int* shape, size_t shape_num, int width, const mlx_array key /* may be null */, const mlx_stream s); extern int (*mlx_random_categorical_shape_)( mlx_array* res, const mlx_array logits, int axis, const int* shape, size_t shape_num, const mlx_array key /* may be null */, const mlx_stream s); extern int (*mlx_random_categorical_num_samples_)( mlx_array* res, const mlx_array logits_, int axis, int num_samples, const mlx_array key /* may be null */, const mlx_stream s); extern int (*mlx_random_categorical_)( mlx_array* res, const mlx_array logits, int axis, const mlx_array key /* may be null */, const mlx_stream s); extern int (*mlx_random_gumbel_)( mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key /* may be null */, const mlx_stream s); extern int (*mlx_random_key_)(mlx_array* res, uint64_t seed); extern int (*mlx_random_laplace_)( mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key /* may be null */, const mlx_stream s); extern int (*mlx_random_multivariate_normal_)( mlx_array* res, const mlx_array mean, const mlx_array cov, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key /* may be null */, const mlx_stream s); extern int (*mlx_random_normal_broadcast_)( mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array loc /* may be null */, const mlx_array scale /* may be null */, const mlx_array key /* may be null */, const mlx_stream s); extern int (*mlx_random_normal_)( mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key /* may be null */, const mlx_stream s); extern int (*mlx_random_permutation_)( mlx_array* res, const mlx_array x, int axis, const mlx_array key /* may be null */, const mlx_stream s); extern int (*mlx_random_permutation_arange_)( mlx_array* res, int x, const mlx_array key /* may be null */, const mlx_stream s); extern int (*mlx_random_randint_)( mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key /* may be null */, const mlx_stream s); extern int (*mlx_random_seed_)(uint64_t seed); extern int (*mlx_random_split_num_)( mlx_array* res, const mlx_array key, int num, const mlx_stream s); extern int (*mlx_random_split_)( mlx_array* res_0, mlx_array* res_1, const mlx_array key, const mlx_stream s); extern int (*mlx_random_truncated_normal_)( mlx_array* res, const mlx_array lower, const mlx_array upper, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key /* may be null */, const mlx_stream s); extern int (*mlx_random_uniform_)( mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key /* may be null */, const mlx_stream s); extern mlx_stream (*mlx_stream_new_)(void); extern mlx_stream (*mlx_stream_new_device_)(mlx_device dev); extern int (*mlx_stream_set_)(mlx_stream* stream, const mlx_stream src); extern int (*mlx_stream_free_)(mlx_stream stream); extern int (*mlx_stream_tostring_)(mlx_string* str, mlx_stream stream); extern bool (*mlx_stream_equal_)(mlx_stream lhs, mlx_stream rhs); extern int (*mlx_stream_get_device_)(mlx_device* dev, mlx_stream stream); extern int (*mlx_stream_get_index_)(int* index, mlx_stream stream); extern int (*mlx_synchronize_)(mlx_stream stream); extern int (*mlx_get_default_stream_)(mlx_stream* stream, mlx_device dev); extern int (*mlx_set_default_stream_)(mlx_stream stream); extern mlx_stream (*mlx_default_cpu_stream_new_)(void); extern mlx_stream (*mlx_default_gpu_stream_new_)(void); extern mlx_string (*mlx_string_new_)(void); extern mlx_string (*mlx_string_new_data_)(const char* str); extern int (*mlx_string_set_)(mlx_string* str, const mlx_string src); extern const char * (*mlx_string_data_)(mlx_string str); extern int (*mlx_string_free_)(mlx_string str); extern int (*mlx_async_eval_)(const mlx_vector_array outputs); extern int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun); extern int (*mlx_custom_function_)( mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp /* may be null */, const mlx_closure_custom_jvp fun_jvp /* may be null */, const mlx_closure_custom_vmap fun_vmap /* may be null */); extern int (*mlx_custom_vjp_)( mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp); extern int (*mlx_eval_)(const mlx_vector_array outputs); extern int (*mlx_jvp_)( mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array tangents); extern int (*mlx_value_and_grad_)( mlx_closure_value_and_grad* res, const mlx_closure fun, const int* argnums, size_t argnums_num); extern int (*mlx_vjp_)( mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array cotangents); extern int (*mlx_detail_vmap_replace_)( mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array s_inputs, const mlx_vector_array s_outputs, const int* in_axes, size_t in_axes_num, const int* out_axes, size_t out_axes_num); extern int (*mlx_detail_vmap_trace_)( mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array inputs, const int* in_axes, size_t in_axes_num); extern mlx_vector_array (*mlx_vector_array_new_)(void); extern int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src); extern int (*mlx_vector_array_free_)(mlx_vector_array vec); extern mlx_vector_array (*mlx_vector_array_new_data_)(const mlx_array* data, size_t size); extern mlx_vector_array (*mlx_vector_array_new_value_)(const mlx_array val); extern int (*mlx_vector_array_set_data_)( mlx_vector_array* vec, const mlx_array* data, size_t size); extern int (*mlx_vector_array_set_value_)(mlx_vector_array* vec, const mlx_array val); extern int (*mlx_vector_array_append_data_)( mlx_vector_array vec, const mlx_array* data, size_t size); extern int (*mlx_vector_array_append_value_)(mlx_vector_array vec, const mlx_array val); extern size_t (*mlx_vector_array_size_)(mlx_vector_array vec); extern int (*mlx_vector_array_get_)( mlx_array* res, const mlx_vector_array vec, size_t idx); extern mlx_vector_vector_array (*mlx_vector_vector_array_new_)(void); extern int (*mlx_vector_vector_array_set_)( mlx_vector_vector_array* vec, const mlx_vector_vector_array src); extern int (*mlx_vector_vector_array_free_)(mlx_vector_vector_array vec); extern mlx_vector_vector_array (*mlx_vector_vector_array_new_data_)( const mlx_vector_array* data, size_t size); extern mlx_vector_vector_array (*mlx_vector_vector_array_new_value_)( const mlx_vector_array val); extern int (*mlx_vector_vector_array_set_data_)( mlx_vector_vector_array* vec, const mlx_vector_array* data, size_t size); extern int (*mlx_vector_vector_array_set_value_)( mlx_vector_vector_array* vec, const mlx_vector_array val); extern int (*mlx_vector_vector_array_append_data_)( mlx_vector_vector_array vec, const mlx_vector_array* data, size_t size); extern int (*mlx_vector_vector_array_append_value_)( mlx_vector_vector_array vec, const mlx_vector_array val); extern size_t (*mlx_vector_vector_array_size_)(mlx_vector_vector_array vec); extern int (*mlx_vector_vector_array_get_)( mlx_vector_array* res, const mlx_vector_vector_array vec, size_t idx); extern mlx_vector_int (*mlx_vector_int_new_)(void); extern int (*mlx_vector_int_set_)(mlx_vector_int* vec, const mlx_vector_int src); extern int (*mlx_vector_int_free_)(mlx_vector_int vec); extern mlx_vector_int (*mlx_vector_int_new_data_)(int* data, size_t size); extern mlx_vector_int (*mlx_vector_int_new_value_)(int val); extern int (*mlx_vector_int_set_data_)(mlx_vector_int* vec, int* data, size_t size); extern int (*mlx_vector_int_set_value_)(mlx_vector_int* vec, int val); extern int (*mlx_vector_int_append_data_)(mlx_vector_int vec, int* data, size_t size); extern int (*mlx_vector_int_append_value_)(mlx_vector_int vec, int val); extern size_t (*mlx_vector_int_size_)(mlx_vector_int vec); extern int (*mlx_vector_int_get_)(int* res, const mlx_vector_int vec, size_t idx); extern mlx_vector_string (*mlx_vector_string_new_)(void); extern int (*mlx_vector_string_set_)(mlx_vector_string* vec, const mlx_vector_string src); extern int (*mlx_vector_string_free_)(mlx_vector_string vec); extern mlx_vector_string (*mlx_vector_string_new_data_)(const char** data, size_t size); extern mlx_vector_string (*mlx_vector_string_new_value_)(const char* val); extern int (*mlx_vector_string_set_data_)( mlx_vector_string* vec, const char** data, size_t size); extern int (*mlx_vector_string_set_value_)(mlx_vector_string* vec, const char* val); extern int (*mlx_vector_string_append_data_)( mlx_vector_string vec, const char** data, size_t size); extern int (*mlx_vector_string_append_value_)(mlx_vector_string vec, const char* val); extern size_t (*mlx_vector_string_size_)(mlx_vector_string vec); extern int (*mlx_vector_string_get_)(char** res, const mlx_vector_string vec, size_t idx); extern int (*mlx_version_)(mlx_string* str_); int mlx_dynamic_load_symbols(mlx_dynamic_handle handle); static inline size_t mlx_dtype_size(mlx_dtype dtype) { return mlx_dtype_size_(dtype); } static inline int mlx_array_tostring(mlx_string* str, const mlx_array arr) { return mlx_array_tostring_(str, arr); } static inline mlx_array mlx_array_new(void) { return mlx_array_new_(); } static inline int mlx_array_free(mlx_array arr) { return mlx_array_free_(arr); } static inline mlx_array mlx_array_new_bool(bool val) { return mlx_array_new_bool_(val); } static inline mlx_array mlx_array_new_int(int val) { return mlx_array_new_int_(val); } static inline mlx_array mlx_array_new_float32(float val) { return mlx_array_new_float32_(val); } static inline mlx_array mlx_array_new_float(float val) { return mlx_array_new_float_(val); } static inline mlx_array mlx_array_new_float64(double val) { return mlx_array_new_float64_(val); } static inline mlx_array mlx_array_new_double(double val) { return mlx_array_new_double_(val); } static inline mlx_array mlx_array_new_complex(float real_val, float imag_val) { return mlx_array_new_complex_(real_val, imag_val); } static inline mlx_array mlx_array_new_data( const void* data, const int* shape, int dim, mlx_dtype dtype) { return mlx_array_new_data_(data, shape, dim, dtype); } static inline mlx_array mlx_array_new_data_managed( void* data, const int* shape, int dim, mlx_dtype dtype, void (*dtor)(void*)) { return mlx_array_new_data_managed_(data, shape, dim, dtype, dtor); } static inline mlx_array mlx_array_new_data_managed_payload( void* data, const int* shape, int dim, mlx_dtype dtype, void* payload, void (*dtor)(void*)) { return mlx_array_new_data_managed_payload_(data, shape, dim, dtype, payload, dtor); } static inline int mlx_array_set(mlx_array* arr, const mlx_array src) { return mlx_array_set_(arr, src); } static inline int mlx_array_set_bool(mlx_array* arr, bool val) { return mlx_array_set_bool_(arr, val); } static inline int mlx_array_set_int(mlx_array* arr, int val) { return mlx_array_set_int_(arr, val); } static inline int mlx_array_set_float32(mlx_array* arr, float val) { return mlx_array_set_float32_(arr, val); } static inline int mlx_array_set_float(mlx_array* arr, float val) { return mlx_array_set_float_(arr, val); } static inline int mlx_array_set_float64(mlx_array* arr, double val) { return mlx_array_set_float64_(arr, val); } static inline int mlx_array_set_double(mlx_array* arr, double val) { return mlx_array_set_double_(arr, val); } static inline int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val) { return mlx_array_set_complex_(arr, real_val, imag_val); } static inline int mlx_array_set_data( mlx_array* arr, const void* data, const int* shape, int dim, mlx_dtype dtype) { return mlx_array_set_data_(arr, data, shape, dim, dtype); } static inline size_t mlx_array_itemsize(const mlx_array arr) { return mlx_array_itemsize_(arr); } static inline size_t mlx_array_size(const mlx_array arr) { return mlx_array_size_(arr); } static inline size_t mlx_array_nbytes(const mlx_array arr) { return mlx_array_nbytes_(arr); } static inline size_t mlx_array_ndim(const mlx_array arr) { return mlx_array_ndim_(arr); } static inline const int * mlx_array_shape(const mlx_array arr) { return mlx_array_shape_(arr); } static inline const size_t * mlx_array_strides(const mlx_array arr) { return mlx_array_strides_(arr); } static inline int mlx_array_dim(const mlx_array arr, int dim) { return mlx_array_dim_(arr, dim); } static inline mlx_dtype mlx_array_dtype(const mlx_array arr) { return mlx_array_dtype_(arr); } static inline int mlx_array_eval(mlx_array arr) { return mlx_array_eval_(arr); } static inline int mlx_array_item_bool(bool* res, const mlx_array arr) { return mlx_array_item_bool_(res, arr); } static inline int mlx_array_item_uint8(uint8_t* res, const mlx_array arr) { return mlx_array_item_uint8_(res, arr); } static inline int mlx_array_item_uint16(uint16_t* res, const mlx_array arr) { return mlx_array_item_uint16_(res, arr); } static inline int mlx_array_item_uint32(uint32_t* res, const mlx_array arr) { return mlx_array_item_uint32_(res, arr); } static inline int mlx_array_item_uint64(uint64_t* res, const mlx_array arr) { return mlx_array_item_uint64_(res, arr); } static inline int mlx_array_item_int8(int8_t* res, const mlx_array arr) { return mlx_array_item_int8_(res, arr); } static inline int mlx_array_item_int16(int16_t* res, const mlx_array arr) { return mlx_array_item_int16_(res, arr); } static inline int mlx_array_item_int32(int32_t* res, const mlx_array arr) { return mlx_array_item_int32_(res, arr); } static inline int mlx_array_item_int64(int64_t* res, const mlx_array arr) { return mlx_array_item_int64_(res, arr); } static inline int mlx_array_item_float32(float* res, const mlx_array arr) { return mlx_array_item_float32_(res, arr); } static inline int mlx_array_item_float64(double* res, const mlx_array arr) { return mlx_array_item_float64_(res, arr); } static inline int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr) { return mlx_array_item_complex64_(res, arr); } static inline int mlx_array_item_float16(float16_t* res, const mlx_array arr) { return mlx_array_item_float16_(res, arr); } static inline int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr) { return mlx_array_item_bfloat16_(res, arr); } static inline const bool * mlx_array_data_bool(const mlx_array arr) { return mlx_array_data_bool_(arr); } static inline const uint8_t * mlx_array_data_uint8(const mlx_array arr) { return mlx_array_data_uint8_(arr); } static inline const uint16_t * mlx_array_data_uint16(const mlx_array arr) { return mlx_array_data_uint16_(arr); } static inline const uint32_t * mlx_array_data_uint32(const mlx_array arr) { return mlx_array_data_uint32_(arr); } static inline const uint64_t * mlx_array_data_uint64(const mlx_array arr) { return mlx_array_data_uint64_(arr); } static inline const int8_t * mlx_array_data_int8(const mlx_array arr) { return mlx_array_data_int8_(arr); } static inline const int16_t * mlx_array_data_int16(const mlx_array arr) { return mlx_array_data_int16_(arr); } static inline const int32_t * mlx_array_data_int32(const mlx_array arr) { return mlx_array_data_int32_(arr); } static inline const int64_t * mlx_array_data_int64(const mlx_array arr) { return mlx_array_data_int64_(arr); } static inline const float * mlx_array_data_float32(const mlx_array arr) { return mlx_array_data_float32_(arr); } static inline const double * mlx_array_data_float64(const mlx_array arr) { return mlx_array_data_float64_(arr); } static inline const mlx_complex64_t * mlx_array_data_complex64(const mlx_array arr) { return mlx_array_data_complex64_(arr); } static inline const float16_t * mlx_array_data_float16(const mlx_array arr) { return mlx_array_data_float16_(arr); } static inline const bfloat16_t * mlx_array_data_bfloat16(const mlx_array arr) { return mlx_array_data_bfloat16_(arr); } static inline int _mlx_array_is_available(bool* res, const mlx_array arr) { return _mlx_array_is_available_(res, arr); } static inline int _mlx_array_wait(const mlx_array arr) { return _mlx_array_wait_(arr); } static inline int _mlx_array_is_contiguous(bool* res, const mlx_array arr) { return _mlx_array_is_contiguous_(res, arr); } static inline int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr) { return _mlx_array_is_row_contiguous_(res, arr); } static inline int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr) { return _mlx_array_is_col_contiguous_(res, arr); } static inline mlx_closure mlx_closure_new(void) { return mlx_closure_new_(); } static inline int mlx_closure_free(mlx_closure cls) { return mlx_closure_free_(cls); } static inline mlx_closure mlx_closure_new_func( int (*fun)(mlx_vector_array*, const mlx_vector_array)) { return mlx_closure_new_func_(fun); } static inline mlx_closure mlx_closure_new_func_payload( int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) { return mlx_closure_new_func_payload_(fun, payload, dtor); } static inline int mlx_closure_set(mlx_closure* cls, const mlx_closure src) { return mlx_closure_set_(cls, src); } static inline int mlx_closure_apply( mlx_vector_array* res, mlx_closure cls, const mlx_vector_array input) { return mlx_closure_apply_(res, cls, input); } static inline mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array)) { return mlx_closure_new_unary_(fun); } static inline mlx_closure_kwargs mlx_closure_kwargs_new(void) { return mlx_closure_kwargs_new_(); } static inline int mlx_closure_kwargs_free(mlx_closure_kwargs cls) { return mlx_closure_kwargs_free_(cls); } static inline mlx_closure_kwargs mlx_closure_kwargs_new_func( int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array)) { return mlx_closure_kwargs_new_func_(fun); } static inline mlx_closure_kwargs mlx_closure_kwargs_new_func_payload( int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array, void*), void* payload, void (*dtor)(void*)) { return mlx_closure_kwargs_new_func_payload_(fun, payload, dtor); } static inline int mlx_closure_kwargs_set( mlx_closure_kwargs* cls, const mlx_closure_kwargs src) { return mlx_closure_kwargs_set_(cls, src); } static inline int mlx_closure_kwargs_apply( mlx_vector_array* res, mlx_closure_kwargs cls, const mlx_vector_array input_0, const mlx_map_string_to_array input_1) { return mlx_closure_kwargs_apply_(res, cls, input_0, input_1); } static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void) { return mlx_closure_value_and_grad_new_(); } static inline int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls) { return mlx_closure_value_and_grad_free_(cls); } static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func( int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)) { return mlx_closure_value_and_grad_new_func_(fun); } static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload( int (*fun)( mlx_vector_array*, mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) { return mlx_closure_value_and_grad_new_func_payload_(fun, payload, dtor); } static inline int mlx_closure_value_and_grad_set( mlx_closure_value_and_grad* cls, const mlx_closure_value_and_grad src) { return mlx_closure_value_and_grad_set_(cls, src); } static inline int mlx_closure_value_and_grad_apply( mlx_vector_array* res_0, mlx_vector_array* res_1, mlx_closure_value_and_grad cls, const mlx_vector_array input) { return mlx_closure_value_and_grad_apply_(res_0, res_1, cls, input); } static inline mlx_closure_custom mlx_closure_custom_new(void) { return mlx_closure_custom_new_(); } static inline int mlx_closure_custom_free(mlx_closure_custom cls) { return mlx_closure_custom_free_(cls); } static inline mlx_closure_custom mlx_closure_custom_new_func( int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array)) { return mlx_closure_custom_new_func_(fun); } static inline mlx_closure_custom mlx_closure_custom_new_func_payload( int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) { return mlx_closure_custom_new_func_payload_(fun, payload, dtor); } static inline int mlx_closure_custom_set( mlx_closure_custom* cls, const mlx_closure_custom src) { return mlx_closure_custom_set_(cls, src); } static inline int mlx_closure_custom_apply( mlx_vector_array* res, mlx_closure_custom cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const mlx_vector_array input_2) { return mlx_closure_custom_apply_(res, cls, input_0, input_1, input_2); } static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void) { return mlx_closure_custom_jvp_new_(); } static inline int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls) { return mlx_closure_custom_jvp_free_(cls); } static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func( int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num)) { return mlx_closure_custom_jvp_new_func_(fun); } static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload( int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)) { return mlx_closure_custom_jvp_new_func_payload_(fun, payload, dtor); } static inline int mlx_closure_custom_jvp_set( mlx_closure_custom_jvp* cls, const mlx_closure_custom_jvp src) { return mlx_closure_custom_jvp_set_(cls, src); } static inline int mlx_closure_custom_jvp_apply( mlx_vector_array* res, mlx_closure_custom_jvp cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const int* input_2, size_t input_2_num) { return mlx_closure_custom_jvp_apply_(res, cls, input_0, input_1, input_2, input_2_num); } static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void) { return mlx_closure_custom_vmap_new_(); } static inline int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls) { return mlx_closure_custom_vmap_free_(cls); } static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func( int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num)) { return mlx_closure_custom_vmap_new_func_(fun); } static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload( int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)) { return mlx_closure_custom_vmap_new_func_payload_(fun, payload, dtor); } static inline int mlx_closure_custom_vmap_set( mlx_closure_custom_vmap* cls, const mlx_closure_custom_vmap src) { return mlx_closure_custom_vmap_set_(cls, src); } static inline int mlx_closure_custom_vmap_apply( mlx_vector_array* res_0, mlx_vector_int* res_1, mlx_closure_custom_vmap cls, const mlx_vector_array input_0, const int* input_1, size_t input_1_num) { return mlx_closure_custom_vmap_apply_(res_0, res_1, cls, input_0, input_1, input_1_num); } static inline int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless) { return mlx_compile_(res, fun, shapeless); } static inline int mlx_detail_compile( mlx_closure* res, const mlx_closure fun, uintptr_t fun_id, bool shapeless, const uint64_t* constants, size_t constants_num) { return mlx_detail_compile_(res, fun, fun_id, shapeless, constants, constants_num); } static inline int mlx_detail_compile_clear_cache(void) { return mlx_detail_compile_clear_cache_(); } static inline int mlx_detail_compile_erase(uintptr_t fun_id) { return mlx_detail_compile_erase_(fun_id); } static inline int mlx_disable_compile(void) { return mlx_disable_compile_(); } static inline int mlx_enable_compile(void) { return mlx_enable_compile_(); } static inline int mlx_set_compile_mode(mlx_compile_mode mode) { return mlx_set_compile_mode_(mode); } static inline int mlx_cuda_is_available(bool* res) { return mlx_cuda_is_available_(res); } static inline mlx_device mlx_device_new(void) { return mlx_device_new_(); } static inline mlx_device mlx_device_new_type(mlx_device_type type, int index) { return mlx_device_new_type_(type, index); } static inline int mlx_device_free(mlx_device dev) { return mlx_device_free_(dev); } static inline int mlx_device_set(mlx_device* dev, const mlx_device src) { return mlx_device_set_(dev, src); } static inline int mlx_device_tostring(mlx_string* str, mlx_device dev) { return mlx_device_tostring_(str, dev); } static inline bool mlx_device_equal(mlx_device lhs, mlx_device rhs) { return mlx_device_equal_(lhs, rhs); } static inline int mlx_device_get_index(int* index, mlx_device dev) { return mlx_device_get_index_(index, dev); } static inline int mlx_device_get_type(mlx_device_type* type, mlx_device dev) { return mlx_device_get_type_(type, dev); } static inline int mlx_get_default_device(mlx_device* dev) { return mlx_get_default_device_(dev); } static inline int mlx_set_default_device(mlx_device dev) { return mlx_set_default_device_(dev); } static inline int mlx_device_is_available(bool* avail, mlx_device dev) { return mlx_device_is_available_(avail, dev); } static inline int mlx_device_count(int* count, mlx_device_type type) { return mlx_device_count_(count, type); } static inline mlx_device_info mlx_device_info_new(void) { return mlx_device_info_new_(); } static inline int mlx_device_info_get(mlx_device_info* info, mlx_device dev) { return mlx_device_info_get_(info, dev); } static inline int mlx_device_info_free(mlx_device_info info) { return mlx_device_info_free_(info); } static inline int mlx_device_info_has_key( bool* exists, mlx_device_info info, const char* key) { return mlx_device_info_has_key_(exists, info, key); } static inline int mlx_device_info_is_string( bool* is_string, mlx_device_info info, const char* key) { return mlx_device_info_is_string_(is_string, info, key); } static inline int mlx_device_info_get_string( const char** value, mlx_device_info info, const char* key) { return mlx_device_info_get_string_(value, info, key); } static inline int mlx_device_info_get_size( size_t* value, mlx_device_info info, const char* key) { return mlx_device_info_get_size_(value, info, key); } static inline int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info) { return mlx_device_info_get_keys_(keys, info); } static inline int mlx_distributed_all_gather( mlx_array* res, const mlx_array x, const mlx_distributed_group group /* may be null */, const mlx_stream S) { return mlx_distributed_all_gather_(res, x, group, S); } static inline int mlx_distributed_all_max( mlx_array* res, const mlx_array x, const mlx_distributed_group group /* may be null */, const mlx_stream s) { return mlx_distributed_all_max_(res, x, group, s); } static inline int mlx_distributed_all_min( mlx_array* res, const mlx_array x, const mlx_distributed_group group /* may be null */, const mlx_stream s) { return mlx_distributed_all_min_(res, x, group, s); } static inline int mlx_distributed_all_sum( mlx_array* res, const mlx_array x, const mlx_distributed_group group /* may be null */, const mlx_stream s) { return mlx_distributed_all_sum_(res, x, group, s); } static inline int mlx_distributed_recv( mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, int src, const mlx_distributed_group group /* may be null */, const mlx_stream s) { return mlx_distributed_recv_(res, shape, shape_num, dtype, src, group, s); } static inline int mlx_distributed_recv_like( mlx_array* res, const mlx_array x, int src, const mlx_distributed_group group /* may be null */, const mlx_stream s) { return mlx_distributed_recv_like_(res, x, src, group, s); } static inline int mlx_distributed_send( mlx_array* res, const mlx_array x, int dst, const mlx_distributed_group group /* may be null */, const mlx_stream s) { return mlx_distributed_send_(res, x, dst, group, s); } static inline int mlx_distributed_sum_scatter( mlx_array* res, const mlx_array x, const mlx_distributed_group group /* may be null */, const mlx_stream s) { return mlx_distributed_sum_scatter_(res, x, group, s); } static inline int mlx_distributed_group_rank(mlx_distributed_group group) { return mlx_distributed_group_rank_(group); } static inline int mlx_distributed_group_size(mlx_distributed_group group) { return mlx_distributed_group_size_(group); } static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) { return mlx_distributed_group_split_(group, color, key); } static inline bool mlx_distributed_is_available(const char* bk /* may be null */) { return mlx_distributed_is_available_(bk); } static inline mlx_distributed_group mlx_distributed_init( bool strict, const char* bk /* may be null */) { return mlx_distributed_init_(strict, bk); } static inline void mlx_set_error_handler( mlx_error_handler_func handler, void* data, void (*dtor)(void*)) { return mlx_set_error_handler_(handler, data, dtor); } static inline void _mlx_error(const char* file, const int line, const char* fmt, ...) { return _mlx_error_(file, line, fmt); } static inline int mlx_export_function( const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) { return mlx_export_function_(file, fun, args, shapeless); } static inline int mlx_export_function_kwargs( const char* file, const mlx_closure_kwargs fun, const mlx_vector_array args, const mlx_map_string_to_array kwargs, bool shapeless) { return mlx_export_function_kwargs_(file, fun, args, kwargs, shapeless); } static inline mlx_function_exporter mlx_function_exporter_new( const char* file, const mlx_closure fun, bool shapeless) { return mlx_function_exporter_new_(file, fun, shapeless); } static inline int mlx_function_exporter_free(mlx_function_exporter xfunc) { return mlx_function_exporter_free_(xfunc); } static inline int mlx_function_exporter_apply( const mlx_function_exporter xfunc, const mlx_vector_array args) { return mlx_function_exporter_apply_(xfunc, args); } static inline int mlx_function_exporter_apply_kwargs( const mlx_function_exporter xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) { return mlx_function_exporter_apply_kwargs_(xfunc, args, kwargs); } static inline mlx_imported_function mlx_imported_function_new(const char* file) { return mlx_imported_function_new_(file); } static inline int mlx_imported_function_free(mlx_imported_function xfunc) { return mlx_imported_function_free_(xfunc); } static inline int mlx_imported_function_apply( mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args) { return mlx_imported_function_apply_(res, xfunc, args); } static inline int mlx_imported_function_apply_kwargs( mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) { return mlx_imported_function_apply_kwargs_(res, xfunc, args, kwargs); } static inline mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void) { return mlx_fast_cuda_kernel_config_new_(); } static inline void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls) { return mlx_fast_cuda_kernel_config_free_(cls); } static inline int mlx_fast_cuda_kernel_config_add_output_arg( mlx_fast_cuda_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype) { return mlx_fast_cuda_kernel_config_add_output_arg_(cls, shape, size, dtype); } static inline int mlx_fast_cuda_kernel_config_set_grid( mlx_fast_cuda_kernel_config cls, int grid1, int grid2, int grid3) { return mlx_fast_cuda_kernel_config_set_grid_(cls, grid1, grid2, grid3); } static inline int mlx_fast_cuda_kernel_config_set_thread_group( mlx_fast_cuda_kernel_config cls, int thread1, int thread2, int thread3) { return mlx_fast_cuda_kernel_config_set_thread_group_(cls, thread1, thread2, thread3); } static inline int mlx_fast_cuda_kernel_config_set_init_value( mlx_fast_cuda_kernel_config cls, float value) { return mlx_fast_cuda_kernel_config_set_init_value_(cls, value); } static inline int mlx_fast_cuda_kernel_config_set_verbose( mlx_fast_cuda_kernel_config cls, bool verbose) { return mlx_fast_cuda_kernel_config_set_verbose_(cls, verbose); } static inline int mlx_fast_cuda_kernel_config_add_template_arg_dtype( mlx_fast_cuda_kernel_config cls, const char* name, mlx_dtype dtype) { return mlx_fast_cuda_kernel_config_add_template_arg_dtype_(cls, name, dtype); } static inline int mlx_fast_cuda_kernel_config_add_template_arg_int( mlx_fast_cuda_kernel_config cls, const char* name, int value) { return mlx_fast_cuda_kernel_config_add_template_arg_int_(cls, name, value); } static inline int mlx_fast_cuda_kernel_config_add_template_arg_bool( mlx_fast_cuda_kernel_config cls, const char* name, bool value) { return mlx_fast_cuda_kernel_config_add_template_arg_bool_(cls, name, value); } static inline mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new( const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, int shared_memory) { return mlx_fast_cuda_kernel_new_(name, input_names, output_names, source, header, ensure_row_contiguous, shared_memory); } static inline void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls) { return mlx_fast_cuda_kernel_free_(cls); } static inline int mlx_fast_cuda_kernel_apply( mlx_vector_array* outputs, mlx_fast_cuda_kernel cls, const mlx_vector_array inputs, const mlx_fast_cuda_kernel_config config, const mlx_stream stream) { return mlx_fast_cuda_kernel_apply_(outputs, cls, inputs, config, stream); } static inline int mlx_fast_layer_norm( mlx_array* res, const mlx_array x, const mlx_array weight /* may be null */, const mlx_array bias /* may be null */, float eps, const mlx_stream s) { return mlx_fast_layer_norm_(res, x, weight, bias, eps, s); } static inline mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void) { return mlx_fast_metal_kernel_config_new_(); } static inline void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls) { return mlx_fast_metal_kernel_config_free_(cls); } static inline int mlx_fast_metal_kernel_config_add_output_arg( mlx_fast_metal_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype) { return mlx_fast_metal_kernel_config_add_output_arg_(cls, shape, size, dtype); } static inline int mlx_fast_metal_kernel_config_set_grid( mlx_fast_metal_kernel_config cls, int grid1, int grid2, int grid3) { return mlx_fast_metal_kernel_config_set_grid_(cls, grid1, grid2, grid3); } static inline int mlx_fast_metal_kernel_config_set_thread_group( mlx_fast_metal_kernel_config cls, int thread1, int thread2, int thread3) { return mlx_fast_metal_kernel_config_set_thread_group_(cls, thread1, thread2, thread3); } static inline int mlx_fast_metal_kernel_config_set_init_value( mlx_fast_metal_kernel_config cls, float value) { return mlx_fast_metal_kernel_config_set_init_value_(cls, value); } static inline int mlx_fast_metal_kernel_config_set_verbose( mlx_fast_metal_kernel_config cls, bool verbose) { return mlx_fast_metal_kernel_config_set_verbose_(cls, verbose); } static inline int mlx_fast_metal_kernel_config_add_template_arg_dtype( mlx_fast_metal_kernel_config cls, const char* name, mlx_dtype dtype) { return mlx_fast_metal_kernel_config_add_template_arg_dtype_(cls, name, dtype); } static inline int mlx_fast_metal_kernel_config_add_template_arg_int( mlx_fast_metal_kernel_config cls, const char* name, int value) { return mlx_fast_metal_kernel_config_add_template_arg_int_(cls, name, value); } static inline int mlx_fast_metal_kernel_config_add_template_arg_bool( mlx_fast_metal_kernel_config cls, const char* name, bool value) { return mlx_fast_metal_kernel_config_add_template_arg_bool_(cls, name, value); } static inline mlx_fast_metal_kernel mlx_fast_metal_kernel_new( const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, bool atomic_outputs) { return mlx_fast_metal_kernel_new_(name, input_names, output_names, source, header, ensure_row_contiguous, atomic_outputs); } static inline void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls) { return mlx_fast_metal_kernel_free_(cls); } static inline int mlx_fast_metal_kernel_apply( mlx_vector_array* outputs, mlx_fast_metal_kernel cls, const mlx_vector_array inputs, const mlx_fast_metal_kernel_config config, const mlx_stream stream) { return mlx_fast_metal_kernel_apply_(outputs, cls, inputs, config, stream); } static inline int mlx_fast_rms_norm( mlx_array* res, const mlx_array x, const mlx_array weight /* may be null */, float eps, const mlx_stream s) { return mlx_fast_rms_norm_(res, x, weight, eps, s); } static inline int mlx_fast_rope( mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, int offset, const mlx_array freqs /* may be null */, const mlx_stream s) { return mlx_fast_rope_(res, x, dims, traditional, base, scale, offset, freqs, s); } static inline int mlx_fast_rope_dynamic( mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, const mlx_array offset, const mlx_array freqs /* may be null */, const mlx_stream s) { return mlx_fast_rope_dynamic_(res, x, dims, traditional, base, scale, offset, freqs, s); } static inline int mlx_fast_scaled_dot_product_attention( mlx_array* res, const mlx_array queries, const mlx_array keys, const mlx_array values, float scale, const char* mask_mode, const mlx_array mask_arr /* may be null */, const mlx_array sinks /* may be null */, const mlx_stream s) { return mlx_fast_scaled_dot_product_attention_(res, queries, keys, values, scale, mask_mode, mask_arr, sinks, s); } static inline int mlx_fft_fft( mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) { return mlx_fft_fft_(res, a, n, axis, s); } static inline int mlx_fft_fft2( mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_fft2_(res, a, n, n_num, axes, axes_num, s); } static inline int mlx_fft_fftn( mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_fftn_(res, a, n, n_num, axes, axes_num, s); } static inline int mlx_fft_fftshift( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_fftshift_(res, a, axes, axes_num, s); } static inline int mlx_fft_ifft( mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) { return mlx_fft_ifft_(res, a, n, axis, s); } static inline int mlx_fft_ifft2( mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_ifft2_(res, a, n, n_num, axes, axes_num, s); } static inline int mlx_fft_ifftn( mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_ifftn_(res, a, n, n_num, axes, axes_num, s); } static inline int mlx_fft_ifftshift( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_ifftshift_(res, a, axes, axes_num, s); } static inline int mlx_fft_irfft( mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) { return mlx_fft_irfft_(res, a, n, axis, s); } static inline int mlx_fft_irfft2( mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_irfft2_(res, a, n, n_num, axes, axes_num, s); } static inline int mlx_fft_irfftn( mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_irfftn_(res, a, n, n_num, axes, axes_num, s); } static inline int mlx_fft_rfft( mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) { return mlx_fft_rfft_(res, a, n, axis, s); } static inline int mlx_fft_rfft2( mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_rfft2_(res, a, n, n_num, axes, axes_num, s); } static inline int mlx_fft_rfftn( mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_fft_rfftn_(res, a, n, n_num, axes, axes_num, s); } static inline int mlx_load_reader( mlx_array* res, mlx_io_reader in_stream, const mlx_stream s) { return mlx_load_reader_(res, in_stream, s); } static inline int mlx_load(mlx_array* res, const char* file, const mlx_stream s) { return mlx_load_(res, file, s); } static inline int mlx_load_safetensors_reader( mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, mlx_io_reader in_stream, const mlx_stream s) { return mlx_load_safetensors_reader_(res_0, res_1, in_stream, s); } static inline int mlx_load_safetensors( mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, const char* file, const mlx_stream s) { return mlx_load_safetensors_(res_0, res_1, file, s); } static inline int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a) { return mlx_save_writer_(out_stream, a); } static inline int mlx_save(const char* file, const mlx_array a) { return mlx_save_(file, a); } static inline int mlx_save_safetensors_writer( mlx_io_writer in_stream, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) { return mlx_save_safetensors_writer_(in_stream, param, metadata); } static inline int mlx_save_safetensors( const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) { return mlx_save_safetensors_(file, param, metadata); } static inline mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable) { return mlx_io_reader_new_(desc, vtable); } static inline int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io) { return mlx_io_reader_descriptor_(desc_, io); } static inline int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io) { return mlx_io_reader_tostring_(str_, io); } static inline int mlx_io_reader_free(mlx_io_reader io) { return mlx_io_reader_free_(io); } static inline mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable) { return mlx_io_writer_new_(desc, vtable); } static inline int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io) { return mlx_io_writer_descriptor_(desc_, io); } static inline int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io) { return mlx_io_writer_tostring_(str_, io); } static inline int mlx_io_writer_free(mlx_io_writer io) { return mlx_io_writer_free_(io); } static inline int mlx_linalg_cholesky( mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) { return mlx_linalg_cholesky_(res, a, upper, s); } static inline int mlx_linalg_cholesky_inv( mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) { return mlx_linalg_cholesky_inv_(res, a, upper, s); } static inline int mlx_linalg_cross( mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s) { return mlx_linalg_cross_(res, a, b, axis, s); } static inline int mlx_linalg_eig( mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) { return mlx_linalg_eig_(res_0, res_1, a, s); } static inline int mlx_linalg_eigh( mlx_array* res_0, mlx_array* res_1, const mlx_array a, const char* UPLO, const mlx_stream s) { return mlx_linalg_eigh_(res_0, res_1, a, UPLO, s); } static inline int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_linalg_eigvals_(res, a, s); } static inline int mlx_linalg_eigvalsh( mlx_array* res, const mlx_array a, const char* UPLO, const mlx_stream s) { return mlx_linalg_eigvalsh_(res, a, UPLO, s); } static inline int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_linalg_inv_(res, a, s); } static inline int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s) { return mlx_linalg_lu_(res, a, s); } static inline int mlx_linalg_lu_factor( mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) { return mlx_linalg_lu_factor_(res_0, res_1, a, s); } static inline int mlx_linalg_norm( mlx_array* res, const mlx_array a, double ord, const int* axis /* may be null */, size_t axis_num, bool keepdims, const mlx_stream s) { return mlx_linalg_norm_(res, a, ord, axis, axis_num, keepdims, s); } static inline int mlx_linalg_norm_matrix( mlx_array* res, const mlx_array a, const char* ord, const int* axis /* may be null */, size_t axis_num, bool keepdims, const mlx_stream s) { return mlx_linalg_norm_matrix_(res, a, ord, axis, axis_num, keepdims, s); } static inline int mlx_linalg_norm_l2( mlx_array* res, const mlx_array a, const int* axis /* may be null */, size_t axis_num, bool keepdims, const mlx_stream s) { return mlx_linalg_norm_l2_(res, a, axis, axis_num, keepdims, s); } static inline int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_linalg_pinv_(res, a, s); } static inline int mlx_linalg_qr( mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) { return mlx_linalg_qr_(res_0, res_1, a, s); } static inline int mlx_linalg_solve( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_linalg_solve_(res, a, b, s); } static inline int mlx_linalg_solve_triangular( mlx_array* res, const mlx_array a, const mlx_array b, bool upper, const mlx_stream s) { return mlx_linalg_solve_triangular_(res, a, b, upper, s); } static inline int mlx_linalg_svd( mlx_vector_array* res, const mlx_array a, bool compute_uv, const mlx_stream s) { return mlx_linalg_svd_(res, a, compute_uv, s); } static inline int mlx_linalg_tri_inv( mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) { return mlx_linalg_tri_inv_(res, a, upper, s); } static inline mlx_map_string_to_array mlx_map_string_to_array_new(void) { return mlx_map_string_to_array_new_(); } static inline int mlx_map_string_to_array_set( mlx_map_string_to_array* map, const mlx_map_string_to_array src) { return mlx_map_string_to_array_set_(map, src); } static inline int mlx_map_string_to_array_free(mlx_map_string_to_array map) { return mlx_map_string_to_array_free_(map); } static inline int mlx_map_string_to_array_insert( mlx_map_string_to_array map, const char* key, const mlx_array value) { return mlx_map_string_to_array_insert_(map, key, value); } static inline int mlx_map_string_to_array_get( mlx_array* value, const mlx_map_string_to_array map, const char* key) { return mlx_map_string_to_array_get_(value, map, key); } static inline mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new( mlx_map_string_to_array map) { return mlx_map_string_to_array_iterator_new_(map); } static inline int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it) { return mlx_map_string_to_array_iterator_free_(it); } static inline int mlx_map_string_to_array_iterator_next( const char** key, mlx_array* value, mlx_map_string_to_array_iterator it) { return mlx_map_string_to_array_iterator_next_(key, value, it); } static inline mlx_map_string_to_string mlx_map_string_to_string_new(void) { return mlx_map_string_to_string_new_(); } static inline int mlx_map_string_to_string_set( mlx_map_string_to_string* map, const mlx_map_string_to_string src) { return mlx_map_string_to_string_set_(map, src); } static inline int mlx_map_string_to_string_free(mlx_map_string_to_string map) { return mlx_map_string_to_string_free_(map); } static inline int mlx_map_string_to_string_insert( mlx_map_string_to_string map, const char* key, const char* value) { return mlx_map_string_to_string_insert_(map, key, value); } static inline int mlx_map_string_to_string_get( const char** value, const mlx_map_string_to_string map, const char* key) { return mlx_map_string_to_string_get_(value, map, key); } static inline mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new( mlx_map_string_to_string map) { return mlx_map_string_to_string_iterator_new_(map); } static inline int mlx_map_string_to_string_iterator_free( mlx_map_string_to_string_iterator it) { return mlx_map_string_to_string_iterator_free_(it); } static inline int mlx_map_string_to_string_iterator_next( const char** key, const char** value, mlx_map_string_to_string_iterator it) { return mlx_map_string_to_string_iterator_next_(key, value, it); } static inline int mlx_clear_cache(void) { return mlx_clear_cache_(); } static inline int mlx_get_active_memory(size_t* res) { return mlx_get_active_memory_(res); } static inline int mlx_get_cache_memory(size_t* res) { return mlx_get_cache_memory_(res); } static inline int mlx_get_memory_limit(size_t* res) { return mlx_get_memory_limit_(res); } static inline int mlx_get_peak_memory(size_t* res) { return mlx_get_peak_memory_(res); } static inline int mlx_reset_peak_memory(void) { return mlx_reset_peak_memory_(); } static inline int mlx_set_cache_limit(size_t* res, size_t limit) { return mlx_set_cache_limit_(res, limit); } static inline int mlx_set_memory_limit(size_t* res, size_t limit) { return mlx_set_memory_limit_(res, limit); } static inline int mlx_set_wired_limit(size_t* res, size_t limit) { return mlx_set_wired_limit_(res, limit); } static inline int mlx_metal_is_available(bool* res) { return mlx_metal_is_available_(res); } static inline int mlx_metal_start_capture(const char* path) { return mlx_metal_start_capture_(path); } static inline int mlx_metal_stop_capture(void) { return mlx_metal_stop_capture_(); } static inline int mlx_abs(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_abs_(res, a, s); } static inline int mlx_add( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_add_(res, a, b, s); } static inline int mlx_addmm( mlx_array* res, const mlx_array c, const mlx_array a, const mlx_array b, float alpha, float beta, const mlx_stream s) { return mlx_addmm_(res, c, a, b, alpha, beta, s); } static inline int mlx_all_axes( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { return mlx_all_axes_(res, a, axes, axes_num, keepdims, s); } static inline int mlx_all_axis( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_all_axis_(res, a, axis, keepdims, s); } static inline int mlx_all( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_all_(res, a, keepdims, s); } static inline int mlx_allclose( mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s) { return mlx_allclose_(res, a, b, rtol, atol, equal_nan, s); } static inline int mlx_any_axes( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { return mlx_any_axes_(res, a, axes, axes_num, keepdims, s); } static inline int mlx_any_axis( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_any_axis_(res, a, axis, keepdims, s); } static inline int mlx_any( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_any_(res, a, keepdims, s); } static inline int mlx_arange( mlx_array* res, double start, double stop, double step, mlx_dtype dtype, const mlx_stream s) { return mlx_arange_(res, start, stop, step, dtype, s); } static inline int mlx_arccos(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arccos_(res, a, s); } static inline int mlx_arccosh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arccosh_(res, a, s); } static inline int mlx_arcsin(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arcsin_(res, a, s); } static inline int mlx_arcsinh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arcsinh_(res, a, s); } static inline int mlx_arctan(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arctan_(res, a, s); } static inline int mlx_arctan2( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_arctan2_(res, a, b, s); } static inline int mlx_arctanh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_arctanh_(res, a, s); } static inline int mlx_argmax_axis( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_argmax_axis_(res, a, axis, keepdims, s); } static inline int mlx_argmax( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_argmax_(res, a, keepdims, s); } static inline int mlx_argmin_axis( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_argmin_axis_(res, a, axis, keepdims, s); } static inline int mlx_argmin( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_argmin_(res, a, keepdims, s); } static inline int mlx_argpartition_axis( mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s) { return mlx_argpartition_axis_(res, a, kth, axis, s); } static inline int mlx_argpartition( mlx_array* res, const mlx_array a, int kth, const mlx_stream s) { return mlx_argpartition_(res, a, kth, s); } static inline int mlx_argsort_axis( mlx_array* res, const mlx_array a, int axis, const mlx_stream s) { return mlx_argsort_axis_(res, a, axis, s); } static inline int mlx_argsort(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_argsort_(res, a, s); } static inline int mlx_array_equal( mlx_array* res, const mlx_array a, const mlx_array b, bool equal_nan, const mlx_stream s) { return mlx_array_equal_(res, a, b, equal_nan, s); } static inline int mlx_as_strided( mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const int64_t* strides, size_t strides_num, size_t offset, const mlx_stream s) { return mlx_as_strided_(res, a, shape, shape_num, strides, strides_num, offset, s); } static inline int mlx_astype( mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s) { return mlx_astype_(res, a, dtype, s); } static inline int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_atleast_1d_(res, a, s); } static inline int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_atleast_2d_(res, a, s); } static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_atleast_3d_(res, a, s); } static inline int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) { return mlx_bartlett_(res, M, s); } static inline int mlx_bitwise_and( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_bitwise_and_(res, a, b, s); } static inline int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_bitwise_invert_(res, a, s); } static inline int mlx_bitwise_or( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_bitwise_or_(res, a, b, s); } static inline int mlx_bitwise_xor( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_bitwise_xor_(res, a, b, s); } static inline int mlx_blackman(mlx_array* res, int M, const mlx_stream s) { return mlx_blackman_(res, M, s); } static inline int mlx_block_masked_mm( mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out /* may be null */, const mlx_array mask_lhs /* may be null */, const mlx_array mask_rhs /* may be null */, const mlx_stream s) { return mlx_block_masked_mm_(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s); } static inline int mlx_broadcast_arrays( mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) { return mlx_broadcast_arrays_(res, inputs, s); } static inline int mlx_broadcast_to( mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) { return mlx_broadcast_to_(res, a, shape, shape_num, s); } static inline int mlx_ceil(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_ceil_(res, a, s); } static inline int mlx_clip( mlx_array* res, const mlx_array a, const mlx_array a_min /* may be null */, const mlx_array a_max /* may be null */, const mlx_stream s) { return mlx_clip_(res, a, a_min, a_max, s); } static inline int mlx_concatenate_axis( mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s) { return mlx_concatenate_axis_(res, arrays, axis, s); } static inline int mlx_concatenate( mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) { return mlx_concatenate_(res, arrays, s); } static inline int mlx_conjugate(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_conjugate_(res, a, s); } static inline int mlx_contiguous( mlx_array* res, const mlx_array a, bool allow_col_major, const mlx_stream s) { return mlx_contiguous_(res, a, allow_col_major, s); } static inline int mlx_conv1d( mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int groups, const mlx_stream s) { return mlx_conv1d_(res, input, weight, stride, padding, dilation, groups, s); } static inline int mlx_conv2d( mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int groups, const mlx_stream s) { return mlx_conv2d_(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, groups, s); } static inline int mlx_conv3d( mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int groups, const mlx_stream s) { return mlx_conv3d_(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, groups, s); } static inline int mlx_conv_general( mlx_array* res, const mlx_array input, const mlx_array weight, const int* stride, size_t stride_num, const int* padding_lo, size_t padding_lo_num, const int* padding_hi, size_t padding_hi_num, const int* kernel_dilation, size_t kernel_dilation_num, const int* input_dilation, size_t input_dilation_num, int groups, bool flip, const mlx_stream s) { return mlx_conv_general_(res, input, weight, stride, stride_num, padding_lo, padding_lo_num, padding_hi, padding_hi_num, kernel_dilation, kernel_dilation_num, input_dilation, input_dilation_num, groups, flip, s); } static inline int mlx_conv_transpose1d( mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int output_padding, int groups, const mlx_stream s) { return mlx_conv_transpose1d_(res, input, weight, stride, padding, dilation, output_padding, groups, s); } static inline int mlx_conv_transpose2d( mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int output_padding_0, int output_padding_1, int groups, const mlx_stream s) { return mlx_conv_transpose2d_(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, output_padding_0, output_padding_1, groups, s); } static inline int mlx_conv_transpose3d( mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int output_padding_0, int output_padding_1, int output_padding_2, int groups, const mlx_stream s) { return mlx_conv_transpose3d_(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, output_padding_0, output_padding_1, output_padding_2, groups, s); } static inline int mlx_copy(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_copy_(res, a, s); } static inline int mlx_cos(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_cos_(res, a, s); } static inline int mlx_cosh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_cosh_(res, a, s); } static inline int mlx_cummax( mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) { return mlx_cummax_(res, a, axis, reverse, inclusive, s); } static inline int mlx_cummin( mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) { return mlx_cummin_(res, a, axis, reverse, inclusive, s); } static inline int mlx_cumprod( mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) { return mlx_cumprod_(res, a, axis, reverse, inclusive, s); } static inline int mlx_cumsum( mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) { return mlx_cumsum_(res, a, axis, reverse, inclusive, s); } static inline int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_degrees_(res, a, s); } static inline int mlx_depends( mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) { return mlx_depends_(res, inputs, dependencies); } static inline int mlx_dequantize( mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases /* may be null */, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale /* may be null */, mlx_optional_dtype dtype, const mlx_stream s) { return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s); } static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) { return mlx_diag_(res, a, k, s); } static inline int mlx_diagonal( mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) { return mlx_diagonal_(res, a, offset, axis1, axis2, s); } static inline int mlx_divide( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_divide_(res, a, b, s); } static inline int mlx_divmod( mlx_vector_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_divmod_(res, a, b, s); } static inline int mlx_einsum( mlx_array* res, const char* subscripts, const mlx_vector_array operands, const mlx_stream s) { return mlx_einsum_(res, subscripts, operands, s); } static inline int mlx_equal( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_equal_(res, a, b, s); } static inline int mlx_erf(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_erf_(res, a, s); } static inline int mlx_erfinv(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_erfinv_(res, a, s); } static inline int mlx_exp(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_exp_(res, a, s); } static inline int mlx_expand_dims_axes( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_expand_dims_axes_(res, a, axes, axes_num, s); } static inline int mlx_expand_dims( mlx_array* res, const mlx_array a, int axis, const mlx_stream s) { return mlx_expand_dims_(res, a, axis, s); } static inline int mlx_expm1(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_expm1_(res, a, s); } static inline int mlx_eye( mlx_array* res, int n, int m, int k, mlx_dtype dtype, const mlx_stream s) { return mlx_eye_(res, n, m, k, dtype, s); } static inline int mlx_flatten( mlx_array* res, const mlx_array a, int start_axis, int end_axis, const mlx_stream s) { return mlx_flatten_(res, a, start_axis, end_axis, s); } static inline int mlx_floor(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_floor_(res, a, s); } static inline int mlx_floor_divide( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_floor_divide_(res, a, b, s); } static inline int mlx_from_fp8( mlx_array* res, const mlx_array x, mlx_dtype dtype, const mlx_stream s) { return mlx_from_fp8_(res, x, dtype, s); } static inline int mlx_full( mlx_array* res, const int* shape, size_t shape_num, const mlx_array vals, mlx_dtype dtype, const mlx_stream s) { return mlx_full_(res, shape, shape_num, vals, dtype, s); } static inline int mlx_full_like( mlx_array* res, const mlx_array a, const mlx_array vals, mlx_dtype dtype, const mlx_stream s) { return mlx_full_like_(res, a, vals, dtype, s); } static inline int mlx_gather( mlx_array* res, const mlx_array a, const mlx_vector_array indices, const int* axes, size_t axes_num, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) { return mlx_gather_(res, a, indices, axes, axes_num, slice_sizes, slice_sizes_num, s); } static inline int mlx_gather_single( mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) { return mlx_gather_single_(res, a, indices, axis, slice_sizes, slice_sizes_num, s); } static inline int mlx_gather_mm( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array lhs_indices /* may be null */, const mlx_array rhs_indices /* may be null */, bool sorted_indices, const mlx_stream s) { return mlx_gather_mm_(res, a, b, lhs_indices, rhs_indices, sorted_indices, s); } static inline int mlx_gather_qmm( mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases /* may be null */, const mlx_array lhs_indices /* may be null */, const mlx_array rhs_indices /* may be null */, bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, bool sorted_indices, const mlx_stream s) { return mlx_gather_qmm_(res, x, w, scales, biases, lhs_indices, rhs_indices, transpose, group_size, bits, mode, sorted_indices, s); } static inline int mlx_greater( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_greater_(res, a, b, s); } static inline int mlx_greater_equal( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_greater_equal_(res, a, b, s); } static inline int mlx_hadamard_transform( mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) { return mlx_hadamard_transform_(res, a, scale, s); } static inline int mlx_hamming(mlx_array* res, int M, const mlx_stream s) { return mlx_hamming_(res, M, s); } static inline int mlx_hanning(mlx_array* res, int M, const mlx_stream s) { return mlx_hanning_(res, M, s); } static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) { return mlx_identity_(res, n, dtype, s); } static inline int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_imag_(res, a, s); } static inline int mlx_inner( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_inner_(res, a, b, s); } static inline int mlx_isclose( mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s) { return mlx_isclose_(res, a, b, rtol, atol, equal_nan, s); } static inline int mlx_isfinite(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isfinite_(res, a, s); } static inline int mlx_isinf(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isinf_(res, a, s); } static inline int mlx_isnan(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isnan_(res, a, s); } static inline int mlx_isneginf(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isneginf_(res, a, s); } static inline int mlx_isposinf(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_isposinf_(res, a, s); } static inline int mlx_kron( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_kron_(res, a, b, s); } static inline int mlx_left_shift( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_left_shift_(res, a, b, s); } static inline int mlx_less( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_less_(res, a, b, s); } static inline int mlx_less_equal( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_less_equal_(res, a, b, s); } static inline int mlx_linspace( mlx_array* res, double start, double stop, int num, mlx_dtype dtype, const mlx_stream s) { return mlx_linspace_(res, start, stop, num, dtype, s); } static inline int mlx_log(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_log_(res, a, s); } static inline int mlx_log10(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_log10_(res, a, s); } static inline int mlx_log1p(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_log1p_(res, a, s); } static inline int mlx_log2(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_log2_(res, a, s); } static inline int mlx_logaddexp( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_logaddexp_(res, a, b, s); } static inline int mlx_logcumsumexp( mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) { return mlx_logcumsumexp_(res, a, axis, reverse, inclusive, s); } static inline int mlx_logical_and( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_logical_and_(res, a, b, s); } static inline int mlx_logical_not(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_logical_not_(res, a, s); } static inline int mlx_logical_or( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_logical_or_(res, a, b, s); } static inline int mlx_logsumexp_axes( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { return mlx_logsumexp_axes_(res, a, axes, axes_num, keepdims, s); } static inline int mlx_logsumexp_axis( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_logsumexp_axis_(res, a, axis, keepdims, s); } static inline int mlx_logsumexp( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_logsumexp_(res, a, keepdims, s); } static inline int mlx_masked_scatter( mlx_array* res, const mlx_array a, const mlx_array mask, const mlx_array src, const mlx_stream s) { return mlx_masked_scatter_(res, a, mask, src, s); } static inline int mlx_matmul( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_matmul_(res, a, b, s); } static inline int mlx_max_axes( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { return mlx_max_axes_(res, a, axes, axes_num, keepdims, s); } static inline int mlx_max_axis( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_max_axis_(res, a, axis, keepdims, s); } static inline int mlx_max( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_max_(res, a, keepdims, s); } static inline int mlx_maximum( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_maximum_(res, a, b, s); } static inline int mlx_mean_axes( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { return mlx_mean_axes_(res, a, axes, axes_num, keepdims, s); } static inline int mlx_mean_axis( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_mean_axis_(res, a, axis, keepdims, s); } static inline int mlx_mean( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_mean_(res, a, keepdims, s); } static inline int mlx_median( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { return mlx_median_(res, a, axes, axes_num, keepdims, s); } static inline int mlx_meshgrid( mlx_vector_array* res, const mlx_vector_array arrays, bool sparse, const char* indexing, const mlx_stream s) { return mlx_meshgrid_(res, arrays, sparse, indexing, s); } static inline int mlx_min_axes( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { return mlx_min_axes_(res, a, axes, axes_num, keepdims, s); } static inline int mlx_min_axis( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_min_axis_(res, a, axis, keepdims, s); } static inline int mlx_min( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_min_(res, a, keepdims, s); } static inline int mlx_minimum( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_minimum_(res, a, b, s); } static inline int mlx_moveaxis( mlx_array* res, const mlx_array a, int source, int destination, const mlx_stream s) { return mlx_moveaxis_(res, a, source, destination, s); } static inline int mlx_multiply( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_multiply_(res, a, b, s); } static inline int mlx_nan_to_num( mlx_array* res, const mlx_array a, float nan, mlx_optional_float posinf, mlx_optional_float neginf, const mlx_stream s) { return mlx_nan_to_num_(res, a, nan, posinf, neginf, s); } static inline int mlx_negative(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_negative_(res, a, s); } static inline int mlx_not_equal( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_not_equal_(res, a, b, s); } static inline int mlx_number_of_elements( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool inverted, mlx_dtype dtype, const mlx_stream s) { return mlx_number_of_elements_(res, a, axes, axes_num, inverted, dtype, s); } static inline int mlx_ones( mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s) { return mlx_ones_(res, shape, shape_num, dtype, s); } static inline int mlx_ones_like(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_ones_like_(res, a, s); } static inline int mlx_outer( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_outer_(res, a, b, s); } static inline int mlx_pad( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const int* low_pad_size, size_t low_pad_size_num, const int* high_pad_size, size_t high_pad_size_num, const mlx_array pad_value, const char* mode, const mlx_stream s) { return mlx_pad_(res, a, axes, axes_num, low_pad_size, low_pad_size_num, high_pad_size, high_pad_size_num, pad_value, mode, s); } static inline int mlx_pad_symmetric( mlx_array* res, const mlx_array a, int pad_width, const mlx_array pad_value, const char* mode, const mlx_stream s) { return mlx_pad_symmetric_(res, a, pad_width, pad_value, mode, s); } static inline int mlx_partition_axis( mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s) { return mlx_partition_axis_(res, a, kth, axis, s); } static inline int mlx_partition( mlx_array* res, const mlx_array a, int kth, const mlx_stream s) { return mlx_partition_(res, a, kth, s); } static inline int mlx_power( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_power_(res, a, b, s); } static inline int mlx_prod_axes( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { return mlx_prod_axes_(res, a, axes, axes_num, keepdims, s); } static inline int mlx_prod_axis( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_prod_axis_(res, a, axis, keepdims, s); } static inline int mlx_prod( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_prod_(res, a, keepdims, s); } static inline int mlx_put_along_axis( mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) { return mlx_put_along_axis_(res, a, indices, values, axis, s); } static inline int mlx_qqmm( mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales /* may be null */, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x /* may be null */, const mlx_array global_scale_w /* may be null */, const mlx_stream s) { return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s); } static inline int mlx_quantize( mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale /* may be null */, const mlx_stream s) { return mlx_quantize_(res, w, group_size, bits, mode, global_scale, s); } static inline int mlx_quantized_matmul( mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases /* may be null */, bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) { return mlx_quantized_matmul_(res, x, w, scales, biases, transpose, group_size, bits, mode, s); } static inline int mlx_radians(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_radians_(res, a, s); } static inline int mlx_real(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_real_(res, a, s); } static inline int mlx_reciprocal(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_reciprocal_(res, a, s); } static inline int mlx_remainder( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_remainder_(res, a, b, s); } static inline int mlx_repeat_axis( mlx_array* res, const mlx_array arr, int repeats, int axis, const mlx_stream s) { return mlx_repeat_axis_(res, arr, repeats, axis, s); } static inline int mlx_repeat( mlx_array* res, const mlx_array arr, int repeats, const mlx_stream s) { return mlx_repeat_(res, arr, repeats, s); } static inline int mlx_reshape( mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) { return mlx_reshape_(res, a, shape, shape_num, s); } static inline int mlx_right_shift( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_right_shift_(res, a, b, s); } static inline int mlx_roll_axis( mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, int axis, const mlx_stream s) { return mlx_roll_axis_(res, a, shift, shift_num, axis, s); } static inline int mlx_roll_axes( mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_roll_axes_(res, a, shift, shift_num, axes, axes_num, s); } static inline int mlx_roll( mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const mlx_stream s) { return mlx_roll_(res, a, shift, shift_num, s); } static inline int mlx_round( mlx_array* res, const mlx_array a, int decimals, const mlx_stream s) { return mlx_round_(res, a, decimals, s); } static inline int mlx_rsqrt(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_rsqrt_(res, a, s); } static inline int mlx_scatter( mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_scatter_(res, a, indices, updates, axes, axes_num, s); } static inline int mlx_scatter_single( mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) { return mlx_scatter_single_(res, a, indices, updates, axis, s); } static inline int mlx_scatter_add( mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_scatter_add_(res, a, indices, updates, axes, axes_num, s); } static inline int mlx_scatter_add_single( mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) { return mlx_scatter_add_single_(res, a, indices, updates, axis, s); } static inline int mlx_scatter_add_axis( mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) { return mlx_scatter_add_axis_(res, a, indices, values, axis, s); } static inline int mlx_scatter_max( mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_scatter_max_(res, a, indices, updates, axes, axes_num, s); } static inline int mlx_scatter_max_single( mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) { return mlx_scatter_max_single_(res, a, indices, updates, axis, s); } static inline int mlx_scatter_min( mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_scatter_min_(res, a, indices, updates, axes, axes_num, s); } static inline int mlx_scatter_min_single( mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) { return mlx_scatter_min_single_(res, a, indices, updates, axis, s); } static inline int mlx_scatter_prod( mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_scatter_prod_(res, a, indices, updates, axes, axes_num, s); } static inline int mlx_scatter_prod_single( mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) { return mlx_scatter_prod_single_(res, a, indices, updates, axis, s); } static inline int mlx_segmented_mm( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array segments, const mlx_stream s) { return mlx_segmented_mm_(res, a, b, segments, s); } static inline int mlx_sigmoid(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sigmoid_(res, a, s); } static inline int mlx_sign(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sign_(res, a, s); } static inline int mlx_sin(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sin_(res, a, s); } static inline int mlx_sinh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sinh_(res, a, s); } static inline int mlx_slice( mlx_array* res, const mlx_array a, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) { return mlx_slice_(res, a, start, start_num, stop, stop_num, strides, strides_num, s); } static inline int mlx_slice_dynamic( mlx_array* res, const mlx_array a, const mlx_array start, const int* axes, size_t axes_num, const int* slice_size, size_t slice_size_num, const mlx_stream s) { return mlx_slice_dynamic_(res, a, start, axes, axes_num, slice_size, slice_size_num, s); } static inline int mlx_slice_update( mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) { return mlx_slice_update_(res, src, update, start, start_num, stop, stop_num, strides, strides_num, s); } static inline int mlx_slice_update_dynamic( mlx_array* res, const mlx_array src, const mlx_array update, const mlx_array start, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_slice_update_dynamic_(res, src, update, start, axes, axes_num, s); } static inline int mlx_softmax_axes( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool precise, const mlx_stream s) { return mlx_softmax_axes_(res, a, axes, axes_num, precise, s); } static inline int mlx_softmax_axis( mlx_array* res, const mlx_array a, int axis, bool precise, const mlx_stream s) { return mlx_softmax_axis_(res, a, axis, precise, s); } static inline int mlx_softmax( mlx_array* res, const mlx_array a, bool precise, const mlx_stream s) { return mlx_softmax_(res, a, precise, s); } static inline int mlx_sort_axis( mlx_array* res, const mlx_array a, int axis, const mlx_stream s) { return mlx_sort_axis_(res, a, axis, s); } static inline int mlx_sort(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sort_(res, a, s); } static inline int mlx_split( mlx_vector_array* res, const mlx_array a, int num_splits, int axis, const mlx_stream s) { return mlx_split_(res, a, num_splits, axis, s); } static inline int mlx_split_sections( mlx_vector_array* res, const mlx_array a, const int* indices, size_t indices_num, int axis, const mlx_stream s) { return mlx_split_sections_(res, a, indices, indices_num, axis, s); } static inline int mlx_sqrt(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_sqrt_(res, a, s); } static inline int mlx_square(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_square_(res, a, s); } static inline int mlx_squeeze_axes( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_squeeze_axes_(res, a, axes, axes_num, s); } static inline int mlx_squeeze_axis( mlx_array* res, const mlx_array a, int axis, const mlx_stream s) { return mlx_squeeze_axis_(res, a, axis, s); } static inline int mlx_squeeze(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_squeeze_(res, a, s); } static inline int mlx_stack_axis( mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s) { return mlx_stack_axis_(res, arrays, axis, s); } static inline int mlx_stack( mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) { return mlx_stack_(res, arrays, s); } static inline int mlx_std_axes( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s) { return mlx_std_axes_(res, a, axes, axes_num, keepdims, ddof, s); } static inline int mlx_std_axis( mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s) { return mlx_std_axis_(res, a, axis, keepdims, ddof, s); } static inline int mlx_std( mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s) { return mlx_std_(res, a, keepdims, ddof, s); } static inline int mlx_stop_gradient(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_stop_gradient_(res, a, s); } static inline int mlx_subtract( mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { return mlx_subtract_(res, a, b, s); } static inline int mlx_sum_axes( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { return mlx_sum_axes_(res, a, axes, axes_num, keepdims, s); } static inline int mlx_sum_axis( mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { return mlx_sum_axis_(res, a, axis, keepdims, s); } static inline int mlx_sum( mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { return mlx_sum_(res, a, keepdims, s); } static inline int mlx_swapaxes( mlx_array* res, const mlx_array a, int axis1, int axis2, const mlx_stream s) { return mlx_swapaxes_(res, a, axis1, axis2, s); } static inline int mlx_take_axis( mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s) { return mlx_take_axis_(res, a, indices, axis, s); } static inline int mlx_take( mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_stream s) { return mlx_take_(res, a, indices, s); } static inline int mlx_take_along_axis( mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s) { return mlx_take_along_axis_(res, a, indices, axis, s); } static inline int mlx_tan(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_tan_(res, a, s); } static inline int mlx_tanh(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_tanh_(res, a, s); } static inline int mlx_tensordot( mlx_array* res, const mlx_array a, const mlx_array b, const int* axes_a, size_t axes_a_num, const int* axes_b, size_t axes_b_num, const mlx_stream s) { return mlx_tensordot_(res, a, b, axes_a, axes_a_num, axes_b, axes_b_num, s); } static inline int mlx_tensordot_axis( mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s) { return mlx_tensordot_axis_(res, a, b, axis, s); } static inline int mlx_tile( mlx_array* res, const mlx_array arr, const int* reps, size_t reps_num, const mlx_stream s) { return mlx_tile_(res, arr, reps, reps_num, s); } static inline int mlx_to_fp8(mlx_array* res, const mlx_array x, const mlx_stream s) { return mlx_to_fp8_(res, x, s); } static inline int mlx_topk_axis( mlx_array* res, const mlx_array a, int k, int axis, const mlx_stream s) { return mlx_topk_axis_(res, a, k, axis, s); } static inline int mlx_topk(mlx_array* res, const mlx_array a, int k, const mlx_stream s) { return mlx_topk_(res, a, k, s); } static inline int mlx_trace( mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, mlx_dtype dtype, const mlx_stream s) { return mlx_trace_(res, a, offset, axis1, axis2, dtype, s); } static inline int mlx_transpose_axes( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) { return mlx_transpose_axes_(res, a, axes, axes_num, s); } static inline int mlx_transpose(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_transpose_(res, a, s); } static inline int mlx_tri( mlx_array* res, int n, int m, int k, mlx_dtype type, const mlx_stream s) { return mlx_tri_(res, n, m, k, type, s); } static inline int mlx_tril(mlx_array* res, const mlx_array x, int k, const mlx_stream s) { return mlx_tril_(res, x, k, s); } static inline int mlx_triu(mlx_array* res, const mlx_array x, int k, const mlx_stream s) { return mlx_triu_(res, x, k, s); } static inline int mlx_unflatten( mlx_array* res, const mlx_array a, int axis, const int* shape, size_t shape_num, const mlx_stream s) { return mlx_unflatten_(res, a, axis, shape, shape_num, s); } static inline int mlx_var_axes( mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s) { return mlx_var_axes_(res, a, axes, axes_num, keepdims, ddof, s); } static inline int mlx_var_axis( mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s) { return mlx_var_axis_(res, a, axis, keepdims, ddof, s); } static inline int mlx_var( mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s) { return mlx_var_(res, a, keepdims, ddof, s); } static inline int mlx_view( mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s) { return mlx_view_(res, a, dtype, s); } static inline int mlx_where( mlx_array* res, const mlx_array condition, const mlx_array x, const mlx_array y, const mlx_stream s) { return mlx_where_(res, condition, x, y, s); } static inline int mlx_zeros( mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s) { return mlx_zeros_(res, shape, shape_num, dtype, s); } static inline int mlx_zeros_like(mlx_array* res, const mlx_array a, const mlx_stream s) { return mlx_zeros_like_(res, a, s); } static inline int mlx_random_bernoulli( mlx_array* res, const mlx_array p, const int* shape, size_t shape_num, const mlx_array key /* may be null */, const mlx_stream s) { return mlx_random_bernoulli_(res, p, shape, shape_num, key, s); } static inline int mlx_random_bits( mlx_array* res, const int* shape, size_t shape_num, int width, const mlx_array key /* may be null */, const mlx_stream s) { return mlx_random_bits_(res, shape, shape_num, width, key, s); } static inline int mlx_random_categorical_shape( mlx_array* res, const mlx_array logits, int axis, const int* shape, size_t shape_num, const mlx_array key /* may be null */, const mlx_stream s) { return mlx_random_categorical_shape_(res, logits, axis, shape, shape_num, key, s); } static inline int mlx_random_categorical_num_samples( mlx_array* res, const mlx_array logits_, int axis, int num_samples, const mlx_array key /* may be null */, const mlx_stream s) { return mlx_random_categorical_num_samples_(res, logits_, axis, num_samples, key, s); } static inline int mlx_random_categorical( mlx_array* res, const mlx_array logits, int axis, const mlx_array key /* may be null */, const mlx_stream s) { return mlx_random_categorical_(res, logits, axis, key, s); } static inline int mlx_random_gumbel( mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key /* may be null */, const mlx_stream s) { return mlx_random_gumbel_(res, shape, shape_num, dtype, key, s); } static inline int mlx_random_key(mlx_array* res, uint64_t seed) { return mlx_random_key_(res, seed); } static inline int mlx_random_laplace( mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key /* may be null */, const mlx_stream s) { return mlx_random_laplace_(res, shape, shape_num, dtype, loc, scale, key, s); } static inline int mlx_random_multivariate_normal( mlx_array* res, const mlx_array mean, const mlx_array cov, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key /* may be null */, const mlx_stream s) { return mlx_random_multivariate_normal_(res, mean, cov, shape, shape_num, dtype, key, s); } static inline int mlx_random_normal_broadcast( mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array loc /* may be null */, const mlx_array scale /* may be null */, const mlx_array key /* may be null */, const mlx_stream s) { return mlx_random_normal_broadcast_(res, shape, shape_num, dtype, loc, scale, key, s); } static inline int mlx_random_normal( mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key /* may be null */, const mlx_stream s) { return mlx_random_normal_(res, shape, shape_num, dtype, loc, scale, key, s); } static inline int mlx_random_permutation( mlx_array* res, const mlx_array x, int axis, const mlx_array key /* may be null */, const mlx_stream s) { return mlx_random_permutation_(res, x, axis, key, s); } static inline int mlx_random_permutation_arange( mlx_array* res, int x, const mlx_array key /* may be null */, const mlx_stream s) { return mlx_random_permutation_arange_(res, x, key, s); } static inline int mlx_random_randint( mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key /* may be null */, const mlx_stream s) { return mlx_random_randint_(res, low, high, shape, shape_num, dtype, key, s); } static inline int mlx_random_seed(uint64_t seed) { return mlx_random_seed_(seed); } static inline int mlx_random_split_num( mlx_array* res, const mlx_array key, int num, const mlx_stream s) { return mlx_random_split_num_(res, key, num, s); } static inline int mlx_random_split( mlx_array* res_0, mlx_array* res_1, const mlx_array key, const mlx_stream s) { return mlx_random_split_(res_0, res_1, key, s); } static inline int mlx_random_truncated_normal( mlx_array* res, const mlx_array lower, const mlx_array upper, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key /* may be null */, const mlx_stream s) { return mlx_random_truncated_normal_(res, lower, upper, shape, shape_num, dtype, key, s); } static inline int mlx_random_uniform( mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key /* may be null */, const mlx_stream s) { return mlx_random_uniform_(res, low, high, shape, shape_num, dtype, key, s); } static inline mlx_stream mlx_stream_new(void) { return mlx_stream_new_(); } static inline mlx_stream mlx_stream_new_device(mlx_device dev) { return mlx_stream_new_device_(dev); } static inline int mlx_stream_set(mlx_stream* stream, const mlx_stream src) { return mlx_stream_set_(stream, src); } static inline int mlx_stream_free(mlx_stream stream) { return mlx_stream_free_(stream); } static inline int mlx_stream_tostring(mlx_string* str, mlx_stream stream) { return mlx_stream_tostring_(str, stream); } static inline bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs) { return mlx_stream_equal_(lhs, rhs); } static inline int mlx_stream_get_device(mlx_device* dev, mlx_stream stream) { return mlx_stream_get_device_(dev, stream); } static inline int mlx_stream_get_index(int* index, mlx_stream stream) { return mlx_stream_get_index_(index, stream); } static inline int mlx_synchronize(mlx_stream stream) { return mlx_synchronize_(stream); } static inline int mlx_get_default_stream(mlx_stream* stream, mlx_device dev) { return mlx_get_default_stream_(stream, dev); } static inline int mlx_set_default_stream(mlx_stream stream) { return mlx_set_default_stream_(stream); } static inline mlx_stream mlx_default_cpu_stream_new(void) { return mlx_default_cpu_stream_new_(); } static inline mlx_stream mlx_default_gpu_stream_new(void) { return mlx_default_gpu_stream_new_(); } static inline mlx_string mlx_string_new(void) { return mlx_string_new_(); } static inline mlx_string mlx_string_new_data(const char* str) { return mlx_string_new_data_(str); } static inline int mlx_string_set(mlx_string* str, const mlx_string src) { return mlx_string_set_(str, src); } static inline const char * mlx_string_data(mlx_string str) { return mlx_string_data_(str); } static inline int mlx_string_free(mlx_string str) { return mlx_string_free_(str); } static inline int mlx_async_eval(const mlx_vector_array outputs) { return mlx_async_eval_(outputs); } static inline int mlx_checkpoint(mlx_closure* res, const mlx_closure fun) { return mlx_checkpoint_(res, fun); } static inline int mlx_custom_function( mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp /* may be null */, const mlx_closure_custom_jvp fun_jvp /* may be null */, const mlx_closure_custom_vmap fun_vmap /* may be null */) { return mlx_custom_function_(res, fun, fun_vjp, fun_jvp, fun_vmap); } static inline int mlx_custom_vjp( mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp) { return mlx_custom_vjp_(res, fun, fun_vjp); } static inline int mlx_eval(const mlx_vector_array outputs) { return mlx_eval_(outputs); } static inline int mlx_jvp( mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array tangents) { return mlx_jvp_(res_0, res_1, fun, primals, tangents); } static inline int mlx_value_and_grad( mlx_closure_value_and_grad* res, const mlx_closure fun, const int* argnums, size_t argnums_num) { return mlx_value_and_grad_(res, fun, argnums, argnums_num); } static inline int mlx_vjp( mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array cotangents) { return mlx_vjp_(res_0, res_1, fun, primals, cotangents); } static inline int mlx_detail_vmap_replace( mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array s_inputs, const mlx_vector_array s_outputs, const int* in_axes, size_t in_axes_num, const int* out_axes, size_t out_axes_num) { return mlx_detail_vmap_replace_(res, inputs, s_inputs, s_outputs, in_axes, in_axes_num, out_axes, out_axes_num); } static inline int mlx_detail_vmap_trace( mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array inputs, const int* in_axes, size_t in_axes_num) { return mlx_detail_vmap_trace_(res_0, res_1, fun, inputs, in_axes, in_axes_num); } static inline mlx_vector_array mlx_vector_array_new(void) { return mlx_vector_array_new_(); } static inline int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src) { return mlx_vector_array_set_(vec, src); } static inline int mlx_vector_array_free(mlx_vector_array vec) { return mlx_vector_array_free_(vec); } static inline mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size) { return mlx_vector_array_new_data_(data, size); } static inline mlx_vector_array mlx_vector_array_new_value(const mlx_array val) { return mlx_vector_array_new_value_(val); } static inline int mlx_vector_array_set_data( mlx_vector_array* vec, const mlx_array* data, size_t size) { return mlx_vector_array_set_data_(vec, data, size); } static inline int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val) { return mlx_vector_array_set_value_(vec, val); } static inline int mlx_vector_array_append_data( mlx_vector_array vec, const mlx_array* data, size_t size) { return mlx_vector_array_append_data_(vec, data, size); } static inline int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val) { return mlx_vector_array_append_value_(vec, val); } static inline size_t mlx_vector_array_size(mlx_vector_array vec) { return mlx_vector_array_size_(vec); } static inline int mlx_vector_array_get( mlx_array* res, const mlx_vector_array vec, size_t idx) { return mlx_vector_array_get_(res, vec, idx); } static inline mlx_vector_vector_array mlx_vector_vector_array_new(void) { return mlx_vector_vector_array_new_(); } static inline int mlx_vector_vector_array_set( mlx_vector_vector_array* vec, const mlx_vector_vector_array src) { return mlx_vector_vector_array_set_(vec, src); } static inline int mlx_vector_vector_array_free(mlx_vector_vector_array vec) { return mlx_vector_vector_array_free_(vec); } static inline mlx_vector_vector_array mlx_vector_vector_array_new_data( const mlx_vector_array* data, size_t size) { return mlx_vector_vector_array_new_data_(data, size); } static inline mlx_vector_vector_array mlx_vector_vector_array_new_value( const mlx_vector_array val) { return mlx_vector_vector_array_new_value_(val); } static inline int mlx_vector_vector_array_set_data( mlx_vector_vector_array* vec, const mlx_vector_array* data, size_t size) { return mlx_vector_vector_array_set_data_(vec, data, size); } static inline int mlx_vector_vector_array_set_value( mlx_vector_vector_array* vec, const mlx_vector_array val) { return mlx_vector_vector_array_set_value_(vec, val); } static inline int mlx_vector_vector_array_append_data( mlx_vector_vector_array vec, const mlx_vector_array* data, size_t size) { return mlx_vector_vector_array_append_data_(vec, data, size); } static inline int mlx_vector_vector_array_append_value( mlx_vector_vector_array vec, const mlx_vector_array val) { return mlx_vector_vector_array_append_value_(vec, val); } static inline size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec) { return mlx_vector_vector_array_size_(vec); } static inline int mlx_vector_vector_array_get( mlx_vector_array* res, const mlx_vector_vector_array vec, size_t idx) { return mlx_vector_vector_array_get_(res, vec, idx); } static inline mlx_vector_int mlx_vector_int_new(void) { return mlx_vector_int_new_(); } static inline int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src) { return mlx_vector_int_set_(vec, src); } static inline int mlx_vector_int_free(mlx_vector_int vec) { return mlx_vector_int_free_(vec); } static inline mlx_vector_int mlx_vector_int_new_data(int* data, size_t size) { return mlx_vector_int_new_data_(data, size); } static inline mlx_vector_int mlx_vector_int_new_value(int val) { return mlx_vector_int_new_value_(val); } static inline int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size) { return mlx_vector_int_set_data_(vec, data, size); } static inline int mlx_vector_int_set_value(mlx_vector_int* vec, int val) { return mlx_vector_int_set_value_(vec, val); } static inline int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size) { return mlx_vector_int_append_data_(vec, data, size); } static inline int mlx_vector_int_append_value(mlx_vector_int vec, int val) { return mlx_vector_int_append_value_(vec, val); } static inline size_t mlx_vector_int_size(mlx_vector_int vec) { return mlx_vector_int_size_(vec); } static inline int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx) { return mlx_vector_int_get_(res, vec, idx); } static inline mlx_vector_string mlx_vector_string_new(void) { return mlx_vector_string_new_(); } static inline int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src) { return mlx_vector_string_set_(vec, src); } static inline int mlx_vector_string_free(mlx_vector_string vec) { return mlx_vector_string_free_(vec); } static inline mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size) { return mlx_vector_string_new_data_(data, size); } static inline mlx_vector_string mlx_vector_string_new_value(const char* val) { return mlx_vector_string_new_value_(val); } static inline int mlx_vector_string_set_data( mlx_vector_string* vec, const char** data, size_t size) { return mlx_vector_string_set_data_(vec, data, size); } static inline int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val) { return mlx_vector_string_set_value_(vec, val); } static inline int mlx_vector_string_append_data( mlx_vector_string vec, const char** data, size_t size) { return mlx_vector_string_append_data_(vec, data, size); } static inline int mlx_vector_string_append_value(mlx_vector_string vec, const char* val) { return mlx_vector_string_append_value_(vec, val); } static inline size_t mlx_vector_string_size(mlx_vector_string vec) { return mlx_vector_string_size_(vec); } static inline int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx) { return mlx_vector_string_get_(res, vec, idx); } static inline int mlx_version(mlx_string* str_) { return mlx_version_(str_); } #endif // MLX_GENERATED_H