mlx: add compiled closure support

Wraps MLX's mlx_compile API so Go functions can be traced into fused
kernels. Contiguous elementwise chains collapse into a single
Metal/CUDA kernel instead of launching one per op.

Exposes Compile plus arity helpers (Compile1/2/3) that mirror Python's
@mx.compile decorator shape, lazily building the closure on first call
so package-level declarations work before the MLX dylib loads.
This commit is contained in:
Jesse Gross
2026-04-13 12:20:33 -07:00
parent 120424d832
commit 3f8e0af045
6 changed files with 445 additions and 10 deletions

View File

@@ -23,11 +23,23 @@ type Array struct {
var arrays []*Array
// tracing is true while a compile callback is running on this goroutine. While
// tracing, New routes new arrays into traceScratch so they can be freed as a
// group at the end of the callback instead of polluting the tracked list.
var (
tracing bool
traceScratch []*Array
)
// constructor utilities
func New(name string) *Array {
t := &Array{name: name}
arrays = append(arrays, t)
if tracing {
traceScratch = append(traceScratch, t)
} else {
arrays = append(arrays, t)
}
return t
}

239
x/mlxrunner/mlx/compile.go Normal file
View File

@@ -0,0 +1,239 @@
package mlx
// #include <stdlib.h>
// #include "generated.h"
//
// extern int mlxClosureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
// extern void mlxClosureDestructor(void* payload);
import "C"
import (
"log/slog"
"runtime"
"runtime/cgo"
"unsafe"
)
// ClosureFunc is the signature of a function that can be compiled.
type ClosureFunc func(inputs []*Array) []*Array
// CompileOption configures Compile behavior.
type CompileOption func(*compileConfig)
type compileConfig struct {
shapeless bool
}
// Shapeless traces the function once against symbolic shapes so the compiled
// graph accepts any input shape afterwards. Without this option, MLX re-traces
// on each new (shape, dtype) combination and caches each specialization.
func Shapeless() CompileOption {
return func(c *compileConfig) { c.shapeless = true }
}
// CompiledFunc is a compiled MLX function. Compiled functions are intended to
// live for the program's lifetime; the underlying MLX closure is released
// automatically via runtime.AddCleanup when the CompiledFunc becomes
// unreachable.
type CompiledFunc struct {
name string
closure C.mlx_closure
}
// Compile wraps fn as a compiled MLX closure. name is a human-readable tag
// used in error messages; it has no effect on execution. MLX traces fn
// lazily on the first Call for each distinct (shape, dtype).
func Compile(name string, fn ClosureFunc, opts ...CompileOption) *CompiledFunc {
var cfg compileConfig
for _, o := range opts {
o(&cfg)
}
// The payload is a C-allocated slot holding a cgo.Handle wrapping fn.
// Using C memory avoids cgo's rule against passing Go memory that
// contains unpinned Go pointers; the slot is freed by the destructor
// when MLX releases its last reference.
payload := (*cgo.Handle)(C.malloc(C.size_t(unsafe.Sizeof(cgo.Handle(0)))))
*payload = cgo.NewHandle(fn)
src := C.mlx_closure_new_func_payload(
(*[0]byte)(C.mlxClosureCallback),
unsafe.Pointer(payload),
(*[0]byte)(C.mlxClosureDestructor),
)
// mlx_compile moves fn into the compiled closure, so src is freed
// either way. The compiled closure keeps the payload alive via its
// own shared_ptr until its mlx_closure_free runs the destructor.
defer C.mlx_closure_free(src)
compiled := C.mlx_closure_new()
clearLastError()
if rc := C.mlx_compile(&compiled, src, C.bool(cfg.shapeless)); rc != 0 {
msg := lastError()
if msg == "" {
msg = "mlx_compile failed"
}
panic("mlx: " + name + ": " + msg)
}
cf := &CompiledFunc{name: name, closure: compiled}
runtime.AddCleanup(cf, func(c C.mlx_closure) { C.mlx_closure_free(c) }, cf.closure)
return cf
}
// Call invokes the compiled function with the given inputs. Returned outputs
// participate in the normal Pin/Sweep lifecycle.
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
inVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(inVec)
for _, in := range inputs {
C.mlx_vector_array_append_value(inVec, in.ctx)
}
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
clearLastError()
if rc := C.mlx_closure_apply(&outVec, cf.closure, inVec); rc != 0 {
msg := lastError()
if msg == "" {
msg = "mlx_closure_apply failed"
}
panic("mlx: " + cf.name + ": " + msg)
}
n := int(C.mlx_vector_array_size(outVec))
outputs := make([]*Array, n)
for i := range n {
outputs[i] = New(cf.name)
C.mlx_vector_array_get(&outputs[i].ctx, outVec, C.size_t(i))
}
return outputs
}
// Compile1 compiles a unary function and returns a plain callable. The
// underlying closure is built on first call so package-level declarations
// work even before MLX's dynamic library is loaded.
//
// If invoked while another compile is tracing, the original fn is called
// directly so its ops are inlined into the outer trace rather than applied
// as a nested compiled closure. This matches upstream mlx's @mx.compile
// decorator and lets outer compiles fuse through inner ones.
func Compile1(name string, fn func(*Array) *Array, opts ...CompileOption) func(*Array) *Array {
var cf *CompiledFunc
return func(a *Array) *Array {
if tracing {
return fn(a)
}
if cf == nil {
cf = Compile(name, func(in []*Array) []*Array {
return []*Array{fn(in[0])}
}, opts...)
}
return cf.Call(a)[0]
}
}
// Compile2 compiles a binary function. See Compile1.
func Compile2(name string, fn func(*Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array) *Array {
var cf *CompiledFunc
return func(a, b *Array) *Array {
if tracing {
return fn(a, b)
}
if cf == nil {
cf = Compile(name, func(in []*Array) []*Array {
return []*Array{fn(in[0], in[1])}
}, opts...)
}
return cf.Call(a, b)[0]
}
}
// Compile3 compiles a ternary function. See Compile1.
func Compile3(name string, fn func(*Array, *Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array, *Array) *Array {
var cf *CompiledFunc
return func(a, b, c *Array) *Array {
if tracing {
return fn(a, b, c)
}
if cf == nil {
cf = Compile(name, func(in []*Array) []*Array {
return []*Array{fn(in[0], in[1], in[2])}
}, opts...)
}
return cf.Call(a, b, c)[0]
}
}
//export mlxClosureCallback
func mlxClosureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) (rc C.int) {
// Recover panics so they don't unwind into C (which is UB). MLX
// overwrites the user's error message with a generic one after any
// non-zero return, so log the original panic and let the caller see
// a failed Call via the non-zero rc. Registered first so it is
// outermost and catches panics from any subsequent code, including
// the handle lookup and type assertion.
defer func() {
if r := recover(); r != nil {
slog.Error("mlx closure callback panicked", "panic", r)
rc = 1
}
}()
handle := *(*cgo.Handle)(payload)
fn := handle.Value().(ClosureFunc)
// Route arrays produced during fn through traceScratch. They are
// symbolic tracing handles that MLX captures into the compiled graph;
// our wrappers must be freed before returning or we leak a handle and
// a refcount per traced op.
prevTracing := tracing
prevScratch := traceScratch
tracing = true
traceScratch = nil
defer func() {
for _, a := range traceScratch {
if a.pinned > 0 {
panic("mlx: traced array was pinned during compilation")
}
if a.Valid() {
C.mlx_array_free(a.ctx)
a.ctx.ctx = nil
}
}
tracing = prevTracing
traceScratch = prevScratch
}()
// Each mlx_vector_array_get populates a caller-owned handle; route it
// into traceScratch so it is freed when the callback returns.
n := int(C.mlx_vector_array_size(input))
inputs := make([]*Array, n)
for i := range n {
a := New("")
C.mlx_vector_array_get(&a.ctx, input, C.size_t(i))
inputs[i] = a
}
outputs := fn(inputs)
// Populate the output vector via set_data, which handles any initial
// state of *res (null or previously allocated) per mlx-c convention.
// Our wrappers remain independent and are freed via traceScratch.
var arrPtr *C.mlx_array
if len(outputs) > 0 {
handles := make([]C.mlx_array, len(outputs))
for i, out := range outputs {
handles[i] = out.ctx
}
arrPtr = &handles[0]
}
C.mlx_vector_array_set_data(res, arrPtr, C.size_t(len(outputs)))
return 0
}
//export mlxClosureDestructor
func mlxClosureDestructor(payload unsafe.Pointer) {
handle := *(*cgo.Handle)(payload)
handle.Delete()
C.free(payload)
}

View File

@@ -0,0 +1,166 @@
package mlx
import (
"math"
"testing"
)
func TestCompileCallbackPanicRecovers(t *testing.T) {
skipIfNoMLX(t)
boom := Compile1("boom", func(a *Array) *Array {
panic("intentional test panic")
})
x := FromValues([]float32{1}, 1)
Pin(x)
defer Unpin(x)
defer func() {
r := recover()
if r == nil {
t.Fatal("expected panic from Call, got none")
}
if _, ok := r.(string); !ok {
t.Fatalf("expected string panic, got %T: %v", r, r)
}
// MLX overwrites the message; what matters is that the Go
// panic was caught before unwinding into C (which would be UB)
// and that Call surfaced a failure instead.
}()
boom(x)
}
func TestCompileUnary(t *testing.T) {
skipIfNoMLX(t)
square := Compile1("square", func(a *Array) *Array {
return a.Multiply(a)
})
x := FromValues([]float32{1, 2, 3, 4}, 4)
Pin(x)
defer Unpin(x)
y := square(x)
Eval(y)
got := y.Floats()
want := []float32{1, 4, 9, 16}
for i, v := range got {
if v != want[i] {
t.Fatalf("got[%d]=%v want %v", i, v, want[i])
}
}
}
func TestCompileBinary(t *testing.T) {
skipIfNoMLX(t)
add := Compile2("add", func(a, b *Array) *Array {
return a.Add(b)
})
a := FromValues([]float32{1, 2, 3}, 3)
b := FromValues([]float32{10, 20, 30}, 3)
Pin(a, b)
defer Unpin(a, b)
c := add(a, b)
Eval(c)
got := c.Floats()
want := []float32{11, 22, 33}
for i, v := range got {
if v != want[i] {
t.Fatalf("got[%d]=%v want %v", i, v, want[i])
}
}
}
func TestCompileShapelessReshape(t *testing.T) {
skipIfNoMLX(t)
// A shapeless compiled kernel should accept inputs of different shapes
// on subsequent calls without recompiling or erroring.
fn := Compile1("square", func(a *Array) *Array {
return a.Multiply(a)
}, Shapeless())
for _, n := range []int{2, 4, 8} {
data := make([]float32, n)
for i := range data {
data[i] = float32(i + 1)
}
x := FromValues(data, n)
Pin(x)
y := fn(x)
Eval(y)
got := y.Floats()
for i, v := range got {
want := float32((i + 1) * (i + 1))
if v != want {
t.Fatalf("n=%d got[%d]=%v want %v", n, i, v, want)
}
}
Unpin(x)
}
}
func TestSwiGLU(t *testing.T) {
skipIfNoMLX(t)
gate := FromValues([]float32{-1, 0, 1, 2}, 4)
up := FromValues([]float32{1, 2, 3, 4}, 4)
Pin(gate, up)
defer Unpin(gate, up)
y := SwiGLU(gate, up)
Eval(y)
got := y.Floats()
// Reference: silu(g) * u = g * sigmoid(g) * u
wantVals := []float32{-1, 0, 1, 2}
upVals := []float32{1, 2, 3, 4}
for i, g := range wantVals {
silu := g / float32(1+math.Exp(float64(-g)))
want := silu * upVals[i]
if math.Abs(float64(got[i]-want)) > 1e-5 {
t.Fatalf("i=%d got=%v want=%v", i, got[i], want)
}
}
}
func TestCompileNoTrackingGrowth(t *testing.T) {
skipIfNoMLX(t)
// Repeated invocations of a compiled kernel should not grow the
// tracked-arrays list by the number of internal ops each call — the
// callback's scratch list frees them before return.
fn := Compile2("mul_add", func(a, b *Array) *Array {
// Two ops per call; if we leaked, we'd see growth proportional to
// iterations * 2 in the tracked list.
return a.Multiply(b).Add(b)
})
a := FromValues([]float32{1, 2}, 2)
b := FromValues([]float32{3, 4}, 2)
Pin(a, b)
defer Unpin(a, b)
// Prime so the initial trace's allocations are already accounted for.
_ = fn(a, b)
Eval()
Sweep()
before := len(arrays)
for range 100 {
_ = fn(a, b)
Eval()
Sweep()
}
after := len(arrays)
if after > before+2 {
t.Fatalf("tracked arrays grew from %d to %d across 100 calls", before, after)
}
}

View File

@@ -53,6 +53,19 @@ func Version() string {
return C.GoString(C.mlx_string_data(str))
}
// clearLastError resets the captured MLX error state before a call that may fail.
func clearLastError() {
C.mlx_clear_last_error()
}
// lastError returns the last captured MLX error message, or "" if none.
func lastError() string {
if C.mlx_had_last_error() == 0 {
return ""
}
return C.GoString(C.mlx_get_last_error())
}
func doEval(outputs []*Array, async bool) {
if len(outputs) == 0 {
return

View File

@@ -23,15 +23,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
return errors.New("model not loaded")
}
enableCompile := true
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
enableCompile = modelCompile.EnableCompile()
}
if enableCompile {
mlx.EnableCompile()
} else {
mlx.DisableCompile()
}
mlx.ResetPeakMemory()
ctx := request.Ctx
var (

View File

@@ -6,6 +6,8 @@ import (
"log/slog"
"net"
"net/http"
"os"
"strconv"
"strings"
"golang.org/x/sync/errgroup"
@@ -79,6 +81,18 @@ func (r *Runner) Load(modelName string) error {
r.Model = m
r.Tokenizer = m.Tokenizer()
r.contextLength = m.MaxContextLength()
enableCompile := true
if s := os.Getenv("OLLAMA_MLX_COMPILE"); s != "" {
if b, err := strconv.ParseBool(s); err == nil {
enableCompile = b
}
}
if enableCompile {
mlx.EnableCompile()
} else {
mlx.DisableCompile()
}
return nil
}