mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 19:54:03 +02:00
mlx: add compiled closure support
Wraps MLX's mlx_compile API so Go functions can be traced into fused kernels. Contiguous elementwise chains collapse into a single Metal/CUDA kernel instead of launching one per op. Exposes Compile plus arity helpers (Compile1/2/3) that mirror Python's @mx.compile decorator shape, lazily building the closure on first call so package-level declarations work before the MLX dylib loads.
This commit is contained in:
@@ -27,7 +27,11 @@ var arrays []*Array
|
||||
|
||||
func New(name string) *Array {
|
||||
t := &Array{name: name}
|
||||
arrays = append(arrays, t)
|
||||
if tracing {
|
||||
traceScratch = append(traceScratch, t)
|
||||
} else {
|
||||
arrays = append(arrays, t)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
|
||||
192
x/mlxrunner/mlx/compile.go
Normal file
192
x/mlxrunner/mlx/compile.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package mlx
|
||||
|
||||
// #include <stdlib.h>
|
||||
// #include "generated.h"
|
||||
//
|
||||
// extern int closureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
|
||||
// extern void closureDestructor(void* payload);
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"runtime/cgo"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// CompileFunc is the signature of a function that can be compiled.
|
||||
type CompileFunc func(inputs ...*Array) []*Array
|
||||
|
||||
// CompileOption configures Compile behavior.
|
||||
type CompileOption func(*compileConfig)
|
||||
|
||||
type compileConfig struct {
|
||||
shapeless bool
|
||||
}
|
||||
|
||||
// Shapeless traces the function once against symbolic shapes so the compiled
|
||||
// graph accepts any input shape afterwards. Without this option, MLX re-traces
|
||||
// on each new (shape, dtype) combination and caches each specialization.
|
||||
func Shapeless() CompileOption {
|
||||
return func(c *compileConfig) { c.shapeless = true }
|
||||
}
|
||||
|
||||
// Compile returns a compiled version of fn. When called during another
|
||||
// compile's trace, fn is inlined directly so outer compiles can fuse through
|
||||
// inner ones.
|
||||
//
|
||||
// Compiled functions must not have side effects outside of the function. Do
|
||||
// not access data other than the arguments passed in (either Go data or MLX
|
||||
// arrays) unless it is a constant.
|
||||
func Compile(name string, fn CompileFunc, opts ...CompileOption) CompileFunc {
|
||||
var cfg compileConfig
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
|
||||
var closure C.mlx_closure
|
||||
var once sync.Once
|
||||
|
||||
return func(inputs ...*Array) []*Array {
|
||||
if tracing {
|
||||
return fn(inputs...)
|
||||
}
|
||||
|
||||
once.Do(func() {
|
||||
payload := (*cgo.Handle)(C.malloc(C.size_t(unsafe.Sizeof(cgo.Handle(0)))))
|
||||
*payload = cgo.NewHandle(fn)
|
||||
src := C.mlx_closure_new_func_payload(
|
||||
(*[0]byte)(C.closureCallback),
|
||||
unsafe.Pointer(payload),
|
||||
(*[0]byte)(C.closureDestructor),
|
||||
)
|
||||
defer C.mlx_closure_free(src)
|
||||
|
||||
closure = C.mlx_closure_new()
|
||||
mlxCheck(name+": compile failed", func() C.int {
|
||||
return C.mlx_compile(&closure, src, C.bool(cfg.shapeless))
|
||||
})
|
||||
})
|
||||
|
||||
inVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(inVec)
|
||||
for _, in := range inputs {
|
||||
C.mlx_vector_array_append_value(inVec, in.ctx)
|
||||
}
|
||||
|
||||
outVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
mlxCheck(name+": closure apply failed", func() C.int {
|
||||
return C.mlx_closure_apply(&outVec, closure, inVec)
|
||||
})
|
||||
|
||||
n := int(C.mlx_vector_array_size(outVec))
|
||||
outputs := make([]*Array, n)
|
||||
for i := range n {
|
||||
outputs[i] = New(name)
|
||||
C.mlx_vector_array_get(&outputs[i].ctx, outVec, C.size_t(i))
|
||||
}
|
||||
return outputs
|
||||
}
|
||||
}
|
||||
|
||||
// Compile1 compiles a unary function. See Compile.
|
||||
func Compile1(name string, fn func(*Array) *Array, opts ...CompileOption) func(*Array) *Array {
|
||||
cf := Compile(name, func(in ...*Array) []*Array {
|
||||
return []*Array{fn(in[0])}
|
||||
}, opts...)
|
||||
return func(a *Array) *Array {
|
||||
return cf(a)[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Compile2 compiles a binary function. See Compile.
|
||||
func Compile2(name string, fn func(*Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array) *Array {
|
||||
cf := Compile(name, func(in ...*Array) []*Array {
|
||||
return []*Array{fn(in[0], in[1])}
|
||||
}, opts...)
|
||||
return func(a, b *Array) *Array {
|
||||
return cf(a, b)[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Compile3 compiles a ternary function. See Compile.
|
||||
func Compile3(name string, fn func(*Array, *Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array, *Array) *Array {
|
||||
cf := Compile(name, func(in ...*Array) []*Array {
|
||||
return []*Array{fn(in[0], in[1], in[2])}
|
||||
}, opts...)
|
||||
return func(a, b, c *Array) *Array {
|
||||
return cf(a, b, c)[0]
|
||||
}
|
||||
}
|
||||
|
||||
// tracing is true while a compile callback is running. Since MLX is
|
||||
// single-threaded at this level a plain Go bool suffices.
|
||||
var tracing bool
|
||||
|
||||
// traceScratch collects arrays created during a compile trace so they can be
|
||||
// freed as a group when the callback returns.
|
||||
var traceScratch []*Array
|
||||
|
||||
//export closureCallback
|
||||
func closureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) (rc C.int) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("mlx closure callback panicked", "panic", r)
|
||||
rc = 1
|
||||
}
|
||||
}()
|
||||
|
||||
handle := *(*cgo.Handle)(payload)
|
||||
fn := handle.Value().(CompileFunc)
|
||||
|
||||
// When tracing, we track all of the intermediates that are created and free them separately at the end of
|
||||
// the process. This will give the effect of a single op - inputs are owned by the original caller (via
|
||||
// the MLX layer) and outputs are transferred back to MLX to create a new Go side tensor.
|
||||
if tracing {
|
||||
panic("mlx: nested compile trace")
|
||||
}
|
||||
tracing = true
|
||||
traceScratch = nil
|
||||
defer func() {
|
||||
for _, a := range traceScratch {
|
||||
if a.pinned > 0 {
|
||||
panic("mlx: traced array was pinned during compilation")
|
||||
}
|
||||
if a.Valid() {
|
||||
C.mlx_array_free(a.ctx)
|
||||
a.ctx.ctx = nil
|
||||
}
|
||||
}
|
||||
tracing = false
|
||||
traceScratch = nil
|
||||
}()
|
||||
|
||||
n := int(C.mlx_vector_array_size(input))
|
||||
inputs := make([]*Array, n)
|
||||
for i := range n {
|
||||
a := New("")
|
||||
C.mlx_vector_array_get(&a.ctx, input, C.size_t(i))
|
||||
inputs[i] = a
|
||||
}
|
||||
|
||||
outputs := fn(inputs...)
|
||||
|
||||
var arrPtr *C.mlx_array
|
||||
if len(outputs) > 0 {
|
||||
handles := make([]C.mlx_array, len(outputs))
|
||||
for i, out := range outputs {
|
||||
handles[i] = out.ctx
|
||||
}
|
||||
arrPtr = &handles[0]
|
||||
}
|
||||
C.mlx_vector_array_set_data(res, arrPtr, C.size_t(len(outputs)))
|
||||
return 0
|
||||
}
|
||||
|
||||
//export closureDestructor
|
||||
func closureDestructor(payload unsafe.Pointer) {
|
||||
handle := *(*cgo.Handle)(payload)
|
||||
handle.Delete()
|
||||
C.free(payload)
|
||||
}
|
||||
147
x/mlxrunner/mlx/compile_test.go
Normal file
147
x/mlxrunner/mlx/compile_test.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompileFusion(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// Compile fuses the ops inside a function body into a single kernel,
|
||||
// eliminating intermediate buffers. Use a diamond-shaped graph where
|
||||
// two branches must be materialized simultaneously without fusion,
|
||||
// then compare peak memory against the compiled version which fuses
|
||||
// everything into one kernel with no intermediates.
|
||||
const n = 1024 * 1024 // 4MB per float32 array
|
||||
data := make([]float32, n)
|
||||
for i := range data {
|
||||
data[i] = float32(i + 1)
|
||||
}
|
||||
|
||||
// Diamond: both a*b and a+b must be live for the final multiply.
|
||||
// Without fusion: peak includes both intermediates (~8MB extra).
|
||||
// With fusion: single kernel, no intermediates.
|
||||
body := func(a, b *Array) *Array {
|
||||
return a.Multiply(b).Multiply(a.Add(b))
|
||||
}
|
||||
|
||||
a := FromValues(data, n)
|
||||
b := FromValues(data, n)
|
||||
Pin(a, b)
|
||||
defer Unpin(a, b)
|
||||
|
||||
// Compiled: ops fused into a single kernel.
|
||||
EnableCompile()
|
||||
fn := Compile2("diamond", body, Shapeless())
|
||||
warm := fn(a, b)
|
||||
Eval(warm)
|
||||
Sweep()
|
||||
ClearCache()
|
||||
ResetPeakMemory()
|
||||
y := fn(a, b)
|
||||
Eval(y)
|
||||
compiledPeak := PeakMemory()
|
||||
Sweep()
|
||||
|
||||
// Uncompiled: ops evaluated individually, intermediates materialized.
|
||||
ClearCache()
|
||||
ResetPeakMemory()
|
||||
z := body(a, b)
|
||||
Eval(z)
|
||||
uncompiledPeak := PeakMemory()
|
||||
Sweep()
|
||||
|
||||
if compiledPeak == 0 && uncompiledPeak == 0 {
|
||||
t.Skip("peak memory tracking not available")
|
||||
}
|
||||
|
||||
t.Logf("peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
|
||||
|
||||
if compiledPeak >= uncompiledPeak {
|
||||
t.Fatalf("compilation did not reduce peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileNested(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// A compiled function that calls another compiled function should
|
||||
// produce correct results. The inner function inlines via isTracing()
|
||||
// during the outer's trace.
|
||||
inner := Compile1("silu", func(a *Array) *Array {
|
||||
return a.Multiply(a.Sigmoid())
|
||||
}, Shapeless())
|
||||
|
||||
outer := Compile2("swiglu", func(gate, up *Array) *Array {
|
||||
return inner(gate).Multiply(up)
|
||||
}, Shapeless())
|
||||
|
||||
gate := FromValues([]float32{0, 1, 2}, 3)
|
||||
up := FromValues([]float32{1, 1, 1}, 3)
|
||||
Pin(gate, up)
|
||||
defer Unpin(gate, up)
|
||||
|
||||
y := outer(gate, up)
|
||||
Eval(y)
|
||||
|
||||
// silu(x) = x * sigmoid(x); for x=0 → 0, x=1 → ~0.7311, x=2 → ~1.7616
|
||||
got := y.Floats()
|
||||
want := []float32{0, 0.7310586, 1.7615942}
|
||||
for i, v := range got {
|
||||
if v-want[i] > 1e-4 || want[i]-v > 1e-4 {
|
||||
t.Fatalf("got[%d]=%v want %v", i, v, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileCallbackPanicRecovers(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
boom := Compile1("boom", func(a *Array) *Array {
|
||||
panic("intentional test panic")
|
||||
})
|
||||
|
||||
x := FromValues([]float32{1}, 1)
|
||||
Pin(x)
|
||||
defer Unpin(x)
|
||||
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
t.Fatal("expected panic from Call, got none")
|
||||
}
|
||||
if _, ok := r.(string); !ok {
|
||||
t.Fatalf("expected string panic, got %T: %v", r, r)
|
||||
}
|
||||
}()
|
||||
boom(x)
|
||||
}
|
||||
|
||||
func TestCompileNoTrackingGrowth(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// Repeated invocations of a compiled kernel should not grow the
|
||||
// tracked-arrays list — the callback's traceScratch collects
|
||||
// intermediates during tracing and frees them when the callback returns.
|
||||
fn := Compile2("mul_add", func(a, b *Array) *Array {
|
||||
return a.Multiply(b).Add(b)
|
||||
})
|
||||
|
||||
a := FromValues([]float32{1, 2}, 2)
|
||||
b := FromValues([]float32{3, 4}, 2)
|
||||
Pin(a, b)
|
||||
defer Unpin(a, b)
|
||||
|
||||
Sweep()
|
||||
before := len(arrays)
|
||||
|
||||
for range 100 {
|
||||
_ = fn(a, b)
|
||||
Sweep()
|
||||
}
|
||||
|
||||
after := len(arrays)
|
||||
if after > before+2 {
|
||||
t.Fatalf("tracked arrays grew from %d to %d across 100 calls (includes initial trace)", before, after)
|
||||
}
|
||||
}
|
||||
@@ -9,8 +9,8 @@ package mlx
|
||||
// #include "generated.h"
|
||||
// #include <string.h>
|
||||
//
|
||||
// static char _mlx_last_error_msg[1024] = {0};
|
||||
// static int _mlx_last_error_flag = 0;
|
||||
// static __thread char _mlx_last_error_msg[1024] = {0};
|
||||
// static __thread int _mlx_last_error_flag = 0;
|
||||
//
|
||||
// static void _mlx_capture_error_handler(const char* msg, void* data) {
|
||||
// (void)data;
|
||||
@@ -30,15 +30,13 @@ package mlx
|
||||
// _mlx_last_error_msg[0] = '\0';
|
||||
// }
|
||||
//
|
||||
// static int mlx_had_last_error(void) {
|
||||
// return _mlx_last_error_flag;
|
||||
// }
|
||||
//
|
||||
// static const char* mlx_get_last_error(void) {
|
||||
// return _mlx_last_error_flag ? _mlx_last_error_msg : NULL;
|
||||
// return _mlx_last_error_flag ? _mlx_last_error_msg : "";
|
||||
// }
|
||||
import "C"
|
||||
|
||||
import "runtime"
|
||||
|
||||
func init() {
|
||||
// Replace the default exit(-1) error handler with one that captures
|
||||
// the error message so we can surface it in Go.
|
||||
@@ -53,6 +51,24 @@ func Version() string {
|
||||
return C.GoString(C.mlx_string_data(str))
|
||||
}
|
||||
|
||||
// mlxCheck locks the goroutine to its OS thread, clears the captured error
|
||||
// state, calls fn, and panics with the captured message if fn returns non-zero.
|
||||
// The thread lock ensures the thread-local error state is read from the same
|
||||
// thread that executed the call.
|
||||
func mlxCheck(fallback string, fn func() C.int) {
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
C.mlx_clear_last_error()
|
||||
if fn() != 0 {
|
||||
msg := C.GoString(C.mlx_get_last_error())
|
||||
if msg == "" {
|
||||
msg = fallback
|
||||
}
|
||||
panic("mlx: " + msg)
|
||||
}
|
||||
}
|
||||
|
||||
func doEval(outputs []*Array, async bool) {
|
||||
if len(outputs) == 0 {
|
||||
return
|
||||
@@ -67,20 +83,12 @@ func doEval(outputs []*Array, async bool) {
|
||||
}
|
||||
}
|
||||
|
||||
C.mlx_clear_last_error()
|
||||
var rc C.int
|
||||
if async {
|
||||
rc = C.mlx_async_eval(vector)
|
||||
} else {
|
||||
rc = C.mlx_eval(vector)
|
||||
}
|
||||
if rc != 0 {
|
||||
msg := "mlx eval failed"
|
||||
if C.mlx_had_last_error() != 0 {
|
||||
msg = C.GoString(C.mlx_get_last_error())
|
||||
mlxCheck("eval failed", func() C.int {
|
||||
if async {
|
||||
return C.mlx_async_eval(vector)
|
||||
}
|
||||
panic("mlx: " + msg)
|
||||
}
|
||||
return C.mlx_eval(vector)
|
||||
})
|
||||
}
|
||||
|
||||
func AsyncEval(outputs ...*Array) {
|
||||
|
||||
@@ -23,15 +23,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
return errors.New("model not loaded")
|
||||
}
|
||||
|
||||
enableCompile := true
|
||||
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
||||
enableCompile = modelCompile.EnableCompile()
|
||||
}
|
||||
if enableCompile {
|
||||
mlx.EnableCompile()
|
||||
} else {
|
||||
mlx.DisableCompile()
|
||||
}
|
||||
mlx.ResetPeakMemory()
|
||||
ctx := request.Ctx
|
||||
var (
|
||||
|
||||
@@ -79,6 +79,8 @@ func (r *Runner) Load(modelName string) error {
|
||||
r.Model = m
|
||||
r.Tokenizer = m.Tokenizer()
|
||||
r.contextLength = m.MaxContextLength()
|
||||
|
||||
mlx.EnableCompile()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user