Files
ollama/x/mlxrunner/mlx/include/mlx/c/random.h
Daniel Hiltgen 10e51c5177 MLX: add header vendoring and remove go build tag (#14642)
* prefer rocm v6 on windows

Avoid building with v7 - more changes are needed

* MLX: add header vendoring and remove go build tag

This switches to using a vendoring approach for the mlx-c headers so that Go
can build without requiring a cmake first.  This enables building the new MLX
based code by default.  Every time cmake runs, the headers are refreshed, so we
can easily keep them in sync when we bump mlx versions.  Basic Windows
and Linux support are verified.

* ci: harden for flaky choco repo servers

CI sometimes fails due to choco not actually installing cache.  Since it just speeds up the build, we can proceed without.

* review comments
2026-03-09 17:24:45 -07:00

167 lines
3.9 KiB
C

/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_RANDOM_H
#define MLX_RANDOM_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup random Random number operations
*/
/**@{*/
int mlx_random_bernoulli(
mlx_array* res,
const mlx_array p,
const int* shape,
size_t shape_num,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_bits(
mlx_array* res,
const int* shape,
size_t shape_num,
int width,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_categorical_shape(
mlx_array* res,
const mlx_array logits,
int axis,
const int* shape,
size_t shape_num,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_categorical_num_samples(
mlx_array* res,
const mlx_array logits_,
int axis,
int num_samples,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_categorical(
mlx_array* res,
const mlx_array logits,
int axis,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_gumbel(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_key(mlx_array* res, uint64_t seed);
int mlx_random_laplace(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
float loc,
float scale,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_multivariate_normal(
mlx_array* res,
const mlx_array mean,
const mlx_array cov,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_normal_broadcast(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array loc /* may be null */,
const mlx_array scale /* may be null */,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_normal(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
float loc,
float scale,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_permutation(
mlx_array* res,
const mlx_array x,
int axis,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_permutation_arange(
mlx_array* res,
int x,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_randint(
mlx_array* res,
const mlx_array low,
const mlx_array high,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_seed(uint64_t seed);
int mlx_random_split_num(
mlx_array* res,
const mlx_array key,
int num,
const mlx_stream s);
int mlx_random_split(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array key,
const mlx_stream s);
int mlx_random_truncated_normal(
mlx_array* res,
const mlx_array lower,
const mlx_array upper,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_uniform(
mlx_array* res,
const mlx_array low,
const mlx_array high,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif