Files
ollama/x/mlxrunner/mlx/fast.go
Jesse Gross d137b850b6 mlx: make array management thread-safe
Use atomic.Int32 for Array.pinned and a sync.Mutex for the global
arrays slice so MLX arrays can be safely created and managed from
multiple goroutines. Convert Array value receivers to pointer
receivers and struct fields from Array to *Array to avoid copying
the atomic.
2026-04-03 19:50:41 -07:00

73 lines
1.5 KiB
Go

package mlx
// #include "generated.h"
import "C"
import (
"unsafe"
)
func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *Array {
if mask == nil {
mask = New("")
}
sinks := New("")
mode := "causal"
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA")
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
return out
}
type LayerNorm struct {
Weight *Array `weight:"weight"`
Bias *Array `weight:"bias"`
}
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_LAYERNORM")
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
return out
}
type RMSNorm struct {
Weight *Array `weight:"weight"`
}
func (r *RMSNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_RMSNORM")
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}
type RoPE struct {
Dims int
Traditional bool
Base float32 `json:"rope_theta"`
Scale float32
}
func (r RoPE) Forward(t *Array, offset int) *Array {
freqs := New("")
out := New("FAST_ROPE")
C.mlx_fast_rope(
&out.ctx,
t.ctx,
C.int(r.Dims),
C._Bool(r.Traditional),
C.mlx_optional_float{
value: C.float(r.Base),
has_value: C._Bool(func() bool { return r.Base != 0 }()),
},
C.float(r.Scale),
C.int(offset),
freqs.ctx,
DefaultStream().ctx,
)
return out
}