mlx: make array management thread-safe

Use atomic.Int32 for Array.pinned and a sync.Mutex for the global
arrays slice so MLX arrays can be safely created and managed from
multiple goroutines. Convert Array value receivers to pointer
receivers and struct fields from Array to *Array to avoid copying
the atomic.
This commit is contained in:
Jesse Gross
2026-03-31 14:15:04 -07:00
parent 4bc2728047
commit d137b850b6
3 changed files with 38 additions and 28 deletions

View File

@@ -10,6 +10,8 @@ import (
"reflect"
"sort"
"strings"
"sync"
"sync/atomic"
"unsafe"
"github.com/ollama/ollama/logutil"
@@ -18,15 +20,20 @@ import (
type Array struct {
ctx C.mlx_array
name string
pinned int
pinned atomic.Int32
}
var arrays []*Array
var (
arrays []*Array
arraysMu sync.Mutex
)
// constructor utilities
func New(name string) *Array {
t := &Array{name: name}
arraysMu.Lock()
defer arraysMu.Unlock()
arrays = append(arrays, t)
return t
}
@@ -127,7 +134,7 @@ func (t *Array) Clone() *Array {
func Pin(s ...*Array) {
for _, t := range s {
if t != nil {
t.pinned++
t.pinned.Add(1)
}
}
}
@@ -136,8 +143,7 @@ func Pin(s ...*Array) {
func Unpin(s ...*Array) {
for _, t := range s {
if t != nil {
t.pinned--
if t.pinned < 0 {
if t.pinned.Add(-1) < 0 {
panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name))
}
}
@@ -147,9 +153,11 @@ func Unpin(s ...*Array) {
// Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly
// free them when there are no other references, including dependencies in the graph.
func Sweep() {
arraysMu.Lock()
defer arraysMu.Unlock()
n := 0
for _, t := range arrays {
if t.pinned > 0 && t.Valid() {
if t.pinned.Load() > 0 && t.Valid() {
arrays[n] = t
n++
} else if t.Valid() {
@@ -176,7 +184,7 @@ func (t *Array) String() string {
func (t *Array) LogValue() slog.Value {
attrs := []slog.Attr{
slog.String("name", t.name),
slog.Int("pinned", t.pinned),
slog.Int("pinned", int(t.pinned.Load())),
}
if t.Valid() {
attrs = append(attrs,
@@ -190,19 +198,19 @@ func (t *Array) LogValue() slog.Value {
// shape utilities
func (t Array) Size() int {
func (t *Array) Size() int {
return int(C.mlx_array_size(t.ctx))
}
func (t Array) NumBytes() int {
func (t *Array) NumBytes() int {
return int(C.mlx_array_nbytes(t.ctx))
}
func (t Array) NumDims() int {
func (t *Array) NumDims() int {
return int(C.mlx_array_ndim(t.ctx))
}
func (t Array) Dims() []int {
func (t *Array) Dims() []int {
dims := make([]int, t.NumDims())
for i := range dims {
dims[i] = t.Dim(i)
@@ -211,29 +219,29 @@ func (t Array) Dims() []int {
return dims
}
func (t Array) Dim(dim int) int {
func (t *Array) Dim(dim int) int {
return int(C.mlx_array_dim(t.ctx, C.int(dim)))
}
func (t Array) DType() DType {
func (t *Array) DType() DType {
return DType(C.mlx_array_dtype(t.ctx))
}
// data utilities
func (t Array) Int() int {
func (t *Array) Int() int {
var item C.int64_t
C.mlx_array_item_int64(&item, t.ctx)
return int(item)
}
func (t Array) Float() float64 {
func (t *Array) Float() float64 {
var item C.double
C.mlx_array_item_float64(&item, t.ctx)
return float64(item)
}
func (t Array) Ints() []int {
func (t *Array) Ints() []int {
ints := make([]int, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
ints[i] = int(f)
@@ -241,7 +249,7 @@ func (t Array) Ints() []int {
return ints
}
func (t Array) Floats() []float32 {
func (t *Array) Floats() []float32 {
floats := make([]float32, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
floats[i] = float32(f)
@@ -249,7 +257,7 @@ func (t Array) Floats() []float32 {
return floats
}
func (t Array) Save(name string) error {
func (t *Array) Save(name string) error {
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
C.mlx_save(cName, t.ctx)
@@ -258,6 +266,8 @@ func (t Array) Save(name string) error {
// LogArrays logs all live arrays, sorted by size
func LogArrays() {
arraysMu.Lock()
defer arraysMu.Unlock()
sort.Slice(arrays, func(i, j int) bool {
return arrays[i].NumBytes() > arrays[j].NumBytes()
})
@@ -266,7 +276,7 @@ func LogArrays() {
for _, t := range arrays {
nb := t.NumBytes()
total += nb
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims()))
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned.Load(), t.Dims()))
}
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s, active: %s", len(arrays), PrettyBytes(total), PrettyBytes(ActiveMemory())))
}