diff --git a/x/mlxrunner/mlx/array.go b/x/mlxrunner/mlx/array.go index 198162efd..00bf92153 100644 --- a/x/mlxrunner/mlx/array.go +++ b/x/mlxrunner/mlx/array.go @@ -10,6 +10,8 @@ import ( "reflect" "sort" "strings" + "sync" + "sync/atomic" "unsafe" "github.com/ollama/ollama/logutil" @@ -18,15 +20,20 @@ import ( type Array struct { ctx C.mlx_array name string - pinned int + pinned atomic.Int32 } -var arrays []*Array +var ( + arrays []*Array + arraysMu sync.Mutex +) // constructor utilities func New(name string) *Array { t := &Array{name: name} + arraysMu.Lock() + defer arraysMu.Unlock() arrays = append(arrays, t) return t } @@ -127,7 +134,7 @@ func (t *Array) Clone() *Array { func Pin(s ...*Array) { for _, t := range s { if t != nil { - t.pinned++ + t.pinned.Add(1) } } } @@ -136,8 +143,7 @@ func Pin(s ...*Array) { func Unpin(s ...*Array) { for _, t := range s { if t != nil { - t.pinned-- - if t.pinned < 0 { + if t.pinned.Add(-1) < 0 { panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name)) } } @@ -147,9 +153,11 @@ func Unpin(s ...*Array) { // Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly // free them when there are no other references, including dependencies in the graph. func Sweep() { + arraysMu.Lock() + defer arraysMu.Unlock() n := 0 for _, t := range arrays { - if t.pinned > 0 && t.Valid() { + if t.pinned.Load() > 0 && t.Valid() { arrays[n] = t n++ } else if t.Valid() { @@ -176,7 +184,7 @@ func (t *Array) String() string { func (t *Array) LogValue() slog.Value { attrs := []slog.Attr{ slog.String("name", t.name), - slog.Int("pinned", t.pinned), + slog.Int("pinned", int(t.pinned.Load())), } if t.Valid() { attrs = append(attrs, @@ -190,19 +198,19 @@ func (t *Array) LogValue() slog.Value { // shape utilities -func (t Array) Size() int { +func (t *Array) Size() int { return int(C.mlx_array_size(t.ctx)) } -func (t Array) NumBytes() int { +func (t *Array) NumBytes() int { return int(C.mlx_array_nbytes(t.ctx)) } -func (t Array) NumDims() int { +func (t *Array) NumDims() int { return int(C.mlx_array_ndim(t.ctx)) } -func (t Array) Dims() []int { +func (t *Array) Dims() []int { dims := make([]int, t.NumDims()) for i := range dims { dims[i] = t.Dim(i) @@ -211,29 +219,29 @@ func (t Array) Dims() []int { return dims } -func (t Array) Dim(dim int) int { +func (t *Array) Dim(dim int) int { return int(C.mlx_array_dim(t.ctx, C.int(dim))) } -func (t Array) DType() DType { +func (t *Array) DType() DType { return DType(C.mlx_array_dtype(t.ctx)) } // data utilities -func (t Array) Int() int { +func (t *Array) Int() int { var item C.int64_t C.mlx_array_item_int64(&item, t.ctx) return int(item) } -func (t Array) Float() float64 { +func (t *Array) Float() float64 { var item C.double C.mlx_array_item_float64(&item, t.ctx) return float64(item) } -func (t Array) Ints() []int { +func (t *Array) Ints() []int { ints := make([]int, t.Size()) for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) { ints[i] = int(f) @@ -241,7 +249,7 @@ func (t Array) Ints() []int { return ints } -func (t Array) Floats() []float32 { +func (t *Array) Floats() []float32 { floats := make([]float32, t.Size()) for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) { floats[i] = float32(f) @@ -249,7 +257,7 @@ func (t Array) Floats() []float32 { return floats } -func (t Array) Save(name string) error { +func (t *Array) Save(name string) error { cName := C.CString(name) defer C.free(unsafe.Pointer(cName)) C.mlx_save(cName, t.ctx) @@ -258,6 +266,8 @@ func (t Array) Save(name string) error { // LogArrays logs all live arrays, sorted by size func LogArrays() { + arraysMu.Lock() + defer arraysMu.Unlock() sort.Slice(arrays, func(i, j int) bool { return arrays[i].NumBytes() > arrays[j].NumBytes() }) @@ -266,7 +276,7 @@ func LogArrays() { for _, t := range arrays { nb := t.NumBytes() total += nb - logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims())) + logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned.Load(), t.Dims())) } logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s, active: %s", len(arrays), PrettyBytes(total), PrettyBytes(ActiveMemory()))) } diff --git a/x/mlxrunner/mlx/fast.go b/x/mlxrunner/mlx/fast.go index 7feca3b1e..d5b218d1c 100644 --- a/x/mlxrunner/mlx/fast.go +++ b/x/mlxrunner/mlx/fast.go @@ -24,8 +24,8 @@ func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *A } type LayerNorm struct { - Weight Array `weight:"weight"` - Bias Array `weight:"bias"` + Weight *Array `weight:"weight"` + Bias *Array `weight:"bias"` } func (r *LayerNorm) Forward(x *Array, eps float32) *Array { @@ -35,10 +35,10 @@ func (r *LayerNorm) Forward(x *Array, eps float32) *Array { } type RMSNorm struct { - Weight Array `weight:"weight"` + Weight *Array `weight:"weight"` } -func (r RMSNorm) Forward(x *Array, eps float32) *Array { +func (r *RMSNorm) Forward(x *Array, eps float32) *Array { out := New("FAST_RMSNORM") C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx) return out diff --git a/x/mlxrunner/mlx/nn.go b/x/mlxrunner/mlx/nn.go index d3a99a6cd..d2e7fb4f1 100644 --- a/x/mlxrunner/mlx/nn.go +++ b/x/mlxrunner/mlx/nn.go @@ -1,12 +1,12 @@ package mlx type Linear struct { - Weight Array `weight:"weight"` - Bias Array `weight:"bias"` + Weight *Array `weight:"weight"` + Bias *Array `weight:"bias"` } // Forward computes the linear transformation: x @ Weight.T + Bias -func (m Linear) Forward(x *Array) *Array { +func (m *Linear) Forward(x *Array) *Array { w := m.Weight.Transpose(1, 0) if m.Bias.Valid() { return m.Bias.Addmm(x, w, 1.0, 1.0) @@ -15,14 +15,14 @@ func (m Linear) Forward(x *Array) *Array { return x.Matmul(w) } -func (m Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array { +func (m *Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array { w := m.Weight.Transpose(0, 2, 1) // TODO: bias return x.GatherMM(w, lhs, rhs, sorted) } type Embedding struct { - Weight Array `weight:"weight"` + Weight *Array `weight:"weight"` } func (e *Embedding) Forward(indices *Array) *Array {