From d137b850b6bc208b2d676b16947be8f0b43a7708 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 31 Mar 2026 14:15:04 -0700 Subject: [PATCH] mlx: make array management thread-safe Use atomic.Int32 for Array.pinned and a sync.Mutex for the global arrays slice so MLX arrays can be safely created and managed from multiple goroutines. Convert Array value receivers to pointer receivers and struct fields from Array to *Array to avoid copying the atomic. --- x/mlxrunner/mlx/array.go | 48 ++++++++++++++++++++++++---------------- x/mlxrunner/mlx/fast.go | 8 +++---- x/mlxrunner/mlx/nn.go | 10 ++++----- 3 files changed, 38 insertions(+), 28 deletions(-) diff --git a/x/mlxrunner/mlx/array.go b/x/mlxrunner/mlx/array.go index 198162efd..00bf92153 100644 --- a/x/mlxrunner/mlx/array.go +++ b/x/mlxrunner/mlx/array.go @@ -10,6 +10,8 @@ import ( "reflect" "sort" "strings" + "sync" + "sync/atomic" "unsafe" "github.com/ollama/ollama/logutil" @@ -18,15 +20,20 @@ import ( type Array struct { ctx C.mlx_array name string - pinned int + pinned atomic.Int32 } -var arrays []*Array +var ( + arrays []*Array + arraysMu sync.Mutex +) // constructor utilities func New(name string) *Array { t := &Array{name: name} + arraysMu.Lock() + defer arraysMu.Unlock() arrays = append(arrays, t) return t } @@ -127,7 +134,7 @@ func (t *Array) Clone() *Array { func Pin(s ...*Array) { for _, t := range s { if t != nil { - t.pinned++ + t.pinned.Add(1) } } } @@ -136,8 +143,7 @@ func Pin(s ...*Array) { func Unpin(s ...*Array) { for _, t := range s { if t != nil { - t.pinned-- - if t.pinned < 0 { + if t.pinned.Add(-1) < 0 { panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name)) } } @@ -147,9 +153,11 @@ func Unpin(s ...*Array) { // Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly // free them when there are no other references, including dependencies in the graph. func Sweep() { + arraysMu.Lock() + defer arraysMu.Unlock() n := 0 for _, t := range arrays { - if t.pinned > 0 && t.Valid() { + if t.pinned.Load() > 0 && t.Valid() { arrays[n] = t n++ } else if t.Valid() { @@ -176,7 +184,7 @@ func (t *Array) String() string { func (t *Array) LogValue() slog.Value { attrs := []slog.Attr{ slog.String("name", t.name), - slog.Int("pinned", t.pinned), + slog.Int("pinned", int(t.pinned.Load())), } if t.Valid() { attrs = append(attrs, @@ -190,19 +198,19 @@ func (t *Array) LogValue() slog.Value { // shape utilities -func (t Array) Size() int { +func (t *Array) Size() int { return int(C.mlx_array_size(t.ctx)) } -func (t Array) NumBytes() int { +func (t *Array) NumBytes() int { return int(C.mlx_array_nbytes(t.ctx)) } -func (t Array) NumDims() int { +func (t *Array) NumDims() int { return int(C.mlx_array_ndim(t.ctx)) } -func (t Array) Dims() []int { +func (t *Array) Dims() []int { dims := make([]int, t.NumDims()) for i := range dims { dims[i] = t.Dim(i) @@ -211,29 +219,29 @@ func (t Array) Dims() []int { return dims } -func (t Array) Dim(dim int) int { +func (t *Array) Dim(dim int) int { return int(C.mlx_array_dim(t.ctx, C.int(dim))) } -func (t Array) DType() DType { +func (t *Array) DType() DType { return DType(C.mlx_array_dtype(t.ctx)) } // data utilities -func (t Array) Int() int { +func (t *Array) Int() int { var item C.int64_t C.mlx_array_item_int64(&item, t.ctx) return int(item) } -func (t Array) Float() float64 { +func (t *Array) Float() float64 { var item C.double C.mlx_array_item_float64(&item, t.ctx) return float64(item) } -func (t Array) Ints() []int { +func (t *Array) Ints() []int { ints := make([]int, t.Size()) for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) { ints[i] = int(f) @@ -241,7 +249,7 @@ func (t Array) Ints() []int { return ints } -func (t Array) Floats() []float32 { +func (t *Array) Floats() []float32 { floats := make([]float32, t.Size()) for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) { floats[i] = float32(f) @@ -249,7 +257,7 @@ func (t Array) Floats() []float32 { return floats } -func (t Array) Save(name string) error { +func (t *Array) Save(name string) error { cName := C.CString(name) defer C.free(unsafe.Pointer(cName)) C.mlx_save(cName, t.ctx) @@ -258,6 +266,8 @@ func (t Array) Save(name string) error { // LogArrays logs all live arrays, sorted by size func LogArrays() { + arraysMu.Lock() + defer arraysMu.Unlock() sort.Slice(arrays, func(i, j int) bool { return arrays[i].NumBytes() > arrays[j].NumBytes() }) @@ -266,7 +276,7 @@ func LogArrays() { for _, t := range arrays { nb := t.NumBytes() total += nb - logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims())) + logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned.Load(), t.Dims())) } logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s, active: %s", len(arrays), PrettyBytes(total), PrettyBytes(ActiveMemory()))) } diff --git a/x/mlxrunner/mlx/fast.go b/x/mlxrunner/mlx/fast.go index 7feca3b1e..d5b218d1c 100644 --- a/x/mlxrunner/mlx/fast.go +++ b/x/mlxrunner/mlx/fast.go @@ -24,8 +24,8 @@ func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *A } type LayerNorm struct { - Weight Array `weight:"weight"` - Bias Array `weight:"bias"` + Weight *Array `weight:"weight"` + Bias *Array `weight:"bias"` } func (r *LayerNorm) Forward(x *Array, eps float32) *Array { @@ -35,10 +35,10 @@ func (r *LayerNorm) Forward(x *Array, eps float32) *Array { } type RMSNorm struct { - Weight Array `weight:"weight"` + Weight *Array `weight:"weight"` } -func (r RMSNorm) Forward(x *Array, eps float32) *Array { +func (r *RMSNorm) Forward(x *Array, eps float32) *Array { out := New("FAST_RMSNORM") C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx) return out diff --git a/x/mlxrunner/mlx/nn.go b/x/mlxrunner/mlx/nn.go index d3a99a6cd..d2e7fb4f1 100644 --- a/x/mlxrunner/mlx/nn.go +++ b/x/mlxrunner/mlx/nn.go @@ -1,12 +1,12 @@ package mlx type Linear struct { - Weight Array `weight:"weight"` - Bias Array `weight:"bias"` + Weight *Array `weight:"weight"` + Bias *Array `weight:"bias"` } // Forward computes the linear transformation: x @ Weight.T + Bias -func (m Linear) Forward(x *Array) *Array { +func (m *Linear) Forward(x *Array) *Array { w := m.Weight.Transpose(1, 0) if m.Bias.Valid() { return m.Bias.Addmm(x, w, 1.0, 1.0) @@ -15,14 +15,14 @@ func (m Linear) Forward(x *Array) *Array { return x.Matmul(w) } -func (m Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array { +func (m *Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array { w := m.Weight.Transpose(0, 2, 1) // TODO: bias return x.GatherMM(w, lhs, rhs, sorted) } type Embedding struct { - Weight Array `weight:"weight"` + Weight *Array `weight:"weight"` } func (e *Embedding) Forward(indices *Array) *Array {