mlx: add compiled closure support

Wraps MLX's mlx_compile API so Go functions can be traced into fused
kernels. Contiguous elementwise chains collapse into a single
Metal/CUDA kernel instead of launching one per op.

Exposes Compile plus arity helpers (Compile1/2/3) that mirror Python's
@mx.compile decorator shape, lazily building the closure on first call
so package-level declarations work before the MLX dylib loads.
This commit is contained in:
Jesse Gross
2026-04-13 12:20:33 -07:00
parent 698e04a14b
commit d3e67e305c
6 changed files with 374 additions and 30 deletions

View File

@@ -27,7 +27,11 @@ var arrays []*Array
func New(name string) *Array {
t := &Array{name: name}
arrays = append(arrays, t)
if tracing {
traceScratch = append(traceScratch, t)
} else {
arrays = append(arrays, t)
}
return t
}

192
x/mlxrunner/mlx/compile.go Normal file
View File

@@ -0,0 +1,192 @@
package mlx
// #include <stdlib.h>
// #include "generated.h"
//
// extern int closureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
// extern void closureDestructor(void* payload);
import "C"
import (
"log/slog"
"runtime/cgo"
"sync"
"unsafe"
)
// CompileFunc is the signature of a function that can be compiled.
type CompileFunc func(inputs ...*Array) []*Array
// CompileOption configures Compile behavior.
type CompileOption func(*compileConfig)
type compileConfig struct {
shapeless bool
}
// Shapeless traces the function once against symbolic shapes so the compiled
// graph accepts any input shape afterwards. Without this option, MLX re-traces
// on each new (shape, dtype) combination and caches each specialization.
func Shapeless() CompileOption {
return func(c *compileConfig) { c.shapeless = true }
}
// Compile returns a compiled version of fn. When called during another
// compile's trace, fn is inlined directly so outer compiles can fuse through
// inner ones.
//
// Compiled functions must not have side effects outside of the function. Do
// not access data other than the arguments passed in (either Go data or MLX
// arrays) unless it is a constant.
func Compile(name string, fn CompileFunc, opts ...CompileOption) CompileFunc {
var cfg compileConfig
for _, o := range opts {
o(&cfg)
}
var closure C.mlx_closure
var once sync.Once
return func(inputs ...*Array) []*Array {
if tracing {
return fn(inputs...)
}
once.Do(func() {
payload := (*cgo.Handle)(C.malloc(C.size_t(unsafe.Sizeof(cgo.Handle(0)))))
*payload = cgo.NewHandle(fn)
src := C.mlx_closure_new_func_payload(
(*[0]byte)(C.closureCallback),
unsafe.Pointer(payload),
(*[0]byte)(C.closureDestructor),
)
defer C.mlx_closure_free(src)
closure = C.mlx_closure_new()
mlxCheck(name+": compile failed", func() C.int {
return C.mlx_compile(&closure, src, C.bool(cfg.shapeless))
})
})
inVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(inVec)
for _, in := range inputs {
C.mlx_vector_array_append_value(inVec, in.ctx)
}
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
mlxCheck(name+": closure apply failed", func() C.int {
return C.mlx_closure_apply(&outVec, closure, inVec)
})
n := int(C.mlx_vector_array_size(outVec))
outputs := make([]*Array, n)
for i := range n {
outputs[i] = New(name)
C.mlx_vector_array_get(&outputs[i].ctx, outVec, C.size_t(i))
}
return outputs
}
}
// Compile1 compiles a unary function. See Compile.
func Compile1(name string, fn func(*Array) *Array, opts ...CompileOption) func(*Array) *Array {
cf := Compile(name, func(in ...*Array) []*Array {
return []*Array{fn(in[0])}
}, opts...)
return func(a *Array) *Array {
return cf(a)[0]
}
}
// Compile2 compiles a binary function. See Compile.
func Compile2(name string, fn func(*Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array) *Array {
cf := Compile(name, func(in ...*Array) []*Array {
return []*Array{fn(in[0], in[1])}
}, opts...)
return func(a, b *Array) *Array {
return cf(a, b)[0]
}
}
// Compile3 compiles a ternary function. See Compile.
func Compile3(name string, fn func(*Array, *Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array, *Array) *Array {
cf := Compile(name, func(in ...*Array) []*Array {
return []*Array{fn(in[0], in[1], in[2])}
}, opts...)
return func(a, b, c *Array) *Array {
return cf(a, b, c)[0]
}
}
// tracing is true while a compile callback is running. Since MLX is
// single-threaded at this level a plain Go bool suffices.
var tracing bool
// traceScratch collects arrays created during a compile trace so they can be
// freed as a group when the callback returns.
var traceScratch []*Array
//export closureCallback
func closureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) (rc C.int) {
defer func() {
if r := recover(); r != nil {
slog.Error("mlx closure callback panicked", "panic", r)
rc = 1
}
}()
handle := *(*cgo.Handle)(payload)
fn := handle.Value().(CompileFunc)
// When tracing, we track all of the intermediates that are created and free them separately at the end of
// the process. This will give the effect of a single op - inputs are owned by the original caller (via
// the MLX layer) and outputs are transferred back to MLX to create a new Go side tensor.
if tracing {
panic("mlx: nested compile trace")
}
tracing = true
traceScratch = nil
defer func() {
for _, a := range traceScratch {
if a.pinned > 0 {
panic("mlx: traced array was pinned during compilation")
}
if a.Valid() {
C.mlx_array_free(a.ctx)
a.ctx.ctx = nil
}
}
tracing = false
traceScratch = nil
}()
n := int(C.mlx_vector_array_size(input))
inputs := make([]*Array, n)
for i := range n {
a := New("")
C.mlx_vector_array_get(&a.ctx, input, C.size_t(i))
inputs[i] = a
}
outputs := fn(inputs...)
var arrPtr *C.mlx_array
if len(outputs) > 0 {
handles := make([]C.mlx_array, len(outputs))
for i, out := range outputs {
handles[i] = out.ctx
}
arrPtr = &handles[0]
}
C.mlx_vector_array_set_data(res, arrPtr, C.size_t(len(outputs)))
return 0
}
//export closureDestructor
func closureDestructor(payload unsafe.Pointer) {
handle := *(*cgo.Handle)(payload)
handle.Delete()
C.free(payload)
}

View File

@@ -0,0 +1,147 @@
package mlx
import (
"testing"
)
func TestCompileFusion(t *testing.T) {
skipIfNoMLX(t)
// Compile fuses the ops inside a function body into a single kernel,
// eliminating intermediate buffers. Use a diamond-shaped graph where
// two branches must be materialized simultaneously without fusion,
// then compare peak memory against the compiled version which fuses
// everything into one kernel with no intermediates.
const n = 1024 * 1024 // 4MB per float32 array
data := make([]float32, n)
for i := range data {
data[i] = float32(i + 1)
}
// Diamond: both a*b and a+b must be live for the final multiply.
// Without fusion: peak includes both intermediates (~8MB extra).
// With fusion: single kernel, no intermediates.
body := func(a, b *Array) *Array {
return a.Multiply(b).Multiply(a.Add(b))
}
a := FromValues(data, n)
b := FromValues(data, n)
Pin(a, b)
defer Unpin(a, b)
// Compiled: ops fused into a single kernel.
EnableCompile()
fn := Compile2("diamond", body, Shapeless())
warm := fn(a, b)
Eval(warm)
Sweep()
ClearCache()
ResetPeakMemory()
y := fn(a, b)
Eval(y)
compiledPeak := PeakMemory()
Sweep()
// Uncompiled: ops evaluated individually, intermediates materialized.
ClearCache()
ResetPeakMemory()
z := body(a, b)
Eval(z)
uncompiledPeak := PeakMemory()
Sweep()
if compiledPeak == 0 && uncompiledPeak == 0 {
t.Skip("peak memory tracking not available")
}
t.Logf("peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
if compiledPeak >= uncompiledPeak {
t.Fatalf("compilation did not reduce peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
}
}
func TestCompileNested(t *testing.T) {
skipIfNoMLX(t)
// A compiled function that calls another compiled function should
// produce correct results. The inner function inlines via isTracing()
// during the outer's trace.
inner := Compile1("silu", func(a *Array) *Array {
return a.Multiply(a.Sigmoid())
}, Shapeless())
outer := Compile2("swiglu", func(gate, up *Array) *Array {
return inner(gate).Multiply(up)
}, Shapeless())
gate := FromValues([]float32{0, 1, 2}, 3)
up := FromValues([]float32{1, 1, 1}, 3)
Pin(gate, up)
defer Unpin(gate, up)
y := outer(gate, up)
Eval(y)
// silu(x) = x * sigmoid(x); for x=0 → 0, x=1 → ~0.7311, x=2 → ~1.7616
got := y.Floats()
want := []float32{0, 0.7310586, 1.7615942}
for i, v := range got {
if v-want[i] > 1e-4 || want[i]-v > 1e-4 {
t.Fatalf("got[%d]=%v want %v", i, v, want[i])
}
}
}
func TestCompileCallbackPanicRecovers(t *testing.T) {
skipIfNoMLX(t)
boom := Compile1("boom", func(a *Array) *Array {
panic("intentional test panic")
})
x := FromValues([]float32{1}, 1)
Pin(x)
defer Unpin(x)
defer func() {
r := recover()
if r == nil {
t.Fatal("expected panic from Call, got none")
}
if _, ok := r.(string); !ok {
t.Fatalf("expected string panic, got %T: %v", r, r)
}
}()
boom(x)
}
func TestCompileNoTrackingGrowth(t *testing.T) {
skipIfNoMLX(t)
// Repeated invocations of a compiled kernel should not grow the
// tracked-arrays list — the callback's traceScratch collects
// intermediates during tracing and frees them when the callback returns.
fn := Compile2("mul_add", func(a, b *Array) *Array {
return a.Multiply(b).Add(b)
})
a := FromValues([]float32{1, 2}, 2)
b := FromValues([]float32{3, 4}, 2)
Pin(a, b)
defer Unpin(a, b)
Sweep()
before := len(arrays)
for range 100 {
_ = fn(a, b)
Sweep()
}
after := len(arrays)
if after > before+2 {
t.Fatalf("tracked arrays grew from %d to %d across 100 calls (includes initial trace)", before, after)
}
}

View File

@@ -9,8 +9,8 @@ package mlx
// #include "generated.h"
// #include <string.h>
//
// static char _mlx_last_error_msg[1024] = {0};
// static int _mlx_last_error_flag = 0;
// static __thread char _mlx_last_error_msg[1024] = {0};
// static __thread int _mlx_last_error_flag = 0;
//
// static void _mlx_capture_error_handler(const char* msg, void* data) {
// (void)data;
@@ -30,15 +30,13 @@ package mlx
// _mlx_last_error_msg[0] = '\0';
// }
//
// static int mlx_had_last_error(void) {
// return _mlx_last_error_flag;
// }
//
// static const char* mlx_get_last_error(void) {
// return _mlx_last_error_flag ? _mlx_last_error_msg : NULL;
// return _mlx_last_error_flag ? _mlx_last_error_msg : "";
// }
import "C"
import "runtime"
func init() {
// Replace the default exit(-1) error handler with one that captures
// the error message so we can surface it in Go.
@@ -53,6 +51,24 @@ func Version() string {
return C.GoString(C.mlx_string_data(str))
}
// mlxCheck locks the goroutine to its OS thread, clears the captured error
// state, calls fn, and panics with the captured message if fn returns non-zero.
// The thread lock ensures the thread-local error state is read from the same
// thread that executed the call.
func mlxCheck(fallback string, fn func() C.int) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
C.mlx_clear_last_error()
if fn() != 0 {
msg := C.GoString(C.mlx_get_last_error())
if msg == "" {
msg = fallback
}
panic("mlx: " + msg)
}
}
func doEval(outputs []*Array, async bool) {
if len(outputs) == 0 {
return
@@ -67,20 +83,12 @@ func doEval(outputs []*Array, async bool) {
}
}
C.mlx_clear_last_error()
var rc C.int
if async {
rc = C.mlx_async_eval(vector)
} else {
rc = C.mlx_eval(vector)
}
if rc != 0 {
msg := "mlx eval failed"
if C.mlx_had_last_error() != 0 {
msg = C.GoString(C.mlx_get_last_error())
mlxCheck("eval failed", func() C.int {
if async {
return C.mlx_async_eval(vector)
}
panic("mlx: " + msg)
}
return C.mlx_eval(vector)
})
}
func AsyncEval(outputs ...*Array) {

View File

@@ -23,15 +23,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
return errors.New("model not loaded")
}
enableCompile := true
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
enableCompile = modelCompile.EnableCompile()
}
if enableCompile {
mlx.EnableCompile()
} else {
mlx.DisableCompile()
}
mlx.ResetPeakMemory()
ctx := request.Ctx
var (

View File

@@ -79,6 +79,8 @@ func (r *Runner) Load(modelName string) error {
r.Model = m
r.Tokenizer = m.Tokenizer()
r.contextLength = m.MaxContextLength()
mlx.EnableCompile()
return nil
}