From d3e67e305cb04b6ff2bbde87e7d6133360619411 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Mon, 13 Apr 2026 12:20:33 -0700 Subject: [PATCH] mlx: add compiled closure support Wraps MLX's mlx_compile API so Go functions can be traced into fused kernels. Contiguous elementwise chains collapse into a single Metal/CUDA kernel instead of launching one per op. Exposes Compile plus arity helpers (Compile1/2/3) that mirror Python's @mx.compile decorator shape, lazily building the closure on first call so package-level declarations work before the MLX dylib loads. --- x/mlxrunner/mlx/array.go | 6 +- x/mlxrunner/mlx/compile.go | 192 ++++++++++++++++++++++++++++++++ x/mlxrunner/mlx/compile_test.go | 147 ++++++++++++++++++++++++ x/mlxrunner/mlx/mlx.go | 48 ++++---- x/mlxrunner/pipeline.go | 9 -- x/mlxrunner/runner.go | 2 + 6 files changed, 374 insertions(+), 30 deletions(-) create mode 100644 x/mlxrunner/mlx/compile.go create mode 100644 x/mlxrunner/mlx/compile_test.go diff --git a/x/mlxrunner/mlx/array.go b/x/mlxrunner/mlx/array.go index 198162efd..a41aee9cb 100644 --- a/x/mlxrunner/mlx/array.go +++ b/x/mlxrunner/mlx/array.go @@ -27,7 +27,11 @@ var arrays []*Array func New(name string) *Array { t := &Array{name: name} - arrays = append(arrays, t) + if tracing { + traceScratch = append(traceScratch, t) + } else { + arrays = append(arrays, t) + } return t } diff --git a/x/mlxrunner/mlx/compile.go b/x/mlxrunner/mlx/compile.go new file mode 100644 index 000000000..987bb7220 --- /dev/null +++ b/x/mlxrunner/mlx/compile.go @@ -0,0 +1,192 @@ +package mlx + +// #include +// #include "generated.h" +// +// extern int closureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload); +// extern void closureDestructor(void* payload); +import "C" + +import ( + "log/slog" + "runtime/cgo" + "sync" + "unsafe" +) + +// CompileFunc is the signature of a function that can be compiled. +type CompileFunc func(inputs ...*Array) []*Array + +// CompileOption configures Compile behavior. +type CompileOption func(*compileConfig) + +type compileConfig struct { + shapeless bool +} + +// Shapeless traces the function once against symbolic shapes so the compiled +// graph accepts any input shape afterwards. Without this option, MLX re-traces +// on each new (shape, dtype) combination and caches each specialization. +func Shapeless() CompileOption { + return func(c *compileConfig) { c.shapeless = true } +} + +// Compile returns a compiled version of fn. When called during another +// compile's trace, fn is inlined directly so outer compiles can fuse through +// inner ones. +// +// Compiled functions must not have side effects outside of the function. Do +// not access data other than the arguments passed in (either Go data or MLX +// arrays) unless it is a constant. +func Compile(name string, fn CompileFunc, opts ...CompileOption) CompileFunc { + var cfg compileConfig + for _, o := range opts { + o(&cfg) + } + + var closure C.mlx_closure + var once sync.Once + + return func(inputs ...*Array) []*Array { + if tracing { + return fn(inputs...) + } + + once.Do(func() { + payload := (*cgo.Handle)(C.malloc(C.size_t(unsafe.Sizeof(cgo.Handle(0))))) + *payload = cgo.NewHandle(fn) + src := C.mlx_closure_new_func_payload( + (*[0]byte)(C.closureCallback), + unsafe.Pointer(payload), + (*[0]byte)(C.closureDestructor), + ) + defer C.mlx_closure_free(src) + + closure = C.mlx_closure_new() + mlxCheck(name+": compile failed", func() C.int { + return C.mlx_compile(&closure, src, C.bool(cfg.shapeless)) + }) + }) + + inVec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(inVec) + for _, in := range inputs { + C.mlx_vector_array_append_value(inVec, in.ctx) + } + + outVec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(outVec) + mlxCheck(name+": closure apply failed", func() C.int { + return C.mlx_closure_apply(&outVec, closure, inVec) + }) + + n := int(C.mlx_vector_array_size(outVec)) + outputs := make([]*Array, n) + for i := range n { + outputs[i] = New(name) + C.mlx_vector_array_get(&outputs[i].ctx, outVec, C.size_t(i)) + } + return outputs + } +} + +// Compile1 compiles a unary function. See Compile. +func Compile1(name string, fn func(*Array) *Array, opts ...CompileOption) func(*Array) *Array { + cf := Compile(name, func(in ...*Array) []*Array { + return []*Array{fn(in[0])} + }, opts...) + return func(a *Array) *Array { + return cf(a)[0] + } +} + +// Compile2 compiles a binary function. See Compile. +func Compile2(name string, fn func(*Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array) *Array { + cf := Compile(name, func(in ...*Array) []*Array { + return []*Array{fn(in[0], in[1])} + }, opts...) + return func(a, b *Array) *Array { + return cf(a, b)[0] + } +} + +// Compile3 compiles a ternary function. See Compile. +func Compile3(name string, fn func(*Array, *Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array, *Array) *Array { + cf := Compile(name, func(in ...*Array) []*Array { + return []*Array{fn(in[0], in[1], in[2])} + }, opts...) + return func(a, b, c *Array) *Array { + return cf(a, b, c)[0] + } +} + +// tracing is true while a compile callback is running. Since MLX is +// single-threaded at this level a plain Go bool suffices. +var tracing bool + +// traceScratch collects arrays created during a compile trace so they can be +// freed as a group when the callback returns. +var traceScratch []*Array + +//export closureCallback +func closureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) (rc C.int) { + defer func() { + if r := recover(); r != nil { + slog.Error("mlx closure callback panicked", "panic", r) + rc = 1 + } + }() + + handle := *(*cgo.Handle)(payload) + fn := handle.Value().(CompileFunc) + + // When tracing, we track all of the intermediates that are created and free them separately at the end of + // the process. This will give the effect of a single op - inputs are owned by the original caller (via + // the MLX layer) and outputs are transferred back to MLX to create a new Go side tensor. + if tracing { + panic("mlx: nested compile trace") + } + tracing = true + traceScratch = nil + defer func() { + for _, a := range traceScratch { + if a.pinned > 0 { + panic("mlx: traced array was pinned during compilation") + } + if a.Valid() { + C.mlx_array_free(a.ctx) + a.ctx.ctx = nil + } + } + tracing = false + traceScratch = nil + }() + + n := int(C.mlx_vector_array_size(input)) + inputs := make([]*Array, n) + for i := range n { + a := New("") + C.mlx_vector_array_get(&a.ctx, input, C.size_t(i)) + inputs[i] = a + } + + outputs := fn(inputs...) + + var arrPtr *C.mlx_array + if len(outputs) > 0 { + handles := make([]C.mlx_array, len(outputs)) + for i, out := range outputs { + handles[i] = out.ctx + } + arrPtr = &handles[0] + } + C.mlx_vector_array_set_data(res, arrPtr, C.size_t(len(outputs))) + return 0 +} + +//export closureDestructor +func closureDestructor(payload unsafe.Pointer) { + handle := *(*cgo.Handle)(payload) + handle.Delete() + C.free(payload) +} diff --git a/x/mlxrunner/mlx/compile_test.go b/x/mlxrunner/mlx/compile_test.go new file mode 100644 index 000000000..801ee6d0d --- /dev/null +++ b/x/mlxrunner/mlx/compile_test.go @@ -0,0 +1,147 @@ +package mlx + +import ( + "testing" +) + +func TestCompileFusion(t *testing.T) { + skipIfNoMLX(t) + + // Compile fuses the ops inside a function body into a single kernel, + // eliminating intermediate buffers. Use a diamond-shaped graph where + // two branches must be materialized simultaneously without fusion, + // then compare peak memory against the compiled version which fuses + // everything into one kernel with no intermediates. + const n = 1024 * 1024 // 4MB per float32 array + data := make([]float32, n) + for i := range data { + data[i] = float32(i + 1) + } + + // Diamond: both a*b and a+b must be live for the final multiply. + // Without fusion: peak includes both intermediates (~8MB extra). + // With fusion: single kernel, no intermediates. + body := func(a, b *Array) *Array { + return a.Multiply(b).Multiply(a.Add(b)) + } + + a := FromValues(data, n) + b := FromValues(data, n) + Pin(a, b) + defer Unpin(a, b) + + // Compiled: ops fused into a single kernel. + EnableCompile() + fn := Compile2("diamond", body, Shapeless()) + warm := fn(a, b) + Eval(warm) + Sweep() + ClearCache() + ResetPeakMemory() + y := fn(a, b) + Eval(y) + compiledPeak := PeakMemory() + Sweep() + + // Uncompiled: ops evaluated individually, intermediates materialized. + ClearCache() + ResetPeakMemory() + z := body(a, b) + Eval(z) + uncompiledPeak := PeakMemory() + Sweep() + + if compiledPeak == 0 && uncompiledPeak == 0 { + t.Skip("peak memory tracking not available") + } + + t.Logf("peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak) + + if compiledPeak >= uncompiledPeak { + t.Fatalf("compilation did not reduce peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak) + } +} + +func TestCompileNested(t *testing.T) { + skipIfNoMLX(t) + + // A compiled function that calls another compiled function should + // produce correct results. The inner function inlines via isTracing() + // during the outer's trace. + inner := Compile1("silu", func(a *Array) *Array { + return a.Multiply(a.Sigmoid()) + }, Shapeless()) + + outer := Compile2("swiglu", func(gate, up *Array) *Array { + return inner(gate).Multiply(up) + }, Shapeless()) + + gate := FromValues([]float32{0, 1, 2}, 3) + up := FromValues([]float32{1, 1, 1}, 3) + Pin(gate, up) + defer Unpin(gate, up) + + y := outer(gate, up) + Eval(y) + + // silu(x) = x * sigmoid(x); for x=0 → 0, x=1 → ~0.7311, x=2 → ~1.7616 + got := y.Floats() + want := []float32{0, 0.7310586, 1.7615942} + for i, v := range got { + if v-want[i] > 1e-4 || want[i]-v > 1e-4 { + t.Fatalf("got[%d]=%v want %v", i, v, want[i]) + } + } +} + +func TestCompileCallbackPanicRecovers(t *testing.T) { + skipIfNoMLX(t) + + boom := Compile1("boom", func(a *Array) *Array { + panic("intentional test panic") + }) + + x := FromValues([]float32{1}, 1) + Pin(x) + defer Unpin(x) + + defer func() { + r := recover() + if r == nil { + t.Fatal("expected panic from Call, got none") + } + if _, ok := r.(string); !ok { + t.Fatalf("expected string panic, got %T: %v", r, r) + } + }() + boom(x) +} + +func TestCompileNoTrackingGrowth(t *testing.T) { + skipIfNoMLX(t) + + // Repeated invocations of a compiled kernel should not grow the + // tracked-arrays list — the callback's traceScratch collects + // intermediates during tracing and frees them when the callback returns. + fn := Compile2("mul_add", func(a, b *Array) *Array { + return a.Multiply(b).Add(b) + }) + + a := FromValues([]float32{1, 2}, 2) + b := FromValues([]float32{3, 4}, 2) + Pin(a, b) + defer Unpin(a, b) + + Sweep() + before := len(arrays) + + for range 100 { + _ = fn(a, b) + Sweep() + } + + after := len(arrays) + if after > before+2 { + t.Fatalf("tracked arrays grew from %d to %d across 100 calls (includes initial trace)", before, after) + } +} diff --git a/x/mlxrunner/mlx/mlx.go b/x/mlxrunner/mlx/mlx.go index a03488158..291410573 100644 --- a/x/mlxrunner/mlx/mlx.go +++ b/x/mlxrunner/mlx/mlx.go @@ -9,8 +9,8 @@ package mlx // #include "generated.h" // #include // -// static char _mlx_last_error_msg[1024] = {0}; -// static int _mlx_last_error_flag = 0; +// static __thread char _mlx_last_error_msg[1024] = {0}; +// static __thread int _mlx_last_error_flag = 0; // // static void _mlx_capture_error_handler(const char* msg, void* data) { // (void)data; @@ -30,15 +30,13 @@ package mlx // _mlx_last_error_msg[0] = '\0'; // } // -// static int mlx_had_last_error(void) { -// return _mlx_last_error_flag; -// } -// // static const char* mlx_get_last_error(void) { -// return _mlx_last_error_flag ? _mlx_last_error_msg : NULL; +// return _mlx_last_error_flag ? _mlx_last_error_msg : ""; // } import "C" +import "runtime" + func init() { // Replace the default exit(-1) error handler with one that captures // the error message so we can surface it in Go. @@ -53,6 +51,24 @@ func Version() string { return C.GoString(C.mlx_string_data(str)) } +// mlxCheck locks the goroutine to its OS thread, clears the captured error +// state, calls fn, and panics with the captured message if fn returns non-zero. +// The thread lock ensures the thread-local error state is read from the same +// thread that executed the call. +func mlxCheck(fallback string, fn func() C.int) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + C.mlx_clear_last_error() + if fn() != 0 { + msg := C.GoString(C.mlx_get_last_error()) + if msg == "" { + msg = fallback + } + panic("mlx: " + msg) + } +} + func doEval(outputs []*Array, async bool) { if len(outputs) == 0 { return @@ -67,20 +83,12 @@ func doEval(outputs []*Array, async bool) { } } - C.mlx_clear_last_error() - var rc C.int - if async { - rc = C.mlx_async_eval(vector) - } else { - rc = C.mlx_eval(vector) - } - if rc != 0 { - msg := "mlx eval failed" - if C.mlx_had_last_error() != 0 { - msg = C.GoString(C.mlx_get_last_error()) + mlxCheck("eval failed", func() C.int { + if async { + return C.mlx_async_eval(vector) } - panic("mlx: " + msg) - } + return C.mlx_eval(vector) + }) } func AsyncEval(outputs ...*Array) { diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index 4dcfad01d..a8285d33e 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -23,15 +23,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error { return errors.New("model not loaded") } - enableCompile := true - if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok { - enableCompile = modelCompile.EnableCompile() - } - if enableCompile { - mlx.EnableCompile() - } else { - mlx.DisableCompile() - } mlx.ResetPeakMemory() ctx := request.Ctx var ( diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index 08a376d43..61b635fcd 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -79,6 +79,8 @@ func (r *Runner) Load(modelName string) error { r.Model = m r.Tokenizer = m.Tokenizer() r.contextLength = m.MaxContextLength() + + mlx.EnableCompile() return nil }