mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
mlx: add compiled closure support
Wraps MLX's mlx_compile API so Go functions can be traced into fused kernels. Contiguous elementwise chains collapse into a single Metal/CUDA kernel instead of launching one per op. Exposes Compile plus arity helpers (Compile1/2/3) that mirror Python's @mx.compile decorator shape, lazily building the closure on first call so package-level declarations work before the MLX dylib loads.
This commit is contained in:
@@ -23,11 +23,23 @@ type Array struct {
|
||||
|
||||
var arrays []*Array
|
||||
|
||||
// tracing is true while a compile callback is running on this goroutine. While
|
||||
// tracing, New routes new arrays into traceScratch so they can be freed as a
|
||||
// group at the end of the callback instead of polluting the tracked list.
|
||||
var (
|
||||
tracing bool
|
||||
traceScratch []*Array
|
||||
)
|
||||
|
||||
// constructor utilities
|
||||
|
||||
func New(name string) *Array {
|
||||
t := &Array{name: name}
|
||||
arrays = append(arrays, t)
|
||||
if tracing {
|
||||
traceScratch = append(traceScratch, t)
|
||||
} else {
|
||||
arrays = append(arrays, t)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
|
||||
239
x/mlxrunner/mlx/compile.go
Normal file
239
x/mlxrunner/mlx/compile.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package mlx
|
||||
|
||||
// #include <stdlib.h>
|
||||
// #include "generated.h"
|
||||
//
|
||||
// extern int mlxClosureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
|
||||
// extern void mlxClosureDestructor(void* payload);
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"runtime"
|
||||
"runtime/cgo"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// ClosureFunc is the signature of a function that can be compiled.
|
||||
type ClosureFunc func(inputs []*Array) []*Array
|
||||
|
||||
// CompileOption configures Compile behavior.
|
||||
type CompileOption func(*compileConfig)
|
||||
|
||||
type compileConfig struct {
|
||||
shapeless bool
|
||||
}
|
||||
|
||||
// Shapeless traces the function once against symbolic shapes so the compiled
|
||||
// graph accepts any input shape afterwards. Without this option, MLX re-traces
|
||||
// on each new (shape, dtype) combination and caches each specialization.
|
||||
func Shapeless() CompileOption {
|
||||
return func(c *compileConfig) { c.shapeless = true }
|
||||
}
|
||||
|
||||
// CompiledFunc is a compiled MLX function. Compiled functions are intended to
|
||||
// live for the program's lifetime; the underlying MLX closure is released
|
||||
// automatically via runtime.AddCleanup when the CompiledFunc becomes
|
||||
// unreachable.
|
||||
type CompiledFunc struct {
|
||||
name string
|
||||
closure C.mlx_closure
|
||||
}
|
||||
|
||||
// Compile wraps fn as a compiled MLX closure. name is a human-readable tag
|
||||
// used in error messages; it has no effect on execution. MLX traces fn
|
||||
// lazily on the first Call for each distinct (shape, dtype).
|
||||
func Compile(name string, fn ClosureFunc, opts ...CompileOption) *CompiledFunc {
|
||||
var cfg compileConfig
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
|
||||
// The payload is a C-allocated slot holding a cgo.Handle wrapping fn.
|
||||
// Using C memory avoids cgo's rule against passing Go memory that
|
||||
// contains unpinned Go pointers; the slot is freed by the destructor
|
||||
// when MLX releases its last reference.
|
||||
payload := (*cgo.Handle)(C.malloc(C.size_t(unsafe.Sizeof(cgo.Handle(0)))))
|
||||
*payload = cgo.NewHandle(fn)
|
||||
src := C.mlx_closure_new_func_payload(
|
||||
(*[0]byte)(C.mlxClosureCallback),
|
||||
unsafe.Pointer(payload),
|
||||
(*[0]byte)(C.mlxClosureDestructor),
|
||||
)
|
||||
// mlx_compile moves fn into the compiled closure, so src is freed
|
||||
// either way. The compiled closure keeps the payload alive via its
|
||||
// own shared_ptr until its mlx_closure_free runs the destructor.
|
||||
defer C.mlx_closure_free(src)
|
||||
|
||||
compiled := C.mlx_closure_new()
|
||||
clearLastError()
|
||||
if rc := C.mlx_compile(&compiled, src, C.bool(cfg.shapeless)); rc != 0 {
|
||||
msg := lastError()
|
||||
if msg == "" {
|
||||
msg = "mlx_compile failed"
|
||||
}
|
||||
panic("mlx: " + name + ": " + msg)
|
||||
}
|
||||
|
||||
cf := &CompiledFunc{name: name, closure: compiled}
|
||||
runtime.AddCleanup(cf, func(c C.mlx_closure) { C.mlx_closure_free(c) }, cf.closure)
|
||||
return cf
|
||||
}
|
||||
|
||||
// Call invokes the compiled function with the given inputs. Returned outputs
|
||||
// participate in the normal Pin/Sweep lifecycle.
|
||||
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
|
||||
inVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(inVec)
|
||||
for _, in := range inputs {
|
||||
C.mlx_vector_array_append_value(inVec, in.ctx)
|
||||
}
|
||||
|
||||
outVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
clearLastError()
|
||||
if rc := C.mlx_closure_apply(&outVec, cf.closure, inVec); rc != 0 {
|
||||
msg := lastError()
|
||||
if msg == "" {
|
||||
msg = "mlx_closure_apply failed"
|
||||
}
|
||||
panic("mlx: " + cf.name + ": " + msg)
|
||||
}
|
||||
|
||||
n := int(C.mlx_vector_array_size(outVec))
|
||||
outputs := make([]*Array, n)
|
||||
for i := range n {
|
||||
outputs[i] = New(cf.name)
|
||||
C.mlx_vector_array_get(&outputs[i].ctx, outVec, C.size_t(i))
|
||||
}
|
||||
return outputs
|
||||
}
|
||||
|
||||
// Compile1 compiles a unary function and returns a plain callable. The
|
||||
// underlying closure is built on first call so package-level declarations
|
||||
// work even before MLX's dynamic library is loaded.
|
||||
//
|
||||
// If invoked while another compile is tracing, the original fn is called
|
||||
// directly so its ops are inlined into the outer trace rather than applied
|
||||
// as a nested compiled closure. This matches upstream mlx's @mx.compile
|
||||
// decorator and lets outer compiles fuse through inner ones.
|
||||
func Compile1(name string, fn func(*Array) *Array, opts ...CompileOption) func(*Array) *Array {
|
||||
var cf *CompiledFunc
|
||||
return func(a *Array) *Array {
|
||||
if tracing {
|
||||
return fn(a)
|
||||
}
|
||||
if cf == nil {
|
||||
cf = Compile(name, func(in []*Array) []*Array {
|
||||
return []*Array{fn(in[0])}
|
||||
}, opts...)
|
||||
}
|
||||
return cf.Call(a)[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Compile2 compiles a binary function. See Compile1.
|
||||
func Compile2(name string, fn func(*Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array) *Array {
|
||||
var cf *CompiledFunc
|
||||
return func(a, b *Array) *Array {
|
||||
if tracing {
|
||||
return fn(a, b)
|
||||
}
|
||||
if cf == nil {
|
||||
cf = Compile(name, func(in []*Array) []*Array {
|
||||
return []*Array{fn(in[0], in[1])}
|
||||
}, opts...)
|
||||
}
|
||||
return cf.Call(a, b)[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Compile3 compiles a ternary function. See Compile1.
|
||||
func Compile3(name string, fn func(*Array, *Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array, *Array) *Array {
|
||||
var cf *CompiledFunc
|
||||
return func(a, b, c *Array) *Array {
|
||||
if tracing {
|
||||
return fn(a, b, c)
|
||||
}
|
||||
if cf == nil {
|
||||
cf = Compile(name, func(in []*Array) []*Array {
|
||||
return []*Array{fn(in[0], in[1], in[2])}
|
||||
}, opts...)
|
||||
}
|
||||
return cf.Call(a, b, c)[0]
|
||||
}
|
||||
}
|
||||
|
||||
//export mlxClosureCallback
|
||||
func mlxClosureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) (rc C.int) {
|
||||
// Recover panics so they don't unwind into C (which is UB). MLX
|
||||
// overwrites the user's error message with a generic one after any
|
||||
// non-zero return, so log the original panic and let the caller see
|
||||
// a failed Call via the non-zero rc. Registered first so it is
|
||||
// outermost and catches panics from any subsequent code, including
|
||||
// the handle lookup and type assertion.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("mlx closure callback panicked", "panic", r)
|
||||
rc = 1
|
||||
}
|
||||
}()
|
||||
|
||||
handle := *(*cgo.Handle)(payload)
|
||||
fn := handle.Value().(ClosureFunc)
|
||||
|
||||
// Route arrays produced during fn through traceScratch. They are
|
||||
// symbolic tracing handles that MLX captures into the compiled graph;
|
||||
// our wrappers must be freed before returning or we leak a handle and
|
||||
// a refcount per traced op.
|
||||
prevTracing := tracing
|
||||
prevScratch := traceScratch
|
||||
tracing = true
|
||||
traceScratch = nil
|
||||
defer func() {
|
||||
for _, a := range traceScratch {
|
||||
if a.pinned > 0 {
|
||||
panic("mlx: traced array was pinned during compilation")
|
||||
}
|
||||
if a.Valid() {
|
||||
C.mlx_array_free(a.ctx)
|
||||
a.ctx.ctx = nil
|
||||
}
|
||||
}
|
||||
tracing = prevTracing
|
||||
traceScratch = prevScratch
|
||||
}()
|
||||
|
||||
// Each mlx_vector_array_get populates a caller-owned handle; route it
|
||||
// into traceScratch so it is freed when the callback returns.
|
||||
n := int(C.mlx_vector_array_size(input))
|
||||
inputs := make([]*Array, n)
|
||||
for i := range n {
|
||||
a := New("")
|
||||
C.mlx_vector_array_get(&a.ctx, input, C.size_t(i))
|
||||
inputs[i] = a
|
||||
}
|
||||
|
||||
outputs := fn(inputs)
|
||||
|
||||
// Populate the output vector via set_data, which handles any initial
|
||||
// state of *res (null or previously allocated) per mlx-c convention.
|
||||
// Our wrappers remain independent and are freed via traceScratch.
|
||||
var arrPtr *C.mlx_array
|
||||
if len(outputs) > 0 {
|
||||
handles := make([]C.mlx_array, len(outputs))
|
||||
for i, out := range outputs {
|
||||
handles[i] = out.ctx
|
||||
}
|
||||
arrPtr = &handles[0]
|
||||
}
|
||||
C.mlx_vector_array_set_data(res, arrPtr, C.size_t(len(outputs)))
|
||||
return 0
|
||||
}
|
||||
|
||||
//export mlxClosureDestructor
|
||||
func mlxClosureDestructor(payload unsafe.Pointer) {
|
||||
handle := *(*cgo.Handle)(payload)
|
||||
handle.Delete()
|
||||
C.free(payload)
|
||||
}
|
||||
166
x/mlxrunner/mlx/compile_test.go
Normal file
166
x/mlxrunner/mlx/compile_test.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompileCallbackPanicRecovers(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
boom := Compile1("boom", func(a *Array) *Array {
|
||||
panic("intentional test panic")
|
||||
})
|
||||
|
||||
x := FromValues([]float32{1}, 1)
|
||||
Pin(x)
|
||||
defer Unpin(x)
|
||||
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
t.Fatal("expected panic from Call, got none")
|
||||
}
|
||||
if _, ok := r.(string); !ok {
|
||||
t.Fatalf("expected string panic, got %T: %v", r, r)
|
||||
}
|
||||
// MLX overwrites the message; what matters is that the Go
|
||||
// panic was caught before unwinding into C (which would be UB)
|
||||
// and that Call surfaced a failure instead.
|
||||
}()
|
||||
boom(x)
|
||||
}
|
||||
|
||||
func TestCompileUnary(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
square := Compile1("square", func(a *Array) *Array {
|
||||
return a.Multiply(a)
|
||||
})
|
||||
|
||||
x := FromValues([]float32{1, 2, 3, 4}, 4)
|
||||
Pin(x)
|
||||
defer Unpin(x)
|
||||
|
||||
y := square(x)
|
||||
Eval(y)
|
||||
|
||||
got := y.Floats()
|
||||
want := []float32{1, 4, 9, 16}
|
||||
for i, v := range got {
|
||||
if v != want[i] {
|
||||
t.Fatalf("got[%d]=%v want %v", i, v, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileBinary(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
add := Compile2("add", func(a, b *Array) *Array {
|
||||
return a.Add(b)
|
||||
})
|
||||
|
||||
a := FromValues([]float32{1, 2, 3}, 3)
|
||||
b := FromValues([]float32{10, 20, 30}, 3)
|
||||
Pin(a, b)
|
||||
defer Unpin(a, b)
|
||||
|
||||
c := add(a, b)
|
||||
Eval(c)
|
||||
|
||||
got := c.Floats()
|
||||
want := []float32{11, 22, 33}
|
||||
for i, v := range got {
|
||||
if v != want[i] {
|
||||
t.Fatalf("got[%d]=%v want %v", i, v, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileShapelessReshape(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// A shapeless compiled kernel should accept inputs of different shapes
|
||||
// on subsequent calls without recompiling or erroring.
|
||||
fn := Compile1("square", func(a *Array) *Array {
|
||||
return a.Multiply(a)
|
||||
}, Shapeless())
|
||||
|
||||
for _, n := range []int{2, 4, 8} {
|
||||
data := make([]float32, n)
|
||||
for i := range data {
|
||||
data[i] = float32(i + 1)
|
||||
}
|
||||
x := FromValues(data, n)
|
||||
Pin(x)
|
||||
y := fn(x)
|
||||
Eval(y)
|
||||
got := y.Floats()
|
||||
for i, v := range got {
|
||||
want := float32((i + 1) * (i + 1))
|
||||
if v != want {
|
||||
t.Fatalf("n=%d got[%d]=%v want %v", n, i, v, want)
|
||||
}
|
||||
}
|
||||
Unpin(x)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwiGLU(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
gate := FromValues([]float32{-1, 0, 1, 2}, 4)
|
||||
up := FromValues([]float32{1, 2, 3, 4}, 4)
|
||||
Pin(gate, up)
|
||||
defer Unpin(gate, up)
|
||||
|
||||
y := SwiGLU(gate, up)
|
||||
Eval(y)
|
||||
|
||||
got := y.Floats()
|
||||
// Reference: silu(g) * u = g * sigmoid(g) * u
|
||||
wantVals := []float32{-1, 0, 1, 2}
|
||||
upVals := []float32{1, 2, 3, 4}
|
||||
for i, g := range wantVals {
|
||||
silu := g / float32(1+math.Exp(float64(-g)))
|
||||
want := silu * upVals[i]
|
||||
if math.Abs(float64(got[i]-want)) > 1e-5 {
|
||||
t.Fatalf("i=%d got=%v want=%v", i, got[i], want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileNoTrackingGrowth(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// Repeated invocations of a compiled kernel should not grow the
|
||||
// tracked-arrays list by the number of internal ops each call — the
|
||||
// callback's scratch list frees them before return.
|
||||
fn := Compile2("mul_add", func(a, b *Array) *Array {
|
||||
// Two ops per call; if we leaked, we'd see growth proportional to
|
||||
// iterations * 2 in the tracked list.
|
||||
return a.Multiply(b).Add(b)
|
||||
})
|
||||
|
||||
a := FromValues([]float32{1, 2}, 2)
|
||||
b := FromValues([]float32{3, 4}, 2)
|
||||
Pin(a, b)
|
||||
defer Unpin(a, b)
|
||||
|
||||
// Prime so the initial trace's allocations are already accounted for.
|
||||
_ = fn(a, b)
|
||||
Eval()
|
||||
Sweep()
|
||||
|
||||
before := len(arrays)
|
||||
for range 100 {
|
||||
_ = fn(a, b)
|
||||
Eval()
|
||||
Sweep()
|
||||
}
|
||||
after := len(arrays)
|
||||
if after > before+2 {
|
||||
t.Fatalf("tracked arrays grew from %d to %d across 100 calls", before, after)
|
||||
}
|
||||
}
|
||||
@@ -53,6 +53,19 @@ func Version() string {
|
||||
return C.GoString(C.mlx_string_data(str))
|
||||
}
|
||||
|
||||
// clearLastError resets the captured MLX error state before a call that may fail.
|
||||
func clearLastError() {
|
||||
C.mlx_clear_last_error()
|
||||
}
|
||||
|
||||
// lastError returns the last captured MLX error message, or "" if none.
|
||||
func lastError() string {
|
||||
if C.mlx_had_last_error() == 0 {
|
||||
return ""
|
||||
}
|
||||
return C.GoString(C.mlx_get_last_error())
|
||||
}
|
||||
|
||||
func doEval(outputs []*Array, async bool) {
|
||||
if len(outputs) == 0 {
|
||||
return
|
||||
|
||||
@@ -23,15 +23,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
return errors.New("model not loaded")
|
||||
}
|
||||
|
||||
enableCompile := true
|
||||
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
||||
enableCompile = modelCompile.EnableCompile()
|
||||
}
|
||||
if enableCompile {
|
||||
mlx.EnableCompile()
|
||||
} else {
|
||||
mlx.DisableCompile()
|
||||
}
|
||||
mlx.ResetPeakMemory()
|
||||
ctx := request.Ctx
|
||||
var (
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -79,6 +81,18 @@ func (r *Runner) Load(modelName string) error {
|
||||
r.Model = m
|
||||
r.Tokenizer = m.Tokenizer()
|
||||
r.contextLength = m.MaxContextLength()
|
||||
|
||||
enableCompile := true
|
||||
if s := os.Getenv("OLLAMA_MLX_COMPILE"); s != "" {
|
||||
if b, err := strconv.ParseBool(s); err == nil {
|
||||
enableCompile = b
|
||||
}
|
||||
}
|
||||
if enableCompile {
|
||||
mlx.EnableCompile()
|
||||
} else {
|
||||
mlx.DisableCompile()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user