mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
* prefer rocm v6 on windows Avoid building with v7 - more changes are needed * MLX: add header vendoring and remove go build tag This switches to using a vendoring approach for the mlx-c headers so that Go can build without requiring a cmake first. This enables building the new MLX based code by default. Every time cmake runs, the headers are refreshed, so we can easily keep them in sync when we bump mlx versions. Basic Windows and Linux support are verified. * ci: harden for flaky choco repo servers CI sometimes fails due to choco not actually installing cache. Since it just speeds up the build, we can proceed without. * review comments
2367 lines
68 KiB
Go
2367 lines
68 KiB
Go
package mlx
|
|
|
|
/*
|
|
#cgo CFLAGS: -O3 -I${SRCDIR}/../../mlxrunner/mlx/include -I${SRCDIR}
|
|
#cgo darwin LDFLAGS: -lc++ -framework Metal -framework Foundation -framework Accelerate
|
|
#cgo linux LDFLAGS: -lstdc++ -ldl
|
|
#cgo windows LDFLAGS: -lstdc++
|
|
|
|
// Use generated wrappers instead of direct MLX headers
|
|
#include "mlx.h"
|
|
#include <stdlib.h>
|
|
#include <stdint.h>
|
|
#include <string.h>
|
|
|
|
// Forward declare cpu_stream
|
|
static mlx_stream cpu_stream();
|
|
|
|
// Cached default GPU stream for all ops
|
|
static mlx_stream _default_stream = {0};
|
|
static mlx_stream _cpu_stream = {0};
|
|
|
|
static inline mlx_stream default_stream() {
|
|
if (_default_stream.ctx == NULL) {
|
|
_default_stream = mlx_default_gpu_stream_new();
|
|
}
|
|
return _default_stream;
|
|
}
|
|
|
|
static inline void set_default_stream(mlx_stream s) {
|
|
_default_stream = s;
|
|
}
|
|
|
|
// CPU stream for operations that only support CPU evaluation
|
|
static inline mlx_stream cpu_stream() {
|
|
if (_cpu_stream.ctx == NULL) {
|
|
_cpu_stream = mlx_default_cpu_stream_new();
|
|
}
|
|
return _cpu_stream;
|
|
}
|
|
|
|
// CGO noescape/nocallback hints to reduce CGO overhead
|
|
// noescape: pointers won't escape, no heap allocation needed
|
|
// nocallback: function won't call back into Go
|
|
*/
|
|
import "C"
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"reflect"
|
|
"runtime"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
"unsafe"
|
|
)
|
|
|
|
// Dtype represents MLX data types
|
|
type Dtype int
|
|
|
|
const (
|
|
DtypeBool Dtype = C.MLX_BOOL
|
|
DtypeUint8 Dtype = C.MLX_UINT8
|
|
DtypeUint16 Dtype = C.MLX_UINT16
|
|
DtypeUint32 Dtype = C.MLX_UINT32
|
|
DtypeUint64 Dtype = C.MLX_UINT64
|
|
DtypeInt8 Dtype = C.MLX_INT8
|
|
DtypeInt16 Dtype = C.MLX_INT16
|
|
DtypeInt32 Dtype = C.MLX_INT32
|
|
DtypeInt64 Dtype = C.MLX_INT64
|
|
DtypeFloat16 Dtype = C.MLX_FLOAT16
|
|
DtypeFloat32 Dtype = C.MLX_FLOAT32
|
|
DtypeFloat64 Dtype = C.MLX_FLOAT64
|
|
DtypeBFloat16 Dtype = C.MLX_BFLOAT16
|
|
DtypeComplex64 Dtype = C.MLX_COMPLEX64
|
|
)
|
|
|
|
// String implements fmt.Stringer for Dtype
|
|
func (d Dtype) String() string {
|
|
switch d {
|
|
case DtypeBool:
|
|
return "bool"
|
|
case DtypeUint8:
|
|
return "u8"
|
|
case DtypeUint16:
|
|
return "u16"
|
|
case DtypeUint32:
|
|
return "u32"
|
|
case DtypeUint64:
|
|
return "u64"
|
|
case DtypeInt8:
|
|
return "i8"
|
|
case DtypeInt16:
|
|
return "i16"
|
|
case DtypeInt32:
|
|
return "i32"
|
|
case DtypeInt64:
|
|
return "i64"
|
|
case DtypeFloat16:
|
|
return "f16"
|
|
case DtypeFloat32:
|
|
return "f32"
|
|
case DtypeFloat64:
|
|
return "f64"
|
|
case DtypeBFloat16:
|
|
return "bf16"
|
|
case DtypeComplex64:
|
|
return "c64"
|
|
default:
|
|
return "unknown"
|
|
}
|
|
}
|
|
|
|
// Memory Management:
|
|
//
|
|
// All arrays are automatically tracked for cleanup. On Eval(), non-kept arrays are freed.
|
|
//
|
|
// x := mlx.Matmul(input, weight) // x is tracked for cleanup
|
|
// mlx.Keep(x) // mark x as persistent
|
|
// mlx.Eval(x) // eval + free non-kept arrays
|
|
//
|
|
// Use Keep() for arrays that should persist (weights, caches).
|
|
// Use Free() to mark a kept array for cleanup on next Eval().
|
|
//
|
|
// Note: Not goroutine-safe. Use from a single goroutine.
|
|
|
|
// Array wraps an MLX array handle.
|
|
// Arrays are freed via Eval() cleanup (deterministic) or GC (fallback).
|
|
type Array struct {
|
|
c C.mlx_array
|
|
freed bool // Prevents double-free
|
|
kept bool // If true, survives Eval() cleanup
|
|
}
|
|
|
|
// arrays tracks all live arrays. On Eval(), non-kept arrays are freed.
|
|
// Not goroutine-safe.
|
|
var arrays = make([]*Array, 0, 4096)
|
|
|
|
// evalHandles is a pre-allocated slice for passing arrays to MLX eval.
|
|
var evalHandles = make([]C.mlx_array, 0, 64)
|
|
|
|
// arrayPool reduces allocations for intermediate arrays
|
|
var arrayPool = sync.Pool{
|
|
New: func() any { return &Array{} },
|
|
}
|
|
|
|
func newArray(array C.mlx_array) *Array {
|
|
// In compiled closures, MLX manages memory - skip Go tracking
|
|
if InClosureCallback() {
|
|
return &Array{c: array}
|
|
}
|
|
|
|
// Use pooled Array struct for efficiency
|
|
a := arrayPool.Get().(*Array)
|
|
a.c = array
|
|
a.freed = false
|
|
a.kept = false
|
|
|
|
// Track in global list
|
|
arrays = append(arrays, a)
|
|
|
|
return a
|
|
}
|
|
|
|
// Collect uses reflection to find all *Array fields in a struct (recursively).
|
|
// Use this to automatically gather model weights, cache state, etc.
|
|
func Collect(v any) []*Array {
|
|
var arrays []*Array
|
|
seen := make(map[uintptr]bool)
|
|
collect(reflect.ValueOf(v), &arrays, seen)
|
|
return arrays
|
|
}
|
|
|
|
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
|
|
if !v.IsValid() {
|
|
return
|
|
}
|
|
|
|
// Handle pointers
|
|
if v.Kind() == reflect.Ptr {
|
|
if v.IsNil() {
|
|
return
|
|
}
|
|
// Avoid infinite loops
|
|
ptr := v.Pointer()
|
|
if seen[ptr] {
|
|
return
|
|
}
|
|
seen[ptr] = true
|
|
|
|
// Check if it's *Array
|
|
if arr, ok := v.Interface().(*Array); ok {
|
|
if arr != nil && arr.c.ctx != nil {
|
|
*arrays = append(*arrays, arr)
|
|
}
|
|
return
|
|
}
|
|
collect(v.Elem(), arrays, seen)
|
|
return
|
|
}
|
|
|
|
// Handle structs
|
|
if v.Kind() == reflect.Struct {
|
|
for i := 0; i < v.NumField(); i++ {
|
|
field := v.Field(i)
|
|
if field.CanInterface() {
|
|
collect(field, arrays, seen)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// Handle slices
|
|
if v.Kind() == reflect.Slice {
|
|
for i := 0; i < v.Len(); i++ {
|
|
collect(v.Index(i), arrays, seen)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Handle maps
|
|
if v.Kind() == reflect.Map {
|
|
for _, key := range v.MapKeys() {
|
|
collect(v.MapIndex(key), arrays, seen)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Handle interfaces
|
|
if v.Kind() == reflect.Interface {
|
|
if !v.IsNil() {
|
|
collect(v.Elem(), arrays, seen)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
// FreeStruct releases all *Array fields in a struct (recursively).
|
|
// Use this to free model weights when unloading a model.
|
|
func FreeStruct(v any) {
|
|
for _, arr := range Collect(v) {
|
|
arr.Free()
|
|
}
|
|
}
|
|
|
|
// Keep marks arrays to persist across Eval() cleanup.
|
|
// Kept arrays will NOT be freed when Eval() runs cleanup.
|
|
func Keep(arrays ...*Array) {
|
|
for _, a := range arrays {
|
|
if a != nil {
|
|
a.kept = true
|
|
}
|
|
}
|
|
}
|
|
|
|
// cleanup frees non-kept arrays and compacts the live array list.
|
|
// Returns number of arrays freed.
|
|
func cleanup() int {
|
|
freed := 0
|
|
n := 0
|
|
for _, a := range arrays {
|
|
if a.kept {
|
|
arrays[n] = a
|
|
n++
|
|
} else if a.c.ctx != nil && !a.freed {
|
|
C.mlx_array_free(a.c)
|
|
a.c.ctx = nil
|
|
arrayPool.Put(a)
|
|
freed++
|
|
}
|
|
}
|
|
arrays = arrays[:n]
|
|
return freed
|
|
}
|
|
|
|
// DebugArrays prints summary info about all tracked arrays.
|
|
func DebugArrays() {
|
|
var totalBytes int64
|
|
var keptCount, unkeptCount int
|
|
for _, a := range arrays {
|
|
if a.kept {
|
|
keptCount++
|
|
} else {
|
|
unkeptCount++
|
|
}
|
|
totalBytes += a.Nbytes()
|
|
}
|
|
fmt.Printf("[DEBUG] Arrays: %d kept, %d unkept, %.2f GB total\n",
|
|
keptCount, unkeptCount, float64(totalBytes)/(1024*1024*1024))
|
|
}
|
|
|
|
// DebugArraysVerbose prints detailed info about all tracked arrays, sorted by size.
|
|
func DebugArraysVerbose(topN int) {
|
|
type arrayInfo struct {
|
|
shape []int32
|
|
dtype Dtype
|
|
bytes int64
|
|
kept bool
|
|
}
|
|
|
|
var infos []arrayInfo
|
|
var totalBytes int64
|
|
for _, a := range arrays {
|
|
bytes := a.Nbytes()
|
|
infos = append(infos, arrayInfo{
|
|
shape: a.Shape(),
|
|
dtype: a.Dtype(),
|
|
bytes: bytes,
|
|
kept: a.kept,
|
|
})
|
|
totalBytes += bytes
|
|
}
|
|
|
|
// Sort by size descending
|
|
for i := 0; i < len(infos)-1; i++ {
|
|
for j := i + 1; j < len(infos); j++ {
|
|
if infos[j].bytes > infos[i].bytes {
|
|
infos[i], infos[j] = infos[j], infos[i]
|
|
}
|
|
}
|
|
}
|
|
|
|
fmt.Printf("[DEBUG] %d arrays, %.2f GB total:\n", len(infos), float64(totalBytes)/(1024*1024*1024))
|
|
for i, info := range infos {
|
|
if i >= topN {
|
|
break
|
|
}
|
|
keptStr := ""
|
|
if info.kept {
|
|
keptStr = " [kept]"
|
|
}
|
|
fmt.Printf(" %3d. %8.2f MB %v %v%s\n",
|
|
i+1, float64(info.bytes)/(1024*1024), info.shape, info.dtype, keptStr)
|
|
}
|
|
}
|
|
|
|
// Eval synchronously evaluates arrays and cleans up non-kept arrays.
|
|
// Outputs are automatically kept (survive cleanup). Returns them for chaining.
|
|
func Eval(outputs ...*Array) []*Array {
|
|
// Keep outputs so cleanup doesn't free them
|
|
for _, o := range outputs {
|
|
if o != nil {
|
|
o.kept = true
|
|
}
|
|
}
|
|
|
|
// Cleanup non-kept arrays
|
|
cleanup()
|
|
|
|
// Then evaluate
|
|
if len(outputs) > 0 {
|
|
evalHandles = evalHandles[:0]
|
|
for _, o := range outputs {
|
|
if o != nil {
|
|
evalHandles = append(evalHandles, o.c)
|
|
}
|
|
}
|
|
if len(evalHandles) > 0 {
|
|
vec := C.mlx_vector_array_new_data(&evalHandles[0], C.size_t(len(evalHandles)))
|
|
C.mlx_eval(vec)
|
|
C.mlx_vector_array_free(vec)
|
|
}
|
|
}
|
|
return outputs
|
|
}
|
|
|
|
// AsyncEval dispatches async evaluation and cleans up non-kept arrays.
|
|
// Outputs are automatically kept (survive cleanup).
|
|
func AsyncEval(outputs ...*Array) {
|
|
// Keep outputs so cleanup doesn't free them
|
|
for _, o := range outputs {
|
|
if o != nil {
|
|
o.kept = true
|
|
}
|
|
}
|
|
|
|
// Cleanup non-kept arrays
|
|
cleanup()
|
|
|
|
// Then dispatch async eval
|
|
if len(outputs) > 0 {
|
|
evalHandles = evalHandles[:0]
|
|
for _, o := range outputs {
|
|
if o != nil {
|
|
evalHandles = append(evalHandles, o.c)
|
|
}
|
|
}
|
|
if len(evalHandles) > 0 {
|
|
vec := C.mlx_vector_array_new_data(&evalHandles[0], C.size_t(len(evalHandles)))
|
|
C.mlx_async_eval(vec)
|
|
C.mlx_vector_array_free(vec)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Sync waits for all async operations to complete (no cleanup).
|
|
func Sync() {
|
|
C.mlx_synchronize(C.default_stream())
|
|
}
|
|
|
|
// Free marks this array for cleanup on the next Eval().
|
|
// The array is not immediately freed - cleanup happens during Eval().
|
|
//
|
|
// Pattern for loops:
|
|
//
|
|
// oldLatents.Free() // mark for cleanup
|
|
// mlx.Eval(newLatents) // frees old, evals new
|
|
func (a *Array) Free() {
|
|
if a != nil {
|
|
a.kept = false
|
|
}
|
|
}
|
|
|
|
// Eval evaluates this single array and runs cleanup.
|
|
func (a *Array) Eval() *Array {
|
|
Eval(a)
|
|
return a
|
|
}
|
|
|
|
// Valid returns true if the array hasn't been freed.
|
|
func (a *Array) Valid() bool {
|
|
return a != nil && a.c.ctx != nil
|
|
}
|
|
|
|
// Kept returns true if the array is marked to survive Eval() cleanup.
|
|
func (a *Array) Kept() bool {
|
|
return a != nil && a.kept
|
|
}
|
|
|
|
func int32ToCInt(s []int32) *C.int {
|
|
if len(s) == 0 {
|
|
return nil
|
|
}
|
|
return (*C.int)(unsafe.Pointer(&s[0]))
|
|
}
|
|
|
|
// NewArray creates a new MLX array from float32 data
|
|
func NewArray(data []float32, shape []int32) *Array {
|
|
handle := C.mlx_array_new_data(
|
|
unsafe.Pointer(&data[0]),
|
|
int32ToCInt(shape),
|
|
C.int(len(shape)),
|
|
C.MLX_FLOAT32,
|
|
)
|
|
return newArray(handle)
|
|
}
|
|
|
|
// NewArrayInt32 creates a new MLX array from int32 data
|
|
func NewArrayInt32(data []int32, shape []int32) *Array {
|
|
handle := C.mlx_array_new_data(
|
|
unsafe.Pointer(&data[0]),
|
|
int32ToCInt(shape),
|
|
C.int(len(shape)),
|
|
C.MLX_INT32,
|
|
)
|
|
return newArray(handle)
|
|
}
|
|
|
|
// NewArrayFloat32 creates a new float32 array from data
|
|
func NewArrayFloat32(data []float32, shape []int32) *Array {
|
|
return NewArray(data, shape)
|
|
}
|
|
|
|
// Zeros creates an array of zeros with optional dtype (default float32)
|
|
func Zeros(shape []int32, dtype ...Dtype) *Array {
|
|
res := C.mlx_array_new()
|
|
dt := DtypeFloat32
|
|
if len(dtype) > 0 {
|
|
dt = dtype[0]
|
|
}
|
|
C.mlx_zeros(&res, int32ToCInt(shape), C.size_t(len(shape)), C.mlx_dtype(dt), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// ZerosLike creates a zeros array with the same dtype as a.
|
|
// If shape is provided, uses that shape; otherwise uses a's shape.
|
|
func ZerosLike(a *Array, shape ...int32) *Array {
|
|
res := C.mlx_array_new()
|
|
if len(shape) == 0 {
|
|
C.mlx_zeros_like(&res, a.c, C.default_stream())
|
|
} else {
|
|
dtype := a.Dtype()
|
|
C.mlx_zeros(&res, int32ToCInt(shape), C.size_t(len(shape)), C.mlx_dtype(dtype), C.default_stream())
|
|
}
|
|
return newArray(res)
|
|
}
|
|
|
|
// Ones creates an array of ones
|
|
func Ones(shape ...int32) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_ones(&res, int32ToCInt(shape), C.size_t(len(shape)), C.MLX_FLOAT32, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Full creates an array filled with a value
|
|
func Full(value float32, shape ...int32) *Array {
|
|
vals := C.mlx_array_new_float(C.float(value))
|
|
res := C.mlx_array_new()
|
|
C.mlx_full(&res, int32ToCInt(shape), C.size_t(len(shape)), vals, C.MLX_FLOAT32, C.default_stream())
|
|
C.mlx_array_free(vals)
|
|
return newArray(res)
|
|
}
|
|
|
|
// Arange creates a range of values
|
|
func Arange(start, stop, step float32) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_arange(&res, C.double(start), C.double(stop), C.double(step), C.MLX_FLOAT32, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Linspace creates evenly spaced values
|
|
func Linspace(start, stop float32, steps int32) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_linspace(&res, C.double(start), C.double(stop), C.int(steps), C.MLX_FLOAT32, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// ============ Math Operations ============
|
|
|
|
// Add adds two arrays element-wise
|
|
func Add(a, b *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_add(&res, a.c, b.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// AddRaw is like Add - kept for API compatibility (now identical to Add)
|
|
func AddRaw(a, b *Array) *Array {
|
|
return Add(a, b)
|
|
}
|
|
|
|
// Sub subtracts two arrays element-wise
|
|
func Sub(a, b *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_subtract(&res, a.c, b.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Mul multiplies two arrays element-wise
|
|
func Mul(a, b *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_multiply(&res, a.c, b.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Div divides two arrays element-wise
|
|
func Div(a, b *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_divide(&res, a.c, b.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Matmul performs matrix multiplication
|
|
func Matmul(a, b *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_matmul(&res, a.c, b.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// AddMM computes: result = beta*c + alpha*(a @ b)
|
|
// This fuses bias addition with matmul into a single op.
|
|
func AddMM(c, a, b *Array, alpha, beta float32) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_addmm(&res, c.c, a.c, b.c, C.float(alpha), C.float(beta), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Linear performs matrix multiplication: a @ weight
|
|
func Linear(a, weight *Array) *Array {
|
|
return Matmul(a, weight)
|
|
}
|
|
|
|
// Sqrt computes element-wise square root
|
|
func Sqrt(a *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_sqrt(&res, a.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// RSqrt computes element-wise reciprocal square root
|
|
func RSqrt(a *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_rsqrt(&res, a.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Erf computes element-wise error function
|
|
func Erf(a *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_erf(&res, a.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Exp computes element-wise exponential
|
|
func Exp(a *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_exp(&res, a.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Log computes element-wise natural logarithm
|
|
func Log(a *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_log(&res, a.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Sin computes element-wise sine
|
|
func Sin(a *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_sin(&res, a.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Cos computes element-wise cosine
|
|
func Cos(a *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_cos(&res, a.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Neg negates the array
|
|
func Neg(a *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_negative(&res, a.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Abs computes element-wise absolute value
|
|
func Abs(a *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_abs(&res, a.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Square computes element-wise square
|
|
func Square(a *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_square(&res, a.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Pow raises a to the power of b element-wise
|
|
func Pow(a, b *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_power(&res, a.c, b.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Max computes element-wise maximum
|
|
func Max(a, b *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_maximum(&res, a.c, b.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Min computes element-wise minimum
|
|
func Min(a, b *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_minimum(&res, a.c, b.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// scalarWithDtype creates a scalar array matching the dtype of a (critical for graph fusion!)
|
|
func scalarWithDtype(s float32, a *Array) C.mlx_array {
|
|
// Create float32 scalar, then cast to match input dtype
|
|
f32 := C.mlx_array_new_float(C.float(s))
|
|
dtype := a.Dtype()
|
|
if dtype == DtypeFloat32 {
|
|
return f32 // No cast needed
|
|
}
|
|
// Cast to match input dtype
|
|
casted := C.mlx_array_new()
|
|
C.mlx_astype(&casted, f32, C.mlx_dtype(dtype), C.default_stream())
|
|
C.mlx_array_free(f32)
|
|
return casted
|
|
}
|
|
|
|
// AddScalar adds a scalar to an array (matches dtype for graph fusion)
|
|
func AddScalar(a *Array, s float32) *Array {
|
|
scalar := scalarWithDtype(s, a)
|
|
res := C.mlx_array_new()
|
|
C.mlx_add(&res, a.c, scalar, C.default_stream())
|
|
C.mlx_array_free(scalar)
|
|
return newArray(res)
|
|
}
|
|
|
|
// MulScalar multiplies an array by a scalar (matches dtype for graph fusion)
|
|
func MulScalar(a *Array, s float32) *Array {
|
|
scalar := scalarWithDtype(s, a)
|
|
res := C.mlx_array_new()
|
|
C.mlx_multiply(&res, a.c, scalar, C.default_stream())
|
|
C.mlx_array_free(scalar)
|
|
return newArray(res)
|
|
}
|
|
|
|
// DivScalar divides an array by a scalar (matches dtype for graph fusion)
|
|
func DivScalar(a *Array, s float32) *Array {
|
|
scalar := scalarWithDtype(s, a)
|
|
res := C.mlx_array_new()
|
|
C.mlx_divide(&res, a.c, scalar, C.default_stream())
|
|
C.mlx_array_free(scalar)
|
|
return newArray(res)
|
|
}
|
|
|
|
// DivScalarInt divides an int array by an int scalar (regular division, may return float)
|
|
func DivScalarInt(a *Array, s int32) *Array {
|
|
scalar := C.mlx_array_new_int(C.int(s))
|
|
res := C.mlx_array_new()
|
|
C.mlx_divide(&res, a.c, scalar, C.default_stream())
|
|
C.mlx_array_free(scalar)
|
|
return newArray(res)
|
|
}
|
|
|
|
// FloorDivideScalar performs integer floor division (a // s), preserving int dtype
|
|
func FloorDivideScalar(a *Array, s int32) *Array {
|
|
scalar := C.mlx_array_new_int(C.int(s))
|
|
res := C.mlx_array_new()
|
|
C.mlx_floor_divide(&res, a.c, scalar, C.default_stream())
|
|
C.mlx_array_free(scalar)
|
|
return newArray(res)
|
|
}
|
|
|
|
// ============ Reduction Operations ============
|
|
|
|
// Sum reduces along an axis
|
|
func Sum(a *Array, axis int, keepdims bool) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_sum_axis(&res, a.c, C.int(axis), C._Bool(keepdims), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// SumAll reduces the entire array to a scalar
|
|
func SumAll(a *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_sum(&res, a.c, false, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Mean reduces along an axis
|
|
func Mean(a *Array, axis int, keepdims bool) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_mean_axis(&res, a.c, C.int(axis), C._Bool(keepdims), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// MeanAll reduces the entire array to a scalar
|
|
func MeanAll(a *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_mean(&res, a.c, false, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Var computes variance along an axis
|
|
func Var(a *Array, axis int, keepdims bool) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_var_axis(&res, a.c, C.int(axis), C._Bool(keepdims), 0, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Argmax returns indices of maximum values along an axis
|
|
func Argmax(a *Array, axis int, keepdims bool) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_argmax_axis(&res, a.c, C.int(axis), C._Bool(keepdims), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// ArgmaxAll returns the index of the maximum element (flattened).
|
|
// Triggers cleanup of non-kept arrays.
|
|
func ArgmaxAll(a *Array) int32 {
|
|
cleanup()
|
|
// Flatten, then argmax with keepdims=false
|
|
flat := C.mlx_array_new()
|
|
C.mlx_flatten(&flat, a.c, 0, -1, C.default_stream())
|
|
res := C.mlx_array_new()
|
|
C.mlx_argmax(&res, flat, false, C.default_stream())
|
|
C.mlx_array_eval(res)
|
|
var val C.int32_t
|
|
C.mlx_array_item_int32(&val, res)
|
|
C.mlx_array_free(flat)
|
|
C.mlx_array_free(res)
|
|
return int32(val)
|
|
}
|
|
|
|
// Reshape reshapes the array
|
|
func Reshape(a *Array, shape ...int32) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_reshape(&res, a.c, int32ToCInt(shape), C.size_t(len(shape)), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Transpose permutes the dimensions
|
|
func Transpose(a *Array, axes ...int) *Array {
|
|
cAxes := make([]C.int, len(axes))
|
|
for i, ax := range axes {
|
|
cAxes[i] = C.int(ax)
|
|
}
|
|
res := C.mlx_array_new()
|
|
C.mlx_transpose_axes(&res, a.c, &cAxes[0], C.size_t(len(axes)), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// AsStrided creates a view with custom strides. Useful for fusing reshape+transpose.
|
|
func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array {
|
|
cShape := make([]C.int, len(shape))
|
|
for i, s := range shape {
|
|
cShape[i] = C.int(s)
|
|
}
|
|
cStrides := make([]C.int64_t, len(strides))
|
|
for i, s := range strides {
|
|
cStrides[i] = C.int64_t(s)
|
|
}
|
|
res := C.mlx_array_new()
|
|
C.mlx_as_strided(&res, a.c, &cShape[0], C.size_t(len(shape)), &cStrides[0], C.size_t(len(strides)), C.size_t(offset), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// ExpandDims adds a dimension at the specified axis
|
|
func ExpandDims(a *Array, axis int) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_expand_dims(&res, a.c, C.int(axis), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Squeeze removes a dimension at the specified axis
|
|
func Squeeze(a *Array, axis int) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_squeeze_axis(&res, a.c, C.int(axis), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Flatten flattens the array to 1D
|
|
func Flatten(a *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_flatten(&res, a.c, 0, -1, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// FlattenRange flattens consecutive axes from start_axis to end_axis (intermediates)
|
|
func FlattenRange(a *Array, startAxis, endAxis int) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_flatten(&res, a.c, C.int(startAxis), C.int(endAxis), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// View reinterprets the array with a new dtype (no data copy)
|
|
func View(a *Array, dtype int) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_view(&res, a.c, C.mlx_dtype(dtype), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Contiguous returns a contiguous copy of the array (row-major)
|
|
func Contiguous(a *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
// Use allow_col=false to force row-major contiguous layout
|
|
C.mlx_contiguous(&res, a.c, false, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Clip clips values to [min, max]. Pass nil for no bound on that side.
|
|
func Clip(a *Array, aMin, aMax *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
var minH, maxH C.mlx_array
|
|
if aMin != nil {
|
|
minH = aMin.c
|
|
}
|
|
if aMax != nil {
|
|
maxH = aMax.c
|
|
}
|
|
C.mlx_clip(&res, a.c, minH, maxH, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// ClipScalar clips array values using scalar bounds (matches dtype for graph fusion)
|
|
// Pass math.NaN() or set hasMin/hasMax to false for unbounded
|
|
func ClipScalar(a *Array, minVal, maxVal float32, hasMin, hasMax bool) *Array {
|
|
var minArr, maxArr C.mlx_array
|
|
if hasMin {
|
|
minArr = scalarWithDtype(minVal, a)
|
|
}
|
|
if hasMax {
|
|
maxArr = scalarWithDtype(maxVal, a)
|
|
}
|
|
res := C.mlx_array_new()
|
|
C.mlx_clip(&res, a.c, minArr, maxArr, C.default_stream())
|
|
if hasMin {
|
|
C.mlx_array_free(minArr)
|
|
}
|
|
if hasMax {
|
|
C.mlx_array_free(maxArr)
|
|
}
|
|
return newArray(res)
|
|
}
|
|
|
|
// GreaterEqual returns element-wise a >= b
|
|
func GreaterEqual(a, b *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_greater_equal(&res, a.c, b.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// LessArray returns element-wise a < b
|
|
func LessArray(a, b *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_less(&res, a.c, b.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// LogicalAnd returns element-wise a && b
|
|
func LogicalAnd(a, b *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_logical_and(&res, a.c, b.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// AllClose returns true if all elements of a and b are within tolerance.
|
|
// Uses rtol (relative tolerance) and atol (absolute tolerance):
|
|
// |a - b| <= atol + rtol * |b|
|
|
func AllClose(a, b *Array, rtol, atol float64) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_allclose(&res, a.c, b.c, C.double(rtol), C.double(atol), C.bool(false), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// AllCloseEqualNaN is like AllClose but treats NaN as equal to NaN.
|
|
func AllCloseEqualNaN(a, b *Array, rtol, atol float64) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_allclose(&res, a.c, b.c, C.double(rtol), C.double(atol), C.bool(true), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// ArrayEqual returns true if arrays have same shape and all elements are equal.
|
|
func ArrayEqual(a, b *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_array_equal(&res, a.c, b.c, C.bool(false), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// ArrayEqualNaN is like ArrayEqual but treats NaN as equal to NaN.
|
|
func ArrayEqualNaN(a, b *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_array_equal(&res, a.c, b.c, C.bool(true), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// IsClose returns element-wise bool array indicating if values are within tolerance.
|
|
// |a - b| <= atol + rtol * |b|
|
|
func IsClose(a, b *Array, rtol, atol float64) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_isclose(&res, a.c, b.c, C.double(rtol), C.double(atol), C.bool(false), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// IsCloseEqualNaN is like IsClose but treats NaN as equal to NaN.
|
|
func IsCloseEqualNaN(a, b *Array, rtol, atol float64) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_isclose(&res, a.c, b.c, C.double(rtol), C.double(atol), C.bool(true), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// ReduceMax reduces array to max value over all dimensions.
|
|
func ReduceMax(a *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_max(&res, a.c, C.bool(false), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// ArangeInt creates an array with values from start to stop with step and specified dtype
|
|
func ArangeInt(start, stop, step int32, dtype Dtype) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_arange(&res, C.double(start), C.double(stop), C.double(step), C.mlx_dtype(dtype), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Concatenate concatenates arrays along an axis
|
|
func Concatenate(arrays []*Array, axis int) *Array {
|
|
handles := make([]C.mlx_array, len(arrays))
|
|
for i, arr := range arrays {
|
|
handles[i] = arr.c
|
|
}
|
|
vec := C.mlx_vector_array_new_data(&handles[0], C.size_t(len(handles)))
|
|
res := C.mlx_array_new()
|
|
C.mlx_concatenate_axis(&res, vec, C.int(axis), C.default_stream())
|
|
C.mlx_vector_array_free(vec)
|
|
return newArray(res)
|
|
}
|
|
|
|
// Concat is a convenience function to concatenate two arrays
|
|
func Concat(a, b *Array, axis int) *Array {
|
|
return Concatenate([]*Array{a, b}, axis)
|
|
}
|
|
|
|
// Stack stacks arrays along a new axis (axis 0 by default)
|
|
func Stack(arrays []*Array, axis int) *Array {
|
|
handles := make([]C.mlx_array, len(arrays))
|
|
for i, arr := range arrays {
|
|
handles[i] = arr.c
|
|
}
|
|
vec := C.mlx_vector_array_new_data(&handles[0], C.size_t(len(handles)))
|
|
res := C.mlx_array_new()
|
|
C.mlx_stack_axis(&res, vec, C.int(axis), C.default_stream())
|
|
C.mlx_vector_array_free(vec)
|
|
return newArray(res)
|
|
}
|
|
|
|
// Slice slices the array
|
|
func Slice(a *Array, start, stop []int32) *Array {
|
|
n := len(start)
|
|
cStart := make([]C.int, n)
|
|
cStop := make([]C.int, n)
|
|
cStrides := make([]C.int, n)
|
|
for i := 0; i < n; i++ {
|
|
cStart[i] = C.int(start[i])
|
|
cStop[i] = C.int(stop[i])
|
|
cStrides[i] = 1 // Default stride of 1
|
|
}
|
|
res := C.mlx_array_new()
|
|
C.mlx_slice(&res, a.c, &cStart[0], C.size_t(n), &cStop[0], C.size_t(n), &cStrides[0], C.size_t(n), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// SliceStride slices with start:stop:stride like Python a[start:stop:stride]
|
|
func SliceStride(a *Array, start, stop, strides []int32) *Array {
|
|
cStart := make([]C.int, len(start))
|
|
cStop := make([]C.int, len(stop))
|
|
cStrides := make([]C.int, len(strides))
|
|
for i := range start {
|
|
cStart[i] = C.int(start[i])
|
|
cStop[i] = C.int(stop[i])
|
|
cStrides[i] = C.int(strides[i])
|
|
}
|
|
res := C.mlx_array_new()
|
|
C.mlx_slice(&res, a.c, &cStart[0], C.size_t(len(start)), &cStop[0], C.size_t(len(stop)), &cStrides[0], C.size_t(len(strides)), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Tile repeats the array along each dimension
|
|
func Tile(a *Array, reps []int32) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_tile(&res, a.c, int32ToCInt(reps), C.size_t(len(reps)), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// BroadcastTo broadcasts an array to a given shape
|
|
func BroadcastTo(a *Array, shape []int32) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_broadcast_to(&res, a.c, int32ToCInt(shape), C.size_t(len(shape)), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// ============ Neural Network Operations ============
|
|
|
|
// Softmax computes softmax along an axis
|
|
func Softmax(a *Array, axis int) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_softmax_axis(&res, a.c, C.int(axis), false, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Take gathers elements along an axis using indices
|
|
func Take(a *Array, indices *Array, axis int) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_take_axis(&res, a.c, indices.c, C.int(axis), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Argsort returns indices that would sort the array along an axis
|
|
func Argsort(a *Array, axis int) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_argsort_axis(&res, a.c, C.int(axis), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Sigmoid computes element-wise sigmoid
|
|
func Sigmoid(a *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_sigmoid(&res, a.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// ReLU computes element-wise ReLU: max(0, x)
|
|
func ReLU(a *Array) *Array {
|
|
// ReLU = maximum(x, 0) - mlx-c doesn't have mlx_relu, but we can use maximum
|
|
zero := C.mlx_array_new_float(0.0)
|
|
res := C.mlx_array_new()
|
|
C.mlx_maximum(&res, a.c, zero, C.default_stream())
|
|
C.mlx_array_free(zero)
|
|
return newArray(res)
|
|
}
|
|
|
|
// SiLU computes element-wise SiLU (Swish): x * sigmoid(x)
|
|
func SiLU(a *Array) *Array {
|
|
// SiLU = x * sigmoid(x)
|
|
sig := C.mlx_array_new()
|
|
C.mlx_sigmoid(&sig, a.c, C.default_stream())
|
|
res := C.mlx_array_new()
|
|
C.mlx_multiply(&res, a.c, sig, C.default_stream())
|
|
C.mlx_array_free(sig)
|
|
return newArray(res)
|
|
}
|
|
|
|
// GELU computes element-wise GELU (Gaussian Error Linear Unit)
|
|
// GELU(x) = x * 0.5 * (1 + erf(x / sqrt(2)))
|
|
func GELU(a *Array) *Array {
|
|
sqrt2 := C.mlx_array_new_float(1.4142135623730951)
|
|
scaled := C.mlx_array_new()
|
|
C.mlx_divide(&scaled, a.c, sqrt2, C.default_stream())
|
|
erfd := C.mlx_array_new()
|
|
C.mlx_erf(&erfd, scaled, C.default_stream())
|
|
one := C.mlx_array_new_float(1.0)
|
|
erfdPlusOne := C.mlx_array_new()
|
|
C.mlx_add(&erfdPlusOne, erfd, one, C.default_stream())
|
|
half := C.mlx_array_new_float(0.5)
|
|
halfErfdPlusOne := C.mlx_array_new()
|
|
C.mlx_multiply(&halfErfdPlusOne, half, erfdPlusOne, C.default_stream())
|
|
res := C.mlx_array_new()
|
|
C.mlx_multiply(&res, a.c, halfErfdPlusOne, C.default_stream())
|
|
C.mlx_array_free(sqrt2)
|
|
C.mlx_array_free(scaled)
|
|
C.mlx_array_free(erfd)
|
|
C.mlx_array_free(one)
|
|
C.mlx_array_free(erfdPlusOne)
|
|
C.mlx_array_free(half)
|
|
C.mlx_array_free(halfErfdPlusOne)
|
|
return newArray(res)
|
|
}
|
|
|
|
// Tanh computes element-wise tanh
|
|
func Tanh(a *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_tanh(&res, a.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// RMSNorm computes RMS normalization using mlx.fast
|
|
func RMSNorm(x, weight *Array, eps float32) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_fast_rms_norm(&res, x.c, weight.c, C.float(eps), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// RMSNormNoWeight applies RMS normalization without a weight
|
|
// x * rsqrt(mean(x^2) + eps)
|
|
// Uses mlx_fast_rms_norm with ones weight for f32 accumulation precision
|
|
func RMSNormNoWeight(x *Array, eps float32) *Array {
|
|
// Create weight of ones matching last dimension
|
|
lastDim := x.Shape()[len(x.Shape())-1]
|
|
ones := AsType(Full(1.0, lastDim), x.Dtype())
|
|
return RMSNorm(x, ones, eps)
|
|
}
|
|
|
|
// LayerNorm applies layer normalization without learnable params
|
|
// (x - mean) / sqrt(var + eps)
|
|
func LayerNorm(x *Array, eps float32) *Array {
|
|
return LayerNormWithWeightBias(x, nil, nil, eps)
|
|
}
|
|
|
|
// LayerNormWithWeightBias computes layer normalization using mlx.fast
|
|
// weight and bias can be nil for elementwise_affine=False
|
|
func LayerNormWithWeightBias(x, weight, bias *Array, eps float32) *Array {
|
|
res := C.mlx_array_new()
|
|
var wc, bc C.mlx_array
|
|
if weight != nil {
|
|
wc = weight.c
|
|
}
|
|
if bias != nil {
|
|
bc = bias.c
|
|
}
|
|
C.mlx_fast_layer_norm(&res, x.c, wc, bc, C.float(eps), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// RoPE applies rotary position embeddings using mlx.fast
|
|
func RoPE(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
|
res := C.mlx_array_new()
|
|
optBase := C.mlx_optional_float{value: C.float(base), has_value: true}
|
|
C.mlx_fast_rope(&res, x.c, C.int(dims), C._Bool(traditional), optBase, C.float(scale), C.int(offset), C.mlx_array{}, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// RoPEWithFreqs applies rotary position embeddings with custom frequencies (for YaRN)
|
|
// freqs is required - use RoPE() if you don't have custom frequencies
|
|
func RoPEWithFreqs(x, freqs *Array, dims int, traditional bool, scale float32, offset int) *Array {
|
|
res := C.mlx_array_new()
|
|
optBase := C.mlx_optional_float{has_value: false} // No base when using freqs
|
|
C.mlx_fast_rope(&res, x.c, C.int(dims), C._Bool(traditional), optBase, C.float(scale), C.int(offset), freqs.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// ============ Indexing ============
|
|
|
|
// EmbeddingLookup performs embedding lookup (gathers from table)
|
|
// table: [vocab_size, hidden_size], indices: [batch, seq_len]
|
|
// returns: [batch, seq_len, hidden_size]
|
|
func EmbeddingLookup(table, indices *Array) *Array {
|
|
return Take(table, indices, 0)
|
|
}
|
|
|
|
// Gather gathers elements using indices - simplified to use take axis 0
|
|
func Gather(a, indices *Array) *Array {
|
|
return Take(a, indices, 0)
|
|
}
|
|
|
|
// ============ Array Properties ============
|
|
|
|
// Ndim returns the number of dimensions
|
|
func (a *Array) Ndim() int {
|
|
return int(C.mlx_array_ndim(a.c))
|
|
}
|
|
|
|
// Size returns the total number of elements
|
|
func (a *Array) Size() int {
|
|
return int(C.mlx_array_size(a.c))
|
|
}
|
|
|
|
// IsContiguous returns whether the array's data is contiguous in memory.
|
|
// Non-contiguous arrays (e.g., from SliceStride) must call Contiguous() before Data().
|
|
func (a *Array) IsContiguous() bool {
|
|
var res C.bool
|
|
C._mlx_array_is_contiguous(&res, a.c)
|
|
return bool(res)
|
|
}
|
|
|
|
// Dim returns the size of a dimension
|
|
func (a *Array) Dim(axis int) int32 {
|
|
return int32(C.mlx_array_dim(a.c, C.int(axis)))
|
|
}
|
|
|
|
// Shape returns the shape as a slice
|
|
func (a *Array) Shape() []int32 {
|
|
ndim := a.Ndim()
|
|
shape := make([]int32, ndim)
|
|
for i := 0; i < ndim; i++ {
|
|
shape[i] = a.Dim(i)
|
|
}
|
|
return shape
|
|
}
|
|
|
|
// IsValid returns true if the array hasn't been freed
|
|
func (a *Array) IsValid() bool {
|
|
return a != nil && a.c.ctx != nil
|
|
}
|
|
|
|
// Dtype returns the data type
|
|
func (a *Array) Dtype() Dtype {
|
|
return Dtype(C.mlx_array_dtype(a.c))
|
|
}
|
|
|
|
// Nbytes returns the total size in bytes
|
|
func (a *Array) Nbytes() int64 {
|
|
return int64(a.Size()) * a.Dtype().ItemSize()
|
|
}
|
|
|
|
// ItemSize returns the size in bytes of one element for this dtype
|
|
func (d Dtype) ItemSize() int64 {
|
|
switch d {
|
|
case DtypeBool, DtypeUint8, DtypeInt8:
|
|
return 1
|
|
case DtypeUint16, DtypeInt16, DtypeFloat16, DtypeBFloat16:
|
|
return 2
|
|
case DtypeUint32, DtypeInt32, DtypeFloat32:
|
|
return 4
|
|
case DtypeUint64, DtypeInt64, DtypeFloat64, DtypeComplex64:
|
|
return 8
|
|
default:
|
|
return 4
|
|
}
|
|
}
|
|
|
|
// ============ Data Access ============
|
|
|
|
// Data copies the float32 data out of the array.
|
|
// Note: For non-contiguous arrays (e.g., from SliceStride), call Contiguous() first.
|
|
// Note: Arrays of other dtypes (bf16, f16, etc) are automatically converted to float32.
|
|
// Note: Triggers cleanup of non-kept arrays.
|
|
func (a *Array) Data() []float32 {
|
|
cleanup()
|
|
size := a.Size()
|
|
if size == 0 {
|
|
return nil
|
|
}
|
|
|
|
arr := a
|
|
if a.Dtype() != DtypeFloat32 {
|
|
arr = AsType(a, DtypeFloat32)
|
|
arr.Eval()
|
|
// Cast array will be cleaned up on next Eval
|
|
}
|
|
|
|
ptr := C.mlx_array_data_float32(arr.c)
|
|
if ptr == nil {
|
|
return nil
|
|
}
|
|
data := make([]float32, size)
|
|
copy(data, unsafe.Slice((*float32)(unsafe.Pointer(ptr)), size))
|
|
return data
|
|
}
|
|
|
|
// Item returns the scalar value from a 0-dimensional array.
|
|
// Converts to float32 if necessary. Triggers cleanup.
|
|
func (a *Array) Item() float32 {
|
|
data := a.Data() // Data() calls cleanup()
|
|
if len(data) == 0 {
|
|
return 0
|
|
}
|
|
return data[0]
|
|
}
|
|
|
|
// DataInt32 copies the int32 data out of the array.
|
|
// Note: For non-contiguous arrays (e.g., from SliceStride), call Contiguous() first.
|
|
// Note: Triggers cleanup of non-kept arrays.
|
|
func (a *Array) DataInt32() []int32 {
|
|
cleanup()
|
|
size := a.Size()
|
|
if size == 0 {
|
|
return nil
|
|
}
|
|
ptr := C.mlx_array_data_int32(a.c)
|
|
if ptr == nil {
|
|
return nil
|
|
}
|
|
data := make([]int32, size)
|
|
copy(data, unsafe.Slice((*int32)(unsafe.Pointer(ptr)), size))
|
|
return data
|
|
}
|
|
|
|
// ItemInt32 gets a single scalar value efficiently (no array copy).
|
|
// Note: Triggers cleanup of non-kept arrays.
|
|
func (a *Array) ItemInt32() int32 {
|
|
cleanup()
|
|
var val C.int32_t
|
|
C.mlx_array_item_int32(&val, a.c)
|
|
return int32(val)
|
|
}
|
|
|
|
// Bytes copies the raw bytes out of the array without type conversion.
|
|
// Works with common dtypes (float32, int32, uint32, uint8).
|
|
// For non-contiguous arrays, call Contiguous() first.
|
|
// Note: Triggers cleanup of non-kept arrays.
|
|
func (a *Array) Bytes() []byte {
|
|
cleanup()
|
|
nbytes := a.Nbytes()
|
|
if nbytes == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Get raw pointer based on dtype
|
|
var ptr unsafe.Pointer
|
|
switch a.Dtype() {
|
|
case DtypeFloat32:
|
|
ptr = unsafe.Pointer(C.mlx_array_data_float32(a.c))
|
|
case DtypeInt32:
|
|
ptr = unsafe.Pointer(C.mlx_array_data_int32(a.c))
|
|
case DtypeUint32:
|
|
ptr = unsafe.Pointer(C.mlx_array_data_uint32(a.c))
|
|
case DtypeUint8:
|
|
ptr = unsafe.Pointer(C.mlx_array_data_uint8(a.c))
|
|
default:
|
|
// For other types (bf16, f16, etc), convert to float32
|
|
arr := AsType(a, DtypeFloat32)
|
|
arr.Eval()
|
|
ptr = unsafe.Pointer(C.mlx_array_data_float32(arr.c))
|
|
nbytes = arr.Nbytes()
|
|
}
|
|
|
|
if ptr == nil {
|
|
return nil
|
|
}
|
|
data := make([]byte, nbytes)
|
|
copy(data, unsafe.Slice((*byte)(ptr), nbytes))
|
|
return data
|
|
}
|
|
|
|
// ============ Utility ============
|
|
|
|
// String returns a string representation
|
|
func (a *Array) String() string {
|
|
shape := a.Shape()
|
|
size := a.Size()
|
|
if size <= 20 {
|
|
data := a.Data()
|
|
return fmt.Sprintf("Array(shape=%v, data=%v)", shape, data)
|
|
}
|
|
return fmt.Sprintf("Array(shape=%v, size=%d)", shape, size)
|
|
}
|
|
|
|
// ============ Safetensors Support ============
|
|
|
|
// NewArrayFromBytes creates an array from raw bytes (for safetensors)
|
|
func NewArrayFromBytes(data []byte, shape []int32, dtype Dtype) *Array {
|
|
cData := unsafe.Pointer(&data[0])
|
|
intShape := make([]C.int, len(shape))
|
|
for i, s := range shape {
|
|
intShape[i] = C.int(s)
|
|
}
|
|
handle := C.mlx_array_new_data(cData, &intShape[0], C.int(len(shape)), C.mlx_dtype(dtype))
|
|
return newArray(handle)
|
|
}
|
|
|
|
// ============ Device Control ============
|
|
|
|
// SetDefaultDeviceGPU sets the default device to GPU (Metal)
|
|
func SetDefaultDeviceGPU() {
|
|
dev := C.mlx_device_new_type(C.MLX_GPU, 0)
|
|
C.mlx_set_default_device(dev)
|
|
C.mlx_device_free(dev)
|
|
}
|
|
|
|
// SetDefaultDeviceCPU sets the default device to CPU
|
|
func SetDefaultDeviceCPU() {
|
|
dev := C.mlx_device_new_type(C.MLX_CPU, 0)
|
|
C.mlx_set_default_device(dev)
|
|
C.mlx_device_free(dev)
|
|
}
|
|
|
|
// MetalIsAvailable returns true if Metal GPU is available
|
|
func MetalIsAvailable() bool {
|
|
var available C._Bool
|
|
C.mlx_metal_is_available(&available)
|
|
return bool(available)
|
|
}
|
|
|
|
// MetalStartCapture starts a GPU trace capture to the given file path.
|
|
// The path must not already exist. Run with MTL_CAPTURE_ENABLED=1 env var.
|
|
// Open the resulting .gputrace file in Xcode for analysis.
|
|
func MetalStartCapture(path string) {
|
|
cPath := C.CString(path)
|
|
defer C.free(unsafe.Pointer(cPath))
|
|
C.mlx_metal_start_capture(cPath)
|
|
}
|
|
|
|
// MetalStopCapture stops the current GPU trace capture.
|
|
func MetalStopCapture() {
|
|
C.mlx_metal_stop_capture()
|
|
}
|
|
|
|
// GPUIsAvailable returns true if any GPU (Metal or CUDA) is available
|
|
func GPUIsAvailable() bool {
|
|
// On Linux with CUDA build, GPU is available
|
|
// On macOS, check Metal availability
|
|
if MetalIsAvailable() {
|
|
return true
|
|
}
|
|
// CUDA is available if we compiled with CUDA support (Linux)
|
|
return runtime.GOOS == "linux"
|
|
}
|
|
|
|
// GetDefaultDeviceType returns the current default device (0=CPU, 1=GPU)
|
|
func GetDefaultDeviceType() int {
|
|
var dev C.mlx_device
|
|
C.mlx_get_default_device(&dev)
|
|
var devType C.mlx_device_type
|
|
C.mlx_device_get_type(&devType, dev)
|
|
C.mlx_device_free(dev)
|
|
return int(devType)
|
|
}
|
|
|
|
// Synchronize waits for all GPU operations to complete
|
|
func Synchronize() {
|
|
C.mlx_synchronize(C.default_stream())
|
|
}
|
|
|
|
// ScaledDotProductAttention computes optimized attention using GPU kernel
|
|
// Q, K, V should be [batch, heads, seq, head_dim]
|
|
func ScaledDotProductAttention(q, k, v *Array, scale float32, causalMask bool) *Array {
|
|
res := C.mlx_array_new()
|
|
maskMode := "" // empty string for no mask
|
|
if causalMask {
|
|
maskMode = "causal"
|
|
}
|
|
cMaskMode := C.CString(maskMode)
|
|
defer C.free(unsafe.Pointer(cMaskMode))
|
|
C.mlx_fast_scaled_dot_product_attention(&res, q.c, k.c, v.c, C.float(scale), cMaskMode, C.mlx_array{}, C.mlx_array{}, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// ScaledDotProductAttentionWithSinks computes attention with sinks support
|
|
// maskMode: "causal", "sliding_window", or "" for none
|
|
// mask: optional attention mask array (nil for none)
|
|
// sinks: attention sinks array (nil for none)
|
|
func ScaledDotProductAttentionWithSinks(q, k, v *Array, scale float32, maskMode string, mask, sinks *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
cMaskMode := C.CString(maskMode)
|
|
defer C.free(unsafe.Pointer(cMaskMode))
|
|
var maskH, sinksH C.mlx_array
|
|
if mask != nil {
|
|
maskH = mask.c
|
|
}
|
|
if sinks != nil {
|
|
sinksH = sinks.c
|
|
}
|
|
C.mlx_fast_scaled_dot_product_attention(&res, q.c, k.c, v.c, C.float(scale), cMaskMode, maskH, sinksH, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// ============ Native Safetensors Loading ============
|
|
|
|
// SafetensorsFile represents a loaded safetensors file
|
|
type SafetensorsFile struct {
|
|
arrays C.mlx_map_string_to_array
|
|
metadata C.mlx_map_string_to_string
|
|
}
|
|
|
|
// LoadSafetensorsNative loads a safetensors file using MLX's optimized loader.
|
|
// On CUDA, Load::eval_gpu is implemented so we use the default (GPU) stream.
|
|
// On Metal, Load::eval_gpu is not implemented so we must use the CPU stream.
|
|
func LoadSafetensorsNative(path string) (*SafetensorsFile, error) {
|
|
cPath := C.CString(path)
|
|
defer C.free(unsafe.Pointer(cPath))
|
|
|
|
stream := C.default_stream()
|
|
if runtime.GOOS == "darwin" {
|
|
stream = C.cpu_stream()
|
|
}
|
|
|
|
var arrays C.mlx_map_string_to_array
|
|
var metadata C.mlx_map_string_to_string
|
|
if C.mlx_load_safetensors(&arrays, &metadata, cPath, stream) != 0 {
|
|
return nil, fmt.Errorf("failed to load safetensors: %s", path)
|
|
}
|
|
return &SafetensorsFile{arrays: arrays, metadata: metadata}, nil
|
|
}
|
|
|
|
// Get retrieves a tensor by name
|
|
func (s *SafetensorsFile) Get(name string) *Array {
|
|
cName := C.CString(name)
|
|
defer C.free(unsafe.Pointer(cName))
|
|
|
|
var arr C.mlx_array
|
|
if C.mlx_map_string_to_array_get(&arr, s.arrays, cName) != 0 {
|
|
return nil
|
|
}
|
|
if arr.ctx == nil {
|
|
return nil
|
|
}
|
|
return newArray(arr)
|
|
}
|
|
|
|
// Set replaces a tensor in the map (like Python's weights[k] = v)
|
|
func (s *SafetensorsFile) Set(name string, arr *Array) {
|
|
cName := C.CString(name)
|
|
defer C.free(unsafe.Pointer(cName))
|
|
C.mlx_map_string_to_array_insert(s.arrays, cName, arr.c)
|
|
}
|
|
|
|
// Count returns the number of tensors (not directly available, would need iterator)
|
|
func (s *SafetensorsFile) Count() int {
|
|
// mlx-c doesn't have a direct count - would need to iterate
|
|
return 0
|
|
}
|
|
|
|
// GetMetadata retrieves a metadata value by key from the safetensors file
|
|
func (s *SafetensorsFile) GetMetadata(key string) string {
|
|
cKey := C.CString(key)
|
|
defer C.free(unsafe.Pointer(cKey))
|
|
|
|
var cValue *C.char
|
|
if C.mlx_map_string_to_string_get(&cValue, s.metadata, cKey) != 0 {
|
|
return ""
|
|
}
|
|
return C.GoString(cValue)
|
|
}
|
|
|
|
// Free releases the safetensors file
|
|
func (s *SafetensorsFile) Free() {
|
|
C.mlx_map_string_to_array_free(s.arrays)
|
|
C.mlx_map_string_to_string_free(s.metadata)
|
|
}
|
|
|
|
// SaveSafetensors saves arrays to a safetensors file using MLX's native implementation.
|
|
// This correctly handles all dtypes including uint32 for quantized weights.
|
|
func SaveSafetensors(path string, arrays map[string]*Array) error {
|
|
cPath := C.CString(path)
|
|
defer C.free(unsafe.Pointer(cPath))
|
|
|
|
// Create the map
|
|
cArrays := C.mlx_map_string_to_array_new()
|
|
defer C.mlx_map_string_to_array_free(cArrays)
|
|
|
|
// Add each array to the map
|
|
for name, arr := range arrays {
|
|
cName := C.CString(name)
|
|
C.mlx_map_string_to_array_insert(cArrays, cName, arr.c)
|
|
C.free(unsafe.Pointer(cName))
|
|
}
|
|
|
|
// Create empty metadata (optional)
|
|
cMeta := C.mlx_map_string_to_string_new()
|
|
defer C.mlx_map_string_to_string_free(cMeta)
|
|
|
|
// Save
|
|
if C.mlx_save_safetensors(cPath, cArrays, cMeta) != 0 {
|
|
return fmt.Errorf("failed to save safetensors: %s", path)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SaveSafetensorsWithMetadata saves arrays to a safetensors file with metadata key/value pairs.
|
|
// This is like SaveSafetensors but inserts metadata into the __metadata__ section.
|
|
func SaveSafetensorsWithMetadata(path string, arrays map[string]*Array, metadata map[string]string) error {
|
|
cPath := C.CString(path)
|
|
defer C.free(unsafe.Pointer(cPath))
|
|
|
|
// Create the array map
|
|
cArrays := C.mlx_map_string_to_array_new()
|
|
defer C.mlx_map_string_to_array_free(cArrays)
|
|
|
|
for name, arr := range arrays {
|
|
cName := C.CString(name)
|
|
C.mlx_map_string_to_array_insert(cArrays, cName, arr.c)
|
|
C.free(unsafe.Pointer(cName))
|
|
}
|
|
|
|
// Create metadata map
|
|
cMeta := C.mlx_map_string_to_string_new()
|
|
defer C.mlx_map_string_to_string_free(cMeta)
|
|
|
|
for key, value := range metadata {
|
|
cKey := C.CString(key)
|
|
cValue := C.CString(value)
|
|
C.mlx_map_string_to_string_insert(cMeta, cKey, cValue)
|
|
C.free(unsafe.Pointer(cKey))
|
|
C.free(unsafe.Pointer(cValue))
|
|
}
|
|
|
|
// Save
|
|
if C.mlx_save_safetensors(cPath, cArrays, cMeta) != 0 {
|
|
return fmt.Errorf("failed to save safetensors: %s", path)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ============ NPY Loading ============
|
|
|
|
// LoadNpy loads a numpy array from an npy file
|
|
// Note: Uses CPU stream because Load primitive only runs on CPU
|
|
func LoadNpy(path string) (*Array, error) {
|
|
cPath := C.CString(path)
|
|
defer C.free(unsafe.Pointer(cPath))
|
|
|
|
var arr C.mlx_array
|
|
if C.mlx_load(&arr, cPath, C.cpu_stream()) != 0 {
|
|
return nil, fmt.Errorf("failed to load npy: %s", path)
|
|
}
|
|
if arr.ctx == nil {
|
|
return nil, fmt.Errorf("failed to load npy: %s", path)
|
|
}
|
|
return newArray(arr), nil
|
|
}
|
|
|
|
// ============ Slice Update ============
|
|
|
|
// SliceUpdate updates a slice of the array with new values
|
|
func SliceUpdate(a, update *Array, start, stop []int32) *Array {
|
|
n := len(start)
|
|
cStart := make([]C.int, n)
|
|
cStop := make([]C.int, n)
|
|
cStrides := make([]C.int, n)
|
|
for i := 0; i < n; i++ {
|
|
cStart[i] = C.int(start[i])
|
|
cStop[i] = C.int(stop[i])
|
|
cStrides[i] = 1 // Default stride of 1
|
|
}
|
|
res := C.mlx_array_new()
|
|
C.mlx_slice_update(&res, a.c, update.c, &cStart[0], C.size_t(n), &cStop[0], C.size_t(n), &cStrides[0], C.size_t(n), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// SliceUpdateInplace updates a slice and returns a new array.
|
|
// Note: Despite the name, this is NOT in-place - MLX arrays are immutable.
|
|
// The caller must use the returned value.
|
|
func SliceUpdateInplace(a, update *Array, start, stop []int32) *Array {
|
|
return SliceUpdate(a, update, start, stop)
|
|
}
|
|
|
|
// ============ Optimized Operations ============
|
|
|
|
// SampleArgmax gets the last logit position and returns argmax (fused operation)
|
|
func SampleArgmax(logits *Array) int32 {
|
|
result := Argmax(logits, -1, false)
|
|
return result.ItemInt32()
|
|
}
|
|
|
|
// ArgmaxKeepArray returns argmax as an Array (for pipelining, no sync)
|
|
// This is like mlx-lm's sampler that returns y as an array, not .item()
|
|
func ArgmaxKeepArray(logits *Array) *Array {
|
|
// For greedy decoding: logits shape is [1, 1, vocab]
|
|
// We want argmax over vocab dimension, return shape []
|
|
return Argmax(logits, -1, false)
|
|
}
|
|
|
|
// RandomState is the global PRNG state, analogous to mx.random.state in Python.
|
|
// It's a slice containing a single key array. Random functions use and update this state.
|
|
//
|
|
// Thread safety: Protected by randomStateMu, mimicking Python's GIL behavior.
|
|
// All random functions that use global state acquire this lock.
|
|
var (
|
|
RandomState = []*Array{nil}
|
|
randomStateMu sync.Mutex
|
|
)
|
|
|
|
var (
|
|
mlxInitialized bool
|
|
mlxInitError error
|
|
)
|
|
|
|
// mlxLibName returns the platform-specific shared library filename.
|
|
func mlxLibName() string {
|
|
switch runtime.GOOS {
|
|
case "windows":
|
|
return "mlxc.dll"
|
|
case "darwin":
|
|
return "libmlxc.dylib"
|
|
default:
|
|
return "libmlxc.so"
|
|
}
|
|
}
|
|
|
|
// findMLXLibrary searches for the MLX shared library in standard locations.
|
|
// Returns the path to the library, or empty string if not found.
|
|
func findMLXLibrary() string {
|
|
libName := mlxLibName()
|
|
|
|
// 1. OLLAMA_LIBRARY_PATH — check each dir and mlx_* subdirs
|
|
if paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok {
|
|
for _, dir := range filepath.SplitList(paths) {
|
|
candidate := filepath.Join(dir, libName)
|
|
if _, err := os.Stat(candidate); err == nil {
|
|
return candidate
|
|
}
|
|
if mlxDirs, err := filepath.Glob(filepath.Join(dir, "mlx*")); err == nil {
|
|
for _, mlxDir := range mlxDirs {
|
|
candidate = filepath.Join(mlxDir, libName)
|
|
if _, err := os.Stat(candidate); err == nil {
|
|
return candidate
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 2. Executable directory and lib/ollama/mlx* subdirs
|
|
if exe, err := os.Executable(); err == nil {
|
|
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
|
exe = eval
|
|
}
|
|
exeDir := filepath.Dir(exe)
|
|
|
|
// Check exe dir directly (macOS copies dylib here)
|
|
candidate := filepath.Join(exeDir, libName)
|
|
if _, err := os.Stat(candidate); err == nil {
|
|
return candidate
|
|
}
|
|
|
|
// Check exe_dir/lib/ollama/mlx* subdirectories
|
|
// and exe_dir/../lib/ollama/mlx* (standard bin/lib sibling layout)
|
|
for _, libOllamaDir := range []string{
|
|
filepath.Join(exeDir, "lib", "ollama"),
|
|
filepath.Join(exeDir, "..", "lib", "ollama"),
|
|
} {
|
|
if mlxDirs, err := filepath.Glob(filepath.Join(libOllamaDir, "mlx*")); err == nil {
|
|
for _, mlxDir := range mlxDirs {
|
|
candidate = filepath.Join(mlxDir, libName)
|
|
if _, err := os.Stat(candidate); err == nil {
|
|
return candidate
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 3. Build directory (for tests run from repo root)
|
|
if cwd, err := os.Getwd(); err == nil {
|
|
candidate := filepath.Join(cwd, "build", "lib", "ollama", libName)
|
|
if _, err := os.Stat(candidate); err == nil {
|
|
return candidate
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// InitMLX initializes the MLX library by dynamically loading libmlxc.
|
|
// This must be called before using any MLX functions.
|
|
// Returns an error if the library cannot be loaded.
|
|
func InitMLX() error {
|
|
if mlxInitialized {
|
|
return mlxInitError
|
|
}
|
|
|
|
// Search for the library using Go path discovery
|
|
libPath := findMLXLibrary()
|
|
if libPath == "" {
|
|
mlxInitError = fmt.Errorf("failed to initialize MLX: %s not found", mlxLibName())
|
|
return mlxInitError
|
|
}
|
|
|
|
cPath := C.CString(libPath)
|
|
defer C.free(unsafe.Pointer(cPath))
|
|
if C.mlx_dynamic_init_path(cPath) != 0 {
|
|
errMsg := C.GoString(C.mlx_dynamic_error())
|
|
mlxInitError = fmt.Errorf("failed to initialize MLX: %s", errMsg)
|
|
return mlxInitError
|
|
}
|
|
|
|
// Initialize all function pointers via dlsym
|
|
handle := C.mlx_get_handle()
|
|
if C.mlx_load_functions(handle) != 0 {
|
|
mlxInitError = fmt.Errorf("failed to load MLX function symbols")
|
|
return mlxInitError
|
|
}
|
|
|
|
mlxInitialized = true
|
|
mlxInitError = nil
|
|
return nil
|
|
}
|
|
|
|
// IsMLXAvailable returns whether MLX was successfully initialized
|
|
func IsMLXAvailable() bool {
|
|
return mlxInitialized && mlxInitError == nil
|
|
}
|
|
|
|
// GetMLXInitError returns any error that occurred during MLX initialization
|
|
func GetMLXInitError() error {
|
|
return mlxInitError
|
|
}
|
|
|
|
func init() {
|
|
// Initialize MLX dynamic library first
|
|
if err := InitMLX(); err != nil {
|
|
// Don't panic in init - let the caller handle the error
|
|
// Store the error for later retrieval
|
|
mlxInitError = err
|
|
return
|
|
}
|
|
|
|
// Lock main goroutine to OS thread for CUDA context stability.
|
|
// CUDA contexts are bound to threads; Go can migrate goroutines between threads.
|
|
runtime.LockOSThread()
|
|
RandomState[0] = RandomKey(uint64(time.Now().UnixMilli()))
|
|
Keep(RandomState[0]) // Global state should persist
|
|
}
|
|
|
|
// RandomKey creates a PRNG key from a seed
|
|
func RandomKey(seed uint64) *Array {
|
|
var res C.mlx_array
|
|
C.mlx_random_key(&res, C.uint64_t(seed))
|
|
return newArray(res)
|
|
}
|
|
|
|
// RandomSplit splits a PRNG key into two new keys
|
|
func RandomSplit(key *Array) (*Array, *Array) {
|
|
var key1, key2 C.mlx_array
|
|
C.mlx_random_split(&key1, &key2, key.c, C.default_stream())
|
|
return newArray(key1), newArray(key2)
|
|
}
|
|
|
|
// RandomCategoricalWithKey samples from categorical distribution using provided key.
|
|
func RandomCategoricalWithKey(logits, key *Array, axis int, numSamples int) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_random_categorical_num_samples(&res, logits.c, C.int(axis), C.int(numSamples), key.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// RandomCategorical samples using global RandomState.
|
|
// For simple scripts - production code should use RandomCategoricalWithKey with explicit key management.
|
|
func RandomCategorical(logits *Array, axis int, numSamples int) *Array {
|
|
randomStateMu.Lock()
|
|
oldKey := RandomState[0]
|
|
key1, key2 := RandomSplit(oldKey)
|
|
Keep(key1) // key1 becomes the new global state
|
|
oldKey.Free()
|
|
RandomState[0] = key1
|
|
randomStateMu.Unlock()
|
|
return RandomCategoricalWithKey(logits, key2, axis, numSamples)
|
|
}
|
|
|
|
// RandomNormal creates a random normal (Gaussian) tensor in float32
|
|
func RandomNormal(shape []int32, seed uint64) *Array {
|
|
return RandomNormalWithDtype(shape, seed, DtypeFloat32)
|
|
}
|
|
|
|
// RandomNormalWithDtype creates a random normal (Gaussian) tensor with specified dtype
|
|
func RandomNormalWithDtype(shape []int32, seed uint64, dtype Dtype) *Array {
|
|
key := RandomKey(seed)
|
|
res := C.mlx_array_new()
|
|
C.mlx_random_normal(&res, int32ToCInt(shape), C.size_t(len(shape)), C.mlx_dtype(dtype), 0.0, 1.0, key.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// RandomUniform generates uniform random values in [0, 1) with the given shape
|
|
func RandomUniform(shape []int32, seed uint64) *Array {
|
|
key := RandomKey(seed)
|
|
low := C.mlx_array_new_float(0.0)
|
|
high := C.mlx_array_new_float(1.0)
|
|
res := C.mlx_array_new()
|
|
C.mlx_random_uniform(&res, low, high, int32ToCInt(shape), C.size_t(len(shape)), C.MLX_FLOAT32, key.c, C.default_stream())
|
|
C.mlx_array_free(low)
|
|
C.mlx_array_free(high)
|
|
return newArray(res)
|
|
}
|
|
|
|
// Conv2d performs 2D convolution
|
|
// input: [N, H, W, C], weight: [O, kH, kW, C] (MLX uses NHWC layout)
|
|
// Returns: [N, H', W', O]
|
|
func Conv2d(input, weight *Array, stride, padding int32) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_conv2d(&res, input.c, weight.c, C.int(stride), C.int(stride), C.int(padding), C.int(padding), 1, 1, 1, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Conv3d performs 3D convolution
|
|
// input: [N, D, H, W, C], weight: [O, kD, kH, kW, C] (MLX uses NDHWC layout)
|
|
// Returns: [N, D', H', W', O]
|
|
func Conv3d(input, weight *Array, strideD, strideH, strideW, padD, padH, padW int32) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_conv3d(&res, input.c, weight.c, C.int(strideD), C.int(strideH), C.int(strideW), C.int(padD), C.int(padH), C.int(padW), 1, 1, 1, 1, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// ============ Compilation Control ============
|
|
|
|
// EnableCompile enables global compilation/graph fusion
|
|
func EnableCompile() {
|
|
C.mlx_enable_compile()
|
|
}
|
|
|
|
// DisableCompile disables global compilation
|
|
func DisableCompile() {
|
|
C.mlx_disable_compile()
|
|
}
|
|
|
|
// SetCompileMode sets the compile mode
|
|
// 0=disabled, 1=no_simplify, 2=no_fuse, 3=enabled
|
|
func SetCompileMode(mode int) {
|
|
C.mlx_set_compile_mode(C.mlx_compile_mode(mode))
|
|
}
|
|
|
|
// ============ Stream Control ============
|
|
|
|
// Stream represents an MLX execution stream
|
|
type Stream struct {
|
|
c C.mlx_stream
|
|
}
|
|
|
|
// NewStream creates a new execution stream on the default device
|
|
func NewStream() *Stream {
|
|
var dev C.mlx_device
|
|
C.mlx_get_default_device(&dev)
|
|
stream := C.mlx_stream_new_device(dev)
|
|
C.mlx_device_free(dev)
|
|
return &Stream{c: stream}
|
|
}
|
|
|
|
// Free releases the stream
|
|
func (s *Stream) Free() {
|
|
if s.c.ctx != nil {
|
|
C.mlx_stream_free(s.c)
|
|
s.c.ctx = nil
|
|
}
|
|
}
|
|
|
|
// SetDefaultStream sets the default stream for operations
|
|
func SetDefaultStream(s *Stream) {
|
|
C.mlx_set_default_stream(s.c)
|
|
C.set_default_stream(s.c) // Also update our cached stream
|
|
}
|
|
|
|
// GetDefaultStream returns the current default stream
|
|
func GetDefaultStream() *Stream {
|
|
var stream C.mlx_stream
|
|
var dev C.mlx_device
|
|
C.mlx_get_default_device(&dev)
|
|
C.mlx_get_default_stream(&stream, dev)
|
|
C.mlx_device_free(dev)
|
|
return &Stream{c: stream}
|
|
}
|
|
|
|
// SynchronizeStream waits for all operations on the stream to complete
|
|
func SynchronizeStream(s *Stream) {
|
|
C.mlx_synchronize(s.c)
|
|
}
|
|
|
|
// ============ Metal Memory Control ============
|
|
|
|
// MetalGetCacheMemory returns the current cache memory usage in bytes
|
|
func MetalGetCacheMemory() uint64 {
|
|
var size C.size_t
|
|
C.mlx_get_cache_memory(&size)
|
|
return uint64(size)
|
|
}
|
|
|
|
// MetalGetPeakMemory returns the peak memory usage in bytes
|
|
func MetalGetPeakMemory() uint64 {
|
|
var size C.size_t
|
|
C.mlx_get_peak_memory(&size)
|
|
return uint64(size)
|
|
}
|
|
|
|
// MetalResetPeakMemory resets the peak memory counter
|
|
func MetalResetPeakMemory() {
|
|
C.mlx_reset_peak_memory()
|
|
}
|
|
|
|
// MetalSetWiredLimit sets the wired memory limit and returns the previous limit
|
|
// This keeps tensors pinned in GPU memory for faster access
|
|
func MetalSetWiredLimit(limit uint64) uint64 {
|
|
var prev C.size_t
|
|
C.mlx_set_wired_limit(&prev, C.size_t(limit))
|
|
return uint64(prev)
|
|
}
|
|
|
|
// MetalGetActiveMemory returns the current active memory usage in bytes
|
|
func MetalGetActiveMemory() uint64 {
|
|
var size C.size_t
|
|
C.mlx_get_active_memory(&size)
|
|
return uint64(size)
|
|
}
|
|
|
|
// ClearCache clears the MLX memory cache
|
|
func ClearCache() {
|
|
C.mlx_clear_cache()
|
|
}
|
|
|
|
// SetCacheLimit sets the free cache limit in bytes
|
|
// Setting to 0 disables caching (useful for memory-constrained generation)
|
|
// Returns the previous cache limit
|
|
func SetCacheLimit(limit uint64) uint64 {
|
|
var prev C.size_t
|
|
C.mlx_set_cache_limit(&prev, C.size_t(limit))
|
|
return uint64(prev)
|
|
}
|
|
|
|
// SetMemoryLimit sets the overall memory limit in bytes
|
|
// This is a guideline for maximum memory during graph evaluation.
|
|
// When Metal is available, defaults to 1.5x the max recommended working set.
|
|
// Returns the previous memory limit
|
|
func SetMemoryLimit(limit uint64) uint64 {
|
|
var prev C.size_t
|
|
C.mlx_set_memory_limit(&prev, C.size_t(limit))
|
|
return uint64(prev)
|
|
}
|
|
|
|
// GetMemoryLimit returns the current memory limit in bytes
|
|
func GetMemoryLimit() uint64 {
|
|
var size C.size_t
|
|
C.mlx_get_memory_limit(&size)
|
|
return uint64(size)
|
|
}
|
|
|
|
// ============ MoE Operations ============
|
|
|
|
// GatherMM performs gather matrix multiplication for MoE
|
|
// a: input, b: weight matrices
|
|
// lhsIndices, rhsIndices: optional expert selection indices (nil for none)
|
|
func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *Array {
|
|
var lhs, rhs C.mlx_array
|
|
if lhsIndices != nil {
|
|
lhs = lhsIndices.c
|
|
}
|
|
if rhsIndices != nil {
|
|
rhs = rhsIndices.c
|
|
}
|
|
res := C.mlx_array_new()
|
|
C.mlx_gather_mm(&res, a.c, b.c, lhs, rhs, C._Bool(sortedIndices), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// GatherQMM performs quantized gather matrix multiplication for MoE
|
|
// Used for MXFP4 and other quantized MoE inference
|
|
func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, transpose bool, groupSize, bits int, mode string, sortedIndices bool) *Array {
|
|
var b, lhs, rhs C.mlx_array
|
|
if biases != nil {
|
|
b = biases.c
|
|
}
|
|
if lhsIndices != nil {
|
|
lhs = lhsIndices.c
|
|
}
|
|
if rhsIndices != nil {
|
|
rhs = rhsIndices.c
|
|
}
|
|
cMode := C.CString(mode)
|
|
defer C.free(unsafe.Pointer(cMode))
|
|
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
|
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
|
res := C.mlx_array_new()
|
|
C.mlx_gather_qmm(&res, x.c, w.c, scales.c, b, lhs, rhs, C._Bool(transpose), optGroupSize, optBits, cMode, C._Bool(sortedIndices), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// ============ Quantization ============
|
|
|
|
// Quantize quantizes weights to specified bits per element.
|
|
// Returns (quantized_weights, scales, biases).
|
|
// groupSize: number of elements quantized together (default 64)
|
|
// bits: bits per element, 2, 4, or 8 (default 4)
|
|
// mode: "affine" (default), "mxfp4", or "mxfp8"
|
|
// Note: mxfp8 mode returns nil biases (only weights and scales)
|
|
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
|
|
cMode := C.CString(mode)
|
|
defer C.free(unsafe.Pointer(cMode))
|
|
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
|
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
|
res := C.mlx_vector_array_new()
|
|
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
|
|
|
|
// Result is a vector of arrays: [weights, scales, biases?]
|
|
// mxfp8 mode returns only 2 elements (no biases)
|
|
vecSize := int(C.mlx_vector_array_size(res))
|
|
var w0, w1, w2 C.mlx_array
|
|
C.mlx_vector_array_get(&w0, res, 0)
|
|
C.mlx_vector_array_get(&w1, res, 1)
|
|
if vecSize >= 3 {
|
|
C.mlx_vector_array_get(&w2, res, 2)
|
|
}
|
|
C.mlx_vector_array_free(res)
|
|
|
|
if vecSize >= 3 {
|
|
return newArray(w0), newArray(w1), newArray(w2)
|
|
}
|
|
return newArray(w0), newArray(w1), nil
|
|
}
|
|
|
|
// Dequantize reconstructs weights from quantized form.
|
|
// groupSize: number of elements quantized together (default 64)
|
|
// bits: bits per element, 2, 4, or 8 (default 4)
|
|
// mode: "affine" (default) or "mxfp4"
|
|
func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array {
|
|
cMode := C.CString(mode)
|
|
defer C.free(unsafe.Pointer(cMode))
|
|
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
|
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
|
optDtype := C.mlx_optional_dtype{has_value: false}
|
|
|
|
var b C.mlx_array
|
|
if biases != nil {
|
|
b = biases.c
|
|
}
|
|
|
|
res := C.mlx_array_new()
|
|
C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, optDtype, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// QuantizedMatmul performs matrix multiplication with quantized weights.
|
|
// x: input tensor [batch..., in_features]
|
|
// w: quantized weights
|
|
// scales, biases: from Quantize
|
|
// transpose: if true, compute x @ w.T (typical for Linear layers)
|
|
// groupSize, bits, mode: must match what was used in Quantize
|
|
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array {
|
|
cMode := C.CString(mode)
|
|
defer C.free(unsafe.Pointer(cMode))
|
|
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
|
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
|
|
|
var b C.mlx_array
|
|
if biases != nil {
|
|
b = biases.c
|
|
}
|
|
|
|
res := C.mlx_array_new()
|
|
C.mlx_quantized_matmul(&res, x.c, w.c, scales.c, b, C._Bool(transpose), optGroupSize, optBits, cMode, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// ============ Sorting and Top-K ============
|
|
|
|
// TopK returns the k largest elements along an axis
|
|
func TopK(a *Array, k int, axis int) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_topk_axis(&res, a.c, C.int(k), C.int(axis), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Argpartition returns indices for partial sort (k-th smallest first)
|
|
func Argpartition(a *Array, kth int, axis int) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_argpartition_axis(&res, a.c, C.int(kth), C.int(axis), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// TakeAlongAxis takes elements from array using indices along axis
|
|
func TakeAlongAxis(a, indices *Array, axis int) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_take_along_axis(&res, a.c, indices.c, C.int(axis), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// PutAlongAxis puts values into array at indices along axis
|
|
func PutAlongAxis(a, indices, values *Array, axis int) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_put_along_axis(&res, a.c, indices.c, values.c, C.int(axis), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Cumsum computes cumulative sum along an axis
|
|
func Cumsum(a *Array, axis int) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_cumsum(&res, a.c, C.int(axis), false, false, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// Where selects elements: condition ? a : b
|
|
func Where(condition, a, b *Array) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_where(&res, condition.c, a.c, b.c, C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// LessScalar returns element-wise a < scalar
|
|
func LessScalar(a *Array, s float32) *Array {
|
|
scalar := C.mlx_array_new_float(C.float(s))
|
|
res := C.mlx_array_new()
|
|
C.mlx_less(&res, a.c, scalar, C.default_stream())
|
|
C.mlx_array_free(scalar)
|
|
return newArray(res)
|
|
}
|
|
|
|
// FullDtype creates an array filled with a value with specific dtype
|
|
func FullDtype(value float32, dtype Dtype, shape ...int32) *Array {
|
|
intShape := make([]C.int, len(shape))
|
|
for i, s := range shape {
|
|
intShape[i] = C.int(s)
|
|
}
|
|
vals := C.mlx_array_new_float(C.float(value))
|
|
res := C.mlx_array_new()
|
|
C.mlx_full(&res, &intShape[0], C.size_t(len(shape)), vals, C.mlx_dtype(dtype), C.default_stream())
|
|
C.mlx_array_free(vals)
|
|
return newArray(res)
|
|
}
|
|
|
|
// AsType casts an array to a different dtype
|
|
func AsType(a *Array, dtype Dtype) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_astype(&res, a.c, C.mlx_dtype(dtype), C.default_stream())
|
|
return newArray(res)
|
|
}
|
|
|
|
// ToBFloat16 casts an array to bfloat16
|
|
func ToBFloat16(a *Array) *Array {
|
|
return AsType(a, DtypeBFloat16)
|
|
}
|
|
|
|
// ============ VibeVoice Helper Functions ============
|
|
|
|
// NewScalarArray creates a true 0-dimensional scalar array from a float32 value
|
|
func NewScalarArray(value float32) *Array {
|
|
return newArray(C.mlx_array_new_float(C.float(value)))
|
|
}
|
|
|
|
// Global random seed counter for RandN
|
|
var randnSeedCounter uint64 = uint64(time.Now().UnixNano())
|
|
|
|
// RandN creates an array of random samples from a standard normal distribution
|
|
func RandN(shape []int32) *Array {
|
|
// Use incrementing seed for unique random values each call
|
|
seed := atomic.AddUint64(&randnSeedCounter, 1)
|
|
return RandomNormal(shape, seed)
|
|
}
|
|
|
|
// Pad pads an array with zeros
|
|
// paddings: [before_0, after_0, before_1, after_1, ...] for each dimension
|
|
func Pad(a *Array, paddings []int32) *Array {
|
|
numAxes := len(paddings) / 2
|
|
// Convert to low/high pairs
|
|
lowPad := make([]C.int, numAxes)
|
|
highPad := make([]C.int, numAxes)
|
|
for i := 0; i < numAxes; i++ {
|
|
lowPad[i] = C.int(paddings[i*2])
|
|
highPad[i] = C.int(paddings[i*2+1])
|
|
}
|
|
zero := C.mlx_array_new_float(0.0)
|
|
res := C.mlx_array_new()
|
|
// mlx_pad takes axes, low, high arrays
|
|
axes := make([]C.int, numAxes)
|
|
for i := 0; i < numAxes; i++ {
|
|
axes[i] = C.int(i)
|
|
}
|
|
cMode := C.CString("constant")
|
|
defer C.free(unsafe.Pointer(cMode))
|
|
C.mlx_pad(&res, a.c, &axes[0], C.size_t(numAxes), &lowPad[0], C.size_t(numAxes), &highPad[0], C.size_t(numAxes), zero, cMode, C.default_stream())
|
|
C.mlx_array_free(zero)
|
|
return newArray(res)
|
|
}
|
|
|
|
// Conv1d performs 1D convolution
|
|
// x: [B, L, Cin], weight: [Cout, K, Cin] (MLX uses NLC layout)
|
|
// bias: optional (nil for no bias)
|
|
func Conv1d(x, weight *Array, bias *Array, stride int32) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_conv1d(&res, x.c, weight.c, C.int(stride), C.int(0), C.int(1), 1, C.default_stream())
|
|
// Apply bias if provided
|
|
if bias != nil {
|
|
biased := C.mlx_array_new()
|
|
C.mlx_add(&biased, res, bias.c, C.default_stream())
|
|
C.mlx_array_free(res)
|
|
return newArray(biased)
|
|
}
|
|
return newArray(res)
|
|
}
|
|
|
|
// ConvTranspose1d performs transposed 1D convolution
|
|
// x: [B, L, Cin], weight: [Cout, K, Cin] (MLX uses NLC layout)
|
|
// bias: optional (nil for no bias)
|
|
func ConvTranspose1d(x, weight *Array, bias *Array, stride int32) *Array {
|
|
res := C.mlx_array_new()
|
|
// stride, padding, dilation, output_padding, groups
|
|
C.mlx_conv_transpose1d(&res, x.c, weight.c, C.int(stride), 0, 1, 0, 1, C.default_stream())
|
|
// Apply bias if provided
|
|
if bias != nil {
|
|
biased := C.mlx_array_new()
|
|
C.mlx_add(&biased, res, bias.c, C.default_stream())
|
|
C.mlx_array_free(res)
|
|
return newArray(biased)
|
|
}
|
|
return newArray(res)
|
|
}
|
|
|
|
// DepthwiseConv1d performs depthwise 1D convolution (groups=Cin)
|
|
// x: [B, L, C], weight: [1, K, C] (groups = C)
|
|
// bias: optional (nil for no bias)
|
|
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
|
// Get number of input channels for groups
|
|
shape := x.Shape()
|
|
groups := int(shape[len(shape)-1])
|
|
res := C.mlx_array_new()
|
|
C.mlx_conv1d(&res, x.c, weight.c, 1, 0, 1, C.int(groups), C.default_stream())
|
|
// Apply bias if provided
|
|
if bias != nil {
|
|
biased := C.mlx_array_new()
|
|
C.mlx_add(&biased, res, bias.c, C.default_stream())
|
|
C.mlx_array_free(res)
|
|
return newArray(biased)
|
|
}
|
|
return newArray(res)
|
|
}
|
|
|
|
// SliceAxis extracts a slice along a specific axis
|
|
func SliceAxis(a *Array, axis int, start, stop int32) *Array {
|
|
shape := a.Shape()
|
|
|
|
// Build start and stop indices for all dimensions
|
|
starts := make([]int32, len(shape))
|
|
stops := make([]int32, len(shape))
|
|
for i := range shape {
|
|
if i == axis {
|
|
starts[i] = start
|
|
stops[i] = stop
|
|
} else {
|
|
starts[i] = 0
|
|
stops[i] = shape[i]
|
|
}
|
|
}
|
|
|
|
return Slice(a, starts, stops)
|
|
}
|
|
|
|
// Tri creates a lower triangular matrix
|
|
func Tri(n, m int32, k int) *Array {
|
|
res := C.mlx_array_new()
|
|
C.mlx_tri(&res, C.int(n), C.int(m), C.int(k), C.MLX_FLOAT32, C.default_stream())
|
|
return newArray(res)
|
|
}
|