Files
ollama/x/mlxrunner/mlx/fast.go
Jesse Gross 04f5f0cdb4 mlx: improve thread safety of array management
Use atomic.Int32 for Array.pinned and a sync.Mutex for the global
arrays slice so MLX arrays can be created and pinned from multiple
goroutines without racing on those structures. Convert Array value
receivers to pointer receivers and struct fields from Array to
*Array to avoid copying the atomic.

This does not fully achieve thread safety even when building
completely independent graphs. The tracing flag and traceScratch
slice in compile.go are unprotected, so concurrent Compile calls
will race. MLX itself is not fully thread-safe either although
it is working to improve.
2026-04-21 14:38:49 -07:00

73 lines
1.5 KiB
Go

package mlx
// #include "generated.h"
import "C"
import (
"unsafe"
)
func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *Array {
if mask == nil {
mask = New("")
}
sinks := New("")
mode := "causal"
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA")
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
return out
}
type LayerNorm struct {
Weight *Array `weight:"weight"`
Bias *Array `weight:"bias"`
}
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_LAYERNORM")
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
return out
}
type RMSNorm struct {
Weight *Array `weight:"weight"`
}
func (r *RMSNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_RMSNORM")
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}
type RoPE struct {
Dims int
Traditional bool
Base float32 `json:"rope_theta"`
Scale float32
}
func (r RoPE) Forward(t *Array, offset int) *Array {
freqs := New("")
out := New("FAST_ROPE")
C.mlx_fast_rope(
&out.ctx,
t.ctx,
C.int(r.Dims),
C._Bool(r.Traditional),
C.mlx_optional_float{
value: C.float(r.Base),
has_value: C._Bool(func() bool { return r.Base != 0 }()),
},
C.float(r.Scale),
C.int(offset),
freqs.ctx,
DefaultStream().ctx,
)
return out
}