diff --git a/go.mod b/go.mod index bf4c787d1..a4d6ccfed 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/mattn/go-sqlite3 v1.14.24 github.com/olekukonko/tablewriter v0.0.5 github.com/spf13/cobra v1.7.0 - github.com/stretchr/testify v1.9.0 + github.com/stretchr/testify v1.10.0 github.com/x448/float16 v0.8.4 golang.org/x/sync v0.17.0 golang.org/x/sys v0.37.0 @@ -31,6 +31,8 @@ require ( github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c github.com/tkrajina/typescriptify-golang-structs v0.2.0 + github.com/tree-sitter/go-tree-sitter v0.25.0 + github.com/tree-sitter/tree-sitter-cpp v0.23.4 github.com/wk8/go-ordered-map/v2 v2.1.8 golang.org/x/image v0.22.0 golang.org/x/mod v0.30.0 @@ -60,6 +62,7 @@ require ( github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-localereader v0.0.1 // indirect + github.com/mattn/go-pointer v0.0.1 // indirect github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/termenv v0.16.0 // indirect diff --git a/go.sum b/go.sum index b34f17e0e..13df21902 100644 --- a/go.sum +++ b/go.sum @@ -172,6 +172,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= +github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0= +github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= @@ -233,12 +235,39 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tkrajina/go-reflector v0.5.5 h1:gwoQFNye30Kk7NrExj8zm3zFtrGPqOkzFMLuQZg1DtQ= github.com/tkrajina/go-reflector v0.5.5/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4= github.com/tkrajina/typescriptify-golang-structs v0.2.0 h1:ZedWk82egydDspGTryAatbX0/1NZDQbdiZLoCbOk4f8= github.com/tkrajina/typescriptify-golang-structs v0.2.0/go.mod h1:sjU00nti/PMEOZb07KljFlR+lJ+RotsC0GBQMv9EKls= +github.com/tree-sitter/go-tree-sitter v0.25.0 h1:sx6kcg8raRFCvc9BnXglke6axya12krCJF5xJ2sftRU= +github.com/tree-sitter/go-tree-sitter v0.25.0/go.mod h1:r77ig7BikoZhHrrsjAnv8RqGti5rtSyvDHPzgTPsUuU= +github.com/tree-sitter/tree-sitter-c v0.23.4 h1:nBPH3FV07DzAD7p0GfNvXM+Y7pNIoPenQWBpvM++t4c= +github.com/tree-sitter/tree-sitter-c v0.23.4/go.mod h1:MkI5dOiIpeN94LNjeCp8ljXN/953JCwAby4bClMr6bw= +github.com/tree-sitter/tree-sitter-cpp v0.23.4 h1:LaWZsiqQKvR65yHgKmnaqA+uz6tlDJTJFCyFIeZU/8w= +github.com/tree-sitter/tree-sitter-cpp v0.23.4/go.mod h1:doqNW64BriC7WBCQ1klf0KmJpdEvfxyXtoEybnBo6v8= +github.com/tree-sitter/tree-sitter-embedded-template v0.23.2 h1:nFkkH6Sbe56EXLmZBqHHcamTpmz3TId97I16EnGy4rg= +github.com/tree-sitter/tree-sitter-embedded-template v0.23.2/go.mod h1:HNPOhN0qF3hWluYLdxWs5WbzP/iE4aaRVPMsdxuzIaQ= +github.com/tree-sitter/tree-sitter-go v0.23.4 h1:yt5KMGnTHS+86pJmLIAZMWxukr8W7Ae1STPvQUuNROA= +github.com/tree-sitter/tree-sitter-go v0.23.4/go.mod h1:Jrx8QqYN0v7npv1fJRH1AznddllYiCMUChtVjxPK040= +github.com/tree-sitter/tree-sitter-html v0.23.2 h1:1UYDV+Yd05GGRhVnTcbP58GkKLSHHZwVaN+lBZV11Lc= +github.com/tree-sitter/tree-sitter-html v0.23.2/go.mod h1:gpUv/dG3Xl/eebqgeYeFMt+JLOY9cgFinb/Nw08a9og= +github.com/tree-sitter/tree-sitter-java v0.23.5 h1:J9YeMGMwXYlKSP3K4Us8CitC6hjtMjqpeOf2GGo6tig= +github.com/tree-sitter/tree-sitter-java v0.23.5/go.mod h1:NRKlI8+EznxA7t1Yt3xtraPk1Wzqh3GAIC46wxvc320= +github.com/tree-sitter/tree-sitter-javascript v0.23.1 h1:1fWupaRC0ArlHJ/QJzsfQ3Ibyopw7ZfQK4xXc40Zveo= +github.com/tree-sitter/tree-sitter-javascript v0.23.1/go.mod h1:lmGD1EJdCA+v0S1u2fFgepMg/opzSg/4pgFym2FPGAs= +github.com/tree-sitter/tree-sitter-json v0.24.8 h1:tV5rMkihgtiOe14a9LHfDY5kzTl5GNUYe6carZBn0fQ= +github.com/tree-sitter/tree-sitter-json v0.24.8/go.mod h1:F351KK0KGvCaYbZ5zxwx/gWWvZhIDl0eMtn+1r+gQbo= +github.com/tree-sitter/tree-sitter-php v0.23.11 h1:iHewsLNDmznh8kgGyfWfujsZxIz1YGbSd2ZTEM0ZiP8= +github.com/tree-sitter/tree-sitter-php v0.23.11/go.mod h1:T/kbfi+UcCywQfUNAJnGTN/fMSUjnwPXA8k4yoIks74= +github.com/tree-sitter/tree-sitter-python v0.23.6 h1:qHnWFR5WhtMQpxBZRwiaU5Hk/29vGju6CVtmvu5Haas= +github.com/tree-sitter/tree-sitter-python v0.23.6/go.mod h1:cpdthSy/Yoa28aJFBscFHlGiU+cnSiSh1kuDVtI8YeM= +github.com/tree-sitter/tree-sitter-ruby v0.23.1 h1:T/NKHUA+iVbHM440hFx+lzVOzS4dV6z8Qw8ai+72bYo= +github.com/tree-sitter/tree-sitter-ruby v0.23.1/go.mod h1:kUS4kCCQloFcdX6sdpr8p6r2rogbM6ZjTox5ZOQy8cA= +github.com/tree-sitter/tree-sitter-rust v0.23.2 h1:6AtoooCW5GqNrRpfnvl0iUhxTAZEovEmLKDbyHlfw90= +github.com/tree-sitter/tree-sitter-rust v0.23.2/go.mod h1:hfeGWic9BAfgTrc7Xf6FaOAguCFJRo3RBbs7QJ6D7MI= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= diff --git a/runner/runner.go b/runner/runner.go index c243e2ec8..d2daddf69 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -4,6 +4,7 @@ import ( "github.com/ollama/ollama/runner/llamarunner" "github.com/ollama/ollama/runner/ollamarunner" "github.com/ollama/ollama/x/imagegen" + "github.com/ollama/ollama/x/mlxrunner" ) func Execute(args []string) error { @@ -17,6 +18,8 @@ func Execute(args []string) error { return ollamarunner.Execute(args[1:]) case "--imagegen-engine": return imagegen.Execute(args[1:]) + case "--mlx-engine": + return mlxrunner.Execute(args[1:]) } } return llamarunner.Execute(args) diff --git a/x/imagegen/manifest/weights.go b/x/imagegen/manifest/weights.go index 19a6ede07..e0ad0399c 100644 --- a/x/imagegen/manifest/weights.go +++ b/x/imagegen/manifest/weights.go @@ -102,8 +102,15 @@ func (mw *ManifestWeights) Load(dtype mlx.Dtype) error { for _, entry := range entries { name := entry.name - // Try to get tensor by manifest name - arr := sf.Get(name) + // Try to get tensor by stripped name first, then with component prefix. + // Blobs may store tensors with the full prefixed name (e.g., "text_encoder/model.layers.0.weight") + // while the tensors map uses stripped names (e.g., "model.layers.0.weight"). + lookupName := name + arr := sf.Get(lookupName) + if arr == nil && mw.component != "" { + lookupName = mw.component + "/" + name + arr = sf.Get(lookupName) + } if arr != nil { // Single-tensor blob or tensor found by name if dtype != 0 && arr.Dtype() != dtype { @@ -114,14 +121,14 @@ func (mw *ManifestWeights) Load(dtype mlx.Dtype) error { arrays = append(arrays, arr) // Check for scale tensor - if scale := sf.Get(name + ".scale"); scale != nil { + if scale := sf.Get(lookupName + ".scale"); scale != nil { scale = mlx.Contiguous(scale) mw.cache[name+"_scale"] = scale arrays = append(arrays, scale) } // Check for bias tensor - if bias := sf.Get(name + ".bias"); bias != nil { + if bias := sf.Get(lookupName + ".bias"); bias != nil { bias = mlx.Contiguous(bias) mw.cache[name+"_qbias"] = bias arrays = append(arrays, bias) @@ -147,20 +154,27 @@ func (mw *ManifestWeights) Load(dtype mlx.Dtype) error { tArr = mlx.AsType(tArr, dtype) } tArr = mlx.Contiguous(tArr) - mw.cache[tensorName] = tArr + + // Strip component prefix from blob-internal names so cache keys + // match the stripped names used by LoadModule. + cacheName := tensorName + if mw.component != "" { + cacheName = strings.TrimPrefix(tensorName, mw.component+"/") + } + mw.cache[cacheName] = tArr arrays = append(arrays, tArr) // Check for scale tensor if scale := sf.Get(tensorName + ".scale"); scale != nil { scale = mlx.Contiguous(scale) - mw.cache[tensorName+"_scale"] = scale + mw.cache[cacheName+"_scale"] = scale arrays = append(arrays, scale) } // Check for bias tensor if bias := sf.Get(tensorName + ".bias"); bias != nil { bias = mlx.Contiguous(bias) - mw.cache[tensorName+"_qbias"] = bias + mw.cache[cacheName+"_qbias"] = bias arrays = append(arrays, bias) } } diff --git a/x/mlxrunner/cache.go b/x/mlxrunner/cache.go new file mode 100644 index 000000000..49ddd04b6 --- /dev/null +++ b/x/mlxrunner/cache.go @@ -0,0 +1,96 @@ +//go:build mlx + +package mlxrunner + +import ( + "log/slog" + + "github.com/ollama/ollama/x/mlxrunner/cache" +) + +type CacheEntry struct { + Caches []cache.Cache + Count int + Entries map[int32]*CacheEntry +} + +func (s Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) { + current := &CacheEntry{Entries: s.CacheEntries} + index, cacheIndex := 0, -1 + for _, token := range tokens { + if _, ok := current.Entries[token]; !ok { + break + } + + current = current.Entries[token] + if len(current.Caches) > 0 { + cacheIndex = index + } + + index += 1 + } + + if cacheIndex == len(tokens)-1 { + slog.Info("Cache hit", "type", "exact", "total", len(tokens), "cached", len(tokens), "left", len(tokens)) + return current.Caches, []int32{} + } else if cacheIndex > 1 { + slog.Info("Cache hit", "type", "partial", "total", len(tokens), "cached", cacheIndex+1, "left", len(tokens[cacheIndex+1:])) + return current.Caches, tokens[cacheIndex+1:] + } else if index > 0 && cacheIndex < 0 { + type stackItem struct { + entry *CacheEntry + tokens []int32 + } + + var best, item stackItem + stack := []stackItem{{entry: current, tokens: []int32{}}} + for len(stack) > 0 { + item, stack = stack[len(stack)-1], stack[:len(stack)-1] + if len(item.entry.Caches) > 0 { + if len(best.tokens) == 0 || len(item.tokens) < len(best.tokens) { + best = item + } + } else { + for token, entry := range item.entry.Entries { + stack = append(stack, stackItem{ + entry: entry, + tokens: append(item.tokens, token), + }) + } + } + } + + prefix := min(len(tokens)-1, index) + caches := make([]cache.Cache, len(best.entry.Caches)) + trim := len(best.tokens)+1 + for i := range caches { + caches[i] = best.entry.Caches[i].Clone() + caches[i].Trim(trim) + } + + slog.Info("Cache hit", "type", "prefix", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]), "trimmed", trim) + return caches, tokens[prefix:] + } + + slog.Info("Cache miss", "left", len(tokens)) + return nil, tokens +} + +func (s *Runner) InsertCache(tokens []int32, caches []cache.Cache) { + current := &CacheEntry{Entries: s.CacheEntries} + for _, token := range tokens { + if _, ok := current.Entries[token]; !ok { + current.Entries[token] = &CacheEntry{ + Entries: make(map[int32]*CacheEntry), + } + } + + current = current.Entries[token] + } + + if len(current.Caches) > 0 { + current.Count += 1 + } else { + current.Caches = caches + } +} diff --git a/x/mlxrunner/cache/cache.go b/x/mlxrunner/cache/cache.go new file mode 100644 index 000000000..05cffbf5e --- /dev/null +++ b/x/mlxrunner/cache/cache.go @@ -0,0 +1,198 @@ +//go:build mlx + +package cache + +import ( + "log/slog" + + "github.com/ollama/ollama/x/mlxrunner/mlx" +) + +type Cache interface { + Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array) + State() (keys, values *mlx.Array) + Trim(int) int + Clone() Cache + Offset() int + Len() int +} + +type KVCache struct { + keys, values *mlx.Array + offset int + step int +} + +func NewKVCache() *KVCache { + return &KVCache{step: 256} +} + +func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { + B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3) + + prev := c.offset + + // Grow buffer if needed + if c.keys == nil || (prev+L) > c.keys.Dim(2) { + steps := (c.step + L - 1) / c.step + newKeys := mlx.Zeros(keys.DType(), B, H, steps*c.step, Dk) + newValues := mlx.Zeros(values.DType(), B, H, steps*c.step, Dv) + + if c.keys != nil { + if prev%c.step != 0 { + c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice())) + c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice())) + } + c.keys.Set(c.keys.Concatenate(2, newKeys)) + c.values.Set(c.values.Concatenate(2, newValues)) + } else { + c.keys, c.values = newKeys, newValues + } + } + + c.offset += L + c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice())) + c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice())) + + return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()), + c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()) +} + +func (c *KVCache) State() (*mlx.Array, *mlx.Array) { + if c.offset == c.keys.Dim(2) { + return c.keys, c.values + } + return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()), + c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()) +} + +func (c *KVCache) Trim(n int) int { + n = min(c.offset, n) + c.offset -= n + return n +} + +func (c *KVCache) Clone() Cache { + return &KVCache{ + keys: c.keys.Clone(), + values: c.values.Clone(), + offset: c.offset, + step: c.step, + } +} + +func (c *KVCache) Offset() int { return c.offset } +func (c *KVCache) Len() int { return c.offset } + +// RotatingKVCache implements sliding window attention with bounded memory +type RotatingKVCache struct { + maxSize int + idx int + + *KVCache +} + +func NewRotatingKVCache(maxSize int) *RotatingKVCache { + return &RotatingKVCache{maxSize: maxSize, KVCache: NewKVCache()} +} + +func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { + if keys.Dim(2) > 1 { + return c.concat(keys, values) + } + return c.update(keys, values) +} + +func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) { + slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize) + if c.keys == nil { + c.keys, c.values = keys, values + } else { + if c.idx < c.keys.Dim(2) { + c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())) + c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())) + } + + // Trim to max_size to maintain sliding window + if trim := c.idx - c.maxSize + 1; trim > 0 { + c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice())) + c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice())) + } + + c.keys.Set(c.keys.Concatenate(2, keys)) + c.values.Set(c.values.Concatenate(2, values)) + c.idx = c.keys.Dim(2) + } + + c.offset += keys.Dim(2) + c.idx = c.keys.Dim(2) + return c.keys, c.values +} + +func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) { + slog.Debug("(*RotatingKVCache).update", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize) + B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3) + + prev := c.offset + + // Grow buffer if not yet at max + if c.keys == nil || (prev >= c.keys.Dim(2) && c.keys.Dim(2) < c.maxSize) { + newSize := min(c.step, c.maxSize-prev) + newKeys := mlx.Zeros(keys.DType(), B, H, newSize, Dk) + newValues := mlx.Zeros(values.DType(), B, H, newSize, Dv) + if c.keys != nil { + c.keys.Set(c.keys.Concatenate(2, newKeys)) + c.values.Set(c.values.Concatenate(2, newValues)) + } else { + c.keys, c.values = newKeys, newValues + } + c.idx = prev + } + + // Trim to max_size to maintain sliding window + if trim := c.keys.Dim(2) - c.maxSize; trim > 0 { + c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice())) + c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice())) + c.idx = c.maxSize + } + + // Rotate when hitting max + if c.idx >= c.maxSize { + c.idx = 0 + } + + c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice())) + c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice())) + + c.offset += L + c.idx += L + + validLen := min(c.offset, c.maxSize) + return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice()), + c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice()) +} + +func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) { + if c.offset < c.keys.Dim(2) { + return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()), + c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()) + } + return c.keys, c.values +} + +func (c *RotatingKVCache) Trim(n int) int { + n = min(c.offset, n) + c.offset -= n + c.idx -= n + return n +} + +func (c *RotatingKVCache) Clone() Cache { + return &RotatingKVCache{ + maxSize: c.maxSize, + idx: c.idx, + KVCache: c.KVCache.Clone().(*KVCache), + } +} + +func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) } diff --git a/x/mlxrunner/client.go b/x/mlxrunner/client.go new file mode 100644 index 000000000..e3e5157ab --- /dev/null +++ b/x/mlxrunner/client.go @@ -0,0 +1,174 @@ +package mlxrunner + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "math" + "net" + "net/http" + "net/url" + "os/exec" + "strconv" + "strings" + + "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/ml" +) + +type Client struct { + Port int + *exec.Cmd +} + +func (c *Client) JoinPath(path string) string { + return (&url.URL{ + Scheme: "http", + Host: net.JoinHostPort("127.0.0.1", strconv.Itoa(c.Port)), + }).JoinPath(path).String() +} + +func (c *Client) CheckError(w *http.Response) error { + if w.StatusCode >= 400 { + return errors.New(w.Status) + } + return nil +} + +// Close implements llm.LlamaServer. +func (c *Client) Close() error { + return c.Cmd.Process.Kill() +} + +// Completion implements llm.LlamaServer. +func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(req); err != nil { + return err + } + + w, err := http.Post(c.JoinPath("/v1/completions"), "application/json", &b) + if err != nil { + return err + } + defer w.Body.Close() + + if err := c.CheckError(w); err != nil { + return err + } + + scanner := bufio.NewScanner(w.Body) + for scanner.Scan() { + bts := scanner.Bytes() + + var resp llm.CompletionResponse + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } + + fn(resp) + } + + return nil +} + +func (c *Client) ContextLength() int { + return math.MaxInt +} + +// Detokenize implements llm.LlamaServer. +func (c *Client) Detokenize(ctx context.Context, tokens []int) (string, error) { + panic("unimplemented") +} + +// Embedding implements llm.LlamaServer. +func (c *Client) Embedding(ctx context.Context, input string) ([]float32, int, error) { + panic("unimplemented") +} + +// GetDeviceInfos implements llm.LlamaServer. +func (c *Client) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { + panic("unimplemented") +} + +// GetPort implements llm.LlamaServer. +func (c *Client) GetPort() int { + return c.Port +} + +// HasExited implements llm.LlamaServer. +func (c *Client) HasExited() bool { + panic("unimplemented") +} + +// Load implements llm.LlamaServer. +func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) ([]ml.DeviceID, error) { + w, err := http.Post(c.JoinPath("/v1/models"), "application/json", nil) + if err != nil { + return nil, err + } + defer w.Body.Close() + + return []ml.DeviceID{}, nil +} + +// ModelPath implements llm.LlamaServer. +func (c *Client) ModelPath() string { + panic("unimplemented") +} + +// Pid implements llm.LlamaServer. +func (c *Client) Pid() int { + panic("unimplemented") +} + +// Ping implements llm.LlamaServer. +func (c *Client) Ping(ctx context.Context) error { + w, err := http.Get(c.JoinPath("/v1/status")) + if err != nil { + return err + } + defer w.Body.Close() + + return nil +} + +// Tokenize implements llm.LlamaServer. +func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) { + w, err := http.Post(c.JoinPath("/v1/tokenize"), "text/plain", strings.NewReader(content)) + if err != nil { + return nil, err + } + defer w.Body.Close() + + var tokens []int + if err := json.NewDecoder(w.Body).Decode(&tokens); err != nil { + return nil, err + } + + return tokens, nil +} + +// TotalSize implements llm.LlamaServer. +func (c *Client) TotalSize() uint64 { + panic("unimplemented") +} + +// VRAMByGPU implements llm.LlamaServer. +func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 { + panic("unimplemented") +} + +// VRAMSize implements llm.LlamaServer. +func (c *Client) VRAMSize() uint64 { + panic("unimplemented") +} + +// WaitUntilRunning implements llm.LlamaServer. +func (c *Client) WaitUntilRunning(ctx context.Context) error { + panic("unimplemented") +} + +var _ llm.LlamaServer = (*Client)(nil) diff --git a/x/mlxrunner/mlx/.gitignore b/x/mlxrunner/mlx/.gitignore new file mode 100644 index 000000000..b3ccd18fc --- /dev/null +++ b/x/mlxrunner/mlx/.gitignore @@ -0,0 +1,3 @@ +_deps +build +dist diff --git a/x/mlxrunner/mlx/CMakeLists.txt b/x/mlxrunner/mlx/CMakeLists.txt new file mode 100644 index 000000000..c41ce46f7 --- /dev/null +++ b/x/mlxrunner/mlx/CMakeLists.txt @@ -0,0 +1,26 @@ +cmake_minimum_required(VERSION 3.5) + +project(mlx) + +if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) + set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/dist" CACHE PATH "" FORCE) +endif() + +set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE) +set(MLX_BUILD_SAFETENSORS ON CACHE BOOL "" FORCE) +set(MLX_C_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) +set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE) + +set(CMAKE_INSTALL_RPATH "@loader_path") + +include(FetchContent) + +set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "") + +FetchContent_Declare( + mlx-c + GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git" + GIT_TAG ${MLX_C_GIT_TAG} +) + +FetchContent_MakeAvailable(mlx-c) diff --git a/x/mlxrunner/mlx/act.go b/x/mlxrunner/mlx/act.go new file mode 100644 index 000000000..3134a127a --- /dev/null +++ b/x/mlxrunner/mlx/act.go @@ -0,0 +1,23 @@ +//go:build mlx + +package mlx + +// #include "generated.h" +import "C" +import "math" + +func GELUApprox(t *Array) *Array { + return t.Multiply( + FromValue[float32](0.5), + ).Multiply( + t.Add( + t.Power(FromValue[float32](3.0)).Multiply(FromValue[float32](0.044715)), + ).Multiply( + FromValue(float32(math.Sqrt(2 / math.Pi))), + ).Tanh().Add(FromValue[float32](1.0)), + ).AsType(t.DType()) +} + +func SILU(t *Array) *Array { + return t.Multiply(t.Sigmoid()).AsType(t.DType()) +} diff --git a/x/mlxrunner/mlx/array.go b/x/mlxrunner/mlx/array.go new file mode 100644 index 000000000..bec8d3444 --- /dev/null +++ b/x/mlxrunner/mlx/array.go @@ -0,0 +1,273 @@ +//go:build mlx + +package mlx + +// #include "generated.h" +import "C" + +import ( + "encoding/binary" + "log/slog" + "reflect" + "strings" + "time" + "unsafe" + + "github.com/ollama/ollama/logutil" +) + +type tensorDesc struct { + name string + inputs []*Array + numRefs int +} + +func (d tensorDesc) LogValue() slog.Value { + return slog.GroupValue( + slog.String("name", d.name), + slog.Int("inputs", len(d.inputs)), + slog.Int("num_refs", d.numRefs), + ) +} + +type Array struct { + ctx C.mlx_array + desc tensorDesc +} + +// constructor utilities + +func New(name string, inputs ...*Array) *Array { + t := &Array{ + desc: tensorDesc{ + name: name, + inputs: inputs, + }, + } + + for _, input := range inputs { + input.desc.numRefs++ + } + logutil.Trace("New", "t", t) + return t +} + +type scalarTypes interface { + ~bool | ~int | ~float32 | ~float64 | ~complex64 +} + +func FromValue[T scalarTypes](t T) *Array { + tt := New("") + switch v := any(t).(type) { + case bool: + tt.ctx = C.mlx_array_new_bool(C.bool(v)) + case int: + tt.ctx = C.mlx_array_new_int(C.int(v)) + case float32: + tt.ctx = C.mlx_array_new_float32(C.float(v)) + case float64: + tt.ctx = C.mlx_array_new_float64(C.double(v)) + case complex64: + tt.ctx = C.mlx_array_new_complex(C.float(real(v)), C.float(imag(v))) + default: + panic("unsupported type") + } + return tt +} + +type arrayTypes interface { + ~bool | ~uint8 | ~uint16 | ~uint32 | ~uint64 | + ~int8 | ~int16 | ~int32 | ~int64 | + ~float32 | ~float64 | + ~complex64 +} + +func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array { + if len(shape) == 0 { + panic("shape must be provided for non-scalar tensors") + } + + cShape := make([]C.int, len(shape)) + for i := range shape { + cShape[i] = C.int(shape[i]) + } + + var dtype DType + switch reflect.TypeOf(s).Elem().Kind() { + case reflect.Bool: + dtype = DTypeBool + case reflect.Uint8: + dtype = DTypeUint8 + case reflect.Uint16: + dtype = DTypeUint16 + case reflect.Uint32: + dtype = DTypeUint32 + case reflect.Uint64: + dtype = DTypeUint64 + case reflect.Int8: + dtype = DTypeInt8 + case reflect.Int16: + dtype = DTypeInt16 + case reflect.Int32: + dtype = DTypeInt32 + case reflect.Int64: + dtype = DTypeInt64 + case reflect.Float32: + dtype = DTypeFloat32 + case reflect.Float64: + dtype = DTypeFloat64 + case reflect.Complex64: + dtype = DTypeComplex64 + default: + panic("unsupported type") + } + + bts := make([]byte, binary.Size(s)) + if _, err := binary.Encode(bts, binary.LittleEndian, s); err != nil { + panic(err) + } + + tt := New("") + tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype)) + return tt +} + +func (t *Array) Set(other *Array) { + other.desc.numRefs++ + t.desc.inputs = []*Array{other} + C.mlx_array_set(&t.ctx, other.ctx) +} + +func (t *Array) Clone() *Array { + tt := New(t.desc.name, t.desc.inputs...) + C.mlx_array_set(&tt.ctx, t.ctx) + return tt +} + +// misc. utilities + +func (t *Array) Valid() bool { + return t.ctx.ctx != nil +} + +func (t *Array) String() string { + str := C.mlx_string_new() + defer C.mlx_string_free(str) + C.mlx_array_tostring(&str, t.ctx) + return strings.TrimSpace(C.GoString(C.mlx_string_data(str))) +} + +func (t *Array) LogValue() slog.Value { + attrs := []slog.Attr{slog.Any("", t.desc)} + if t.Valid() { + attrs = append(attrs, + slog.Any("dtype", t.DType()), + slog.Any("shape", t.Dims()), + slog.Int("num_bytes", t.NumBytes()), + ) + } + return slog.GroupValue(attrs...) +} + +// shape utilities + +func (t Array) Size() int { + return int(C.mlx_array_size(t.ctx)) +} + +func (t Array) NumBytes() int { + return int(C.mlx_array_nbytes(t.ctx)) +} + +func (t Array) NumDims() int { + return int(C.mlx_array_ndim(t.ctx)) +} + +func (t Array) Dims() []int { + dims := make([]int, t.NumDims()) + for i := range dims { + dims[i] = t.Dim(i) + } + + return dims +} + +func (t Array) Dim(dim int) int { + return int(C.mlx_array_dim(t.ctx, C.int(dim))) +} + +func (t Array) DType() DType { + return DType(C.mlx_array_dtype(t.ctx)) +} + +// data utilities + +func (t Array) Int() int { + var item C.int64_t + C.mlx_array_item_int64(&item, t.ctx) + return int(item) +} + +func (t Array) Float() float64 { + var item C.double + C.mlx_array_item_float64(&item, t.ctx) + return float64(item) +} + +func (t Array) Ints() []int { + ints := make([]int, t.Size()) + for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) { + ints[i] = int(f) + } + return ints +} + +func (t Array) Floats() []float32 { + floats := make([]float32, t.Size()) + for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) { + floats[i] = float32(f) + } + return floats +} + +func (t Array) Save(name string) error { + cName := C.CString(name) + defer C.free(unsafe.Pointer(cName)) + C.mlx_save(cName, t.ctx) + return nil +} + +func Free(s ...*Array) (n int) { + now := time.Now() + defer func() { + if n > 0 { + logutil.Trace("Freed tensors", "num_bytes", PrettyBytes(n), "took", time.Since(now)) + } + }() + + free := make([]*Array, 0, 8192) + fn := func(t *Array) { + if t.Valid() { + free = append(free, t.desc.inputs...) + t.desc.numRefs-- + if t.desc.numRefs <= 0 { + logutil.Trace("Free", "t", t) + n += t.NumBytes() + C.mlx_array_free(t.ctx) + t.ctx.ctx = nil + } + } + } + + for _, t := range s { + fn(t) + } + + for len(free) > 0 { + tail := free[len(free)-1] + free = free[:len(free)-1] + fn(tail) + } + + return n +} diff --git a/x/mlxrunner/mlx/array_test.go b/x/mlxrunner/mlx/array_test.go new file mode 100644 index 000000000..aab5db7ba --- /dev/null +++ b/x/mlxrunner/mlx/array_test.go @@ -0,0 +1,45 @@ +//go:build mlx + +package mlx + +import "testing" + +func TestFromValue(t *testing.T) { + for got, want := range map[*Array]DType{ + FromValue(true): DTypeBool, + FromValue(false): DTypeBool, + FromValue(int(7)): DTypeInt32, + FromValue(float32(3.14)): DTypeFloat32, + FromValue(float64(2.71)): DTypeFloat64, + FromValue(complex64(1 + 2i)): DTypeComplex64, + } { + t.Run(want.String(), func(t *testing.T) { + if got.DType() != want { + t.Errorf("want %v, got %v", want, got) + } + }) + } +} + +func TestFromValues(t *testing.T) { + for got, want := range map[*Array]DType{ + FromValues([]bool{true, false, true}, 3): DTypeBool, + FromValues([]uint8{1, 2, 3}, 3): DTypeUint8, + FromValues([]uint16{1, 2, 3}, 3): DTypeUint16, + FromValues([]uint32{1, 2, 3}, 3): DTypeUint32, + FromValues([]uint64{1, 2, 3}, 3): DTypeUint64, + FromValues([]int8{-1, -2, -3}, 3): DTypeInt8, + FromValues([]int16{-1, -2, -3}, 3): DTypeInt16, + FromValues([]int32{-1, -2, -3}, 3): DTypeInt32, + FromValues([]int64{-1, -2, -3}, 3): DTypeInt64, + FromValues([]float32{3.14, 2.71, 1.61}, 3): DTypeFloat32, + FromValues([]float64{3.14, 2.71, 1.61}, 3): DTypeFloat64, + FromValues([]complex64{1 + 2i, 3 + 4i, 5 + 6i}, 3): DTypeComplex64, + } { + t.Run(want.String(), func(t *testing.T) { + if got.DType() != want { + t.Errorf("want %v, got %v", want, got) + } + }) + } +} diff --git a/x/mlxrunner/mlx/dtype.go b/x/mlxrunner/mlx/dtype.go new file mode 100644 index 000000000..95237c792 --- /dev/null +++ b/x/mlxrunner/mlx/dtype.go @@ -0,0 +1,96 @@ +//go:build mlx + +package mlx + +// #include "generated.h" +import "C" + +type DType int + +func (t DType) String() string { + switch t { + case DTypeBool: + return "BOOL" + case DTypeUint8: + return "U8" + case DTypeUint16: + return "U16" + case DTypeUint32: + return "U32" + case DTypeUint64: + return "U64" + case DTypeInt8: + return "I8" + case DTypeInt16: + return "I16" + case DTypeInt32: + return "I32" + case DTypeInt64: + return "I64" + case DTypeFloat16: + return "F16" + case DTypeFloat32: + return "F32" + case DTypeFloat64: + return "F64" + case DTypeBFloat16: + return "BF16" + case DTypeComplex64: + return "C64" + default: + return "Unknown" + } +} + +func (t *DType) UnmarshalJSON(b []byte) error { + switch string(b) { + case `"BOOL"`: + *t = DTypeBool + case `"U8"`: + *t = DTypeUint8 + case `"U16"`: + *t = DTypeUint16 + case `"U32"`: + *t = DTypeUint32 + case `"U64"`: + *t = DTypeUint64 + case `"I8"`: + *t = DTypeInt8 + case `"I16"`: + *t = DTypeInt16 + case `"I32"`: + *t = DTypeInt32 + case `"I64"`: + *t = DTypeInt64 + case `"F16"`: + *t = DTypeFloat16 + case `"F64"`: + *t = DTypeFloat64 + case `"F32"`: + *t = DTypeFloat32 + case `"BF16"`: + *t = DTypeBFloat16 + case `"C64"`: + *t = DTypeComplex64 + default: + return nil + } + return nil +} + +const ( + DTypeBool DType = C.MLX_BOOL + DTypeUint8 DType = C.MLX_UINT8 + DTypeUint16 DType = C.MLX_UINT16 + DTypeUint32 DType = C.MLX_UINT32 + DTypeUint64 DType = C.MLX_UINT64 + DTypeInt8 DType = C.MLX_INT8 + DTypeInt16 DType = C.MLX_INT16 + DTypeInt32 DType = C.MLX_INT32 + DTypeInt64 DType = C.MLX_INT64 + DTypeFloat16 DType = C.MLX_FLOAT16 + DTypeFloat32 DType = C.MLX_FLOAT32 + DTypeFloat64 DType = C.MLX_FLOAT64 + DTypeBFloat16 DType = C.MLX_BFLOAT16 + DTypeComplex64 DType = C.MLX_COMPLEX64 +) diff --git a/x/mlxrunner/mlx/dynamic.c b/x/mlxrunner/mlx/dynamic.c new file mode 100644 index 000000000..d3c4e6e6c --- /dev/null +++ b/x/mlxrunner/mlx/dynamic.c @@ -0,0 +1,34 @@ +#include "dynamic.h" + +#include + +#ifdef _WIN32 +#include +#define DLOPEN(path) LoadLibraryA(path) +#define DLCLOSE(handle) FreeLibrary((HMODULE)(handle)) +#else +#ifdef __APPLE__ +#include +#include +#endif +#include +#define DLOPEN(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL) +#define DLCLOSE(handle) dlclose(handle) +#endif + +static int mlx_dynamic_open(mlx_dynamic_handle* handle, const char* path) { + handle->ctx = (void*) DLOPEN(path); + CHECK(handle->ctx != NULL); + return 0; +} + +int mlx_dynamic_load(mlx_dynamic_handle* handle, const char *path) { + return mlx_dynamic_open(handle, path); +} + +void mlx_dynamic_unload(mlx_dynamic_handle* handle) { + if (handle->ctx) { + DLCLOSE(handle->ctx); + handle->ctx = NULL; + } +} diff --git a/x/mlxrunner/mlx/dynamic.go b/x/mlxrunner/mlx/dynamic.go new file mode 100644 index 000000000..eb6427fb5 --- /dev/null +++ b/x/mlxrunner/mlx/dynamic.go @@ -0,0 +1,65 @@ +//go:build mlx + +package mlx + +// #include "dynamic.h" +// #include "generated.h" +// #include +import "C" + +import ( + "io/fs" + "log/slog" + "os" + "path/filepath" + "runtime" + "unsafe" +) + +func init() { + switch runtime.GOOS { + case "darwin": + + case "windows": + default: + return + } + + paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH") + if !ok { + slog.Debug("OLLAMA_LIBRARY_PATH not set, skipping mlx dynamic loading") + return + } + + for _, path := range filepath.SplitList(paths) { + matches, err := fs.Glob(os.DirFS(path), "libmlxc.*") + if err != nil { + panic(err) + } + + for _, match := range matches { + path := filepath.Join(paths, match) + slog.Info("Loading MLX dynamic library", "path", path) + + cPath := C.CString(path) + defer C.free(unsafe.Pointer(cPath)) + + var handle C.mlx_dynamic_handle + if C.mlx_dynamic_load(&handle, cPath) != 0 { + slog.Error("Failed to load MLX dynamic library", "path", path) + continue + } + + if C.mlx_dynamic_load_symbols(handle) != 0 { + slog.Error("Failed to load MLX dynamic library symbols", "path", path) + C.mlx_dynamic_unload(&handle) + continue + } + + slog.Info("Loaded MLX dynamic library", "path", path) + return + } + } + + panic("Failed to load any MLX dynamic library") +} diff --git a/x/mlxrunner/mlx/dynamic.h b/x/mlxrunner/mlx/dynamic.h new file mode 100644 index 000000000..f93d8fab7 --- /dev/null +++ b/x/mlxrunner/mlx/dynamic.h @@ -0,0 +1,41 @@ +#ifndef MLX_DYNAMIC_H +#define MLX_DYNAMIC_H + +#ifdef _WIN32 +#include +#define DLSYM(handle, symbol) GetProcAddress((HMODULE)(handle), symbol) +#else +#include +#define DLSYM(handle, symbol) dlsym(handle.ctx, symbol) +#endif + +#include + +// Provide fallback typedefs for float16_t and bfloat16_t on non-ARM64 +// platforms where arm_fp16.h and arm_bf16.h are not available. These are +// only used as function pointer signature placeholders since MLX requires +// Apple Silicon at runtime. +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) +typedef uint16_t float16_t; +#endif + +#if !defined(__aarch64__) && !defined(__ARM_FEATURE_BF16) +typedef uint16_t bfloat16_t; +#endif + +#define ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1 +#define CHECK(x) if (!(x)) { ERROR("CHECK failed: " #x); } +#define CHECK_LOAD(handle, x) x##_ = DLSYM(handle, #x); CHECK(x##_) + +typedef struct { + void* ctx; +} mlx_dynamic_handle; + +int mlx_dynamic_load( + mlx_dynamic_handle* handle, + const char *path); + +void mlx_dynamic_unload( + mlx_dynamic_handle* handle); + +#endif // MLX_DYNAMIC_H diff --git a/x/mlxrunner/mlx/fast.go b/x/mlxrunner/mlx/fast.go new file mode 100644 index 000000000..250d42dc8 --- /dev/null +++ b/x/mlxrunner/mlx/fast.go @@ -0,0 +1,74 @@ +//go:build mlx + +package mlx + +// #include "generated.h" +import "C" + +import ( + "unsafe" +) + +func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *Array { + if mask == nil { + mask = New("") + } + + sinks := New("") + + mode := "causal" + cMode := C.CString(mode) + defer C.free(unsafe.Pointer(cMode)) + + out := New("FAST_SDPA", query, key, value, mask, sinks) + C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx) + return out +} + +type LayerNorm struct { + Weight Array `weight:"weight"` + Bias Array `weight:"bias"` +} + +func (r *LayerNorm) Forward(x *Array, eps float32) *Array { + out := New("FAST_LAYERNORM", x) + C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx) + return out +} + +type RMSNorm struct { + Weight Array `weight:"weight"` +} + +func (r RMSNorm) Forward(x *Array, eps float32) *Array { + out := New("FAST_RMSNORM", x) + C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx) + return out +} + +type RoPE struct { + Dims int + Traditional bool + Base float32 `json:"rope_theta"` + Scale float32 +} + +func (r RoPE) Forward(t *Array, offset int) *Array { + freqs := New("") + out := New("FAST_ROPE", t, freqs) + C.mlx_fast_rope( + &out.ctx, + t.ctx, + C.int(r.Dims), + C._Bool(r.Traditional), + C.mlx_optional_float{ + value: C.float(r.Base), + has_value: C._Bool(func() bool { return r.Base != 0 }()), + }, + C.float(r.Scale), + C.int(offset), + freqs.ctx, + DefaultStream().ctx, + ) + return out +} diff --git a/x/mlxrunner/mlx/generated.c b/x/mlxrunner/mlx/generated.c new file mode 100644 index 000000000..af99b631e --- /dev/null +++ b/x/mlxrunner/mlx/generated.c @@ -0,0 +1,2724 @@ +// This code is auto-generated; DO NOT EDIT. + +#include "generated.h" + +#include +#include +#include + +size_t (*mlx_dtype_size_)(mlx_dtype dtype) = NULL; +int (*mlx_array_tostring_)(mlx_string* str, const mlx_array arr) = NULL; +mlx_array (*mlx_array_new_)(void) = NULL; +int (*mlx_array_free_)(mlx_array arr) = NULL; +mlx_array (*mlx_array_new_bool_)(bool val) = NULL; +mlx_array (*mlx_array_new_int_)(int val) = NULL; +mlx_array (*mlx_array_new_float32_)(float val) = NULL; +mlx_array (*mlx_array_new_float_)(float val) = NULL; +mlx_array (*mlx_array_new_float64_)(double val) = NULL; +mlx_array (*mlx_array_new_double_)(double val) = NULL; +mlx_array (*mlx_array_new_complex_)(float real_val, float imag_val) = NULL; +mlx_array (*mlx_array_new_data_)( + const void* data, + const int* shape, + int dim, + mlx_dtype dtype) = NULL; +int (*mlx_array_set_)(mlx_array* arr, const mlx_array src) = NULL; +int (*mlx_array_set_bool_)(mlx_array* arr, bool val) = NULL; +int (*mlx_array_set_int_)(mlx_array* arr, int val) = NULL; +int (*mlx_array_set_float32_)(mlx_array* arr, float val) = NULL; +int (*mlx_array_set_float_)(mlx_array* arr, float val) = NULL; +int (*mlx_array_set_float64_)(mlx_array* arr, double val) = NULL; +int (*mlx_array_set_double_)(mlx_array* arr, double val) = NULL; +int (*mlx_array_set_complex_)(mlx_array* arr, float real_val, float imag_val) = NULL; +int (*mlx_array_set_data_)( + mlx_array* arr, + const void* data, + const int* shape, + int dim, + mlx_dtype dtype) = NULL; +size_t (*mlx_array_itemsize_)(const mlx_array arr) = NULL; +size_t (*mlx_array_size_)(const mlx_array arr) = NULL; +size_t (*mlx_array_nbytes_)(const mlx_array arr) = NULL; +size_t (*mlx_array_ndim_)(const mlx_array arr) = NULL; +const int * (*mlx_array_shape_)(const mlx_array arr) = NULL; +const size_t * (*mlx_array_strides_)(const mlx_array arr) = NULL; +int (*mlx_array_dim_)(const mlx_array arr, int dim) = NULL; +mlx_dtype (*mlx_array_dtype_)(const mlx_array arr) = NULL; +int (*mlx_array_eval_)(mlx_array arr) = NULL; +int (*mlx_array_item_bool_)(bool* res, const mlx_array arr) = NULL; +int (*mlx_array_item_uint8_)(uint8_t* res, const mlx_array arr) = NULL; +int (*mlx_array_item_uint16_)(uint16_t* res, const mlx_array arr) = NULL; +int (*mlx_array_item_uint32_)(uint32_t* res, const mlx_array arr) = NULL; +int (*mlx_array_item_uint64_)(uint64_t* res, const mlx_array arr) = NULL; +int (*mlx_array_item_int8_)(int8_t* res, const mlx_array arr) = NULL; +int (*mlx_array_item_int16_)(int16_t* res, const mlx_array arr) = NULL; +int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr) = NULL; +int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr) = NULL; +int (*mlx_array_item_float32_)(float* res, const mlx_array arr) = NULL; +int (*mlx_array_item_float64_)(double* res, const mlx_array arr) = NULL; +int (*mlx_array_item_complex64_)(float _Complex* res, const mlx_array arr) = NULL; +int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr) = NULL; +int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr) = NULL; +const bool * (*mlx_array_data_bool_)(const mlx_array arr) = NULL; +const uint8_t * (*mlx_array_data_uint8_)(const mlx_array arr) = NULL; +const uint16_t * (*mlx_array_data_uint16_)(const mlx_array arr) = NULL; +const uint32_t * (*mlx_array_data_uint32_)(const mlx_array arr) = NULL; +const uint64_t * (*mlx_array_data_uint64_)(const mlx_array arr) = NULL; +const int8_t * (*mlx_array_data_int8_)(const mlx_array arr) = NULL; +const int16_t * (*mlx_array_data_int16_)(const mlx_array arr) = NULL; +const int32_t * (*mlx_array_data_int32_)(const mlx_array arr) = NULL; +const int64_t * (*mlx_array_data_int64_)(const mlx_array arr) = NULL; +const float * (*mlx_array_data_float32_)(const mlx_array arr) = NULL; +const double * (*mlx_array_data_float64_)(const mlx_array arr) = NULL; +const float _Complex * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL; +const float16_t * (*mlx_array_data_float16_)(const mlx_array arr) = NULL; +const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr) = NULL; +int (*_mlx_array_is_available_)(bool* res, const mlx_array arr) = NULL; +int (*_mlx_array_wait_)(const mlx_array arr) = NULL; +int (*_mlx_array_is_contiguous_)(bool* res, const mlx_array arr) = NULL; +int (*_mlx_array_is_row_contiguous_)(bool* res, const mlx_array arr) = NULL; +int (*_mlx_array_is_col_contiguous_)(bool* res, const mlx_array arr) = NULL; +mlx_closure (*mlx_closure_new_)(void) = NULL; +int (*mlx_closure_free_)(mlx_closure cls) = NULL; +mlx_closure (*mlx_closure_new_func_)( + int (*fun)(mlx_vector_array*, const mlx_vector_array)) = NULL; +mlx_closure (*mlx_closure_new_func_payload_)( + int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), + void* payload, + void (*dtor)(void*)) = NULL; +int (*mlx_closure_set_)(mlx_closure* cls, const mlx_closure src) = NULL; +int (*mlx_closure_apply_)( + mlx_vector_array* res, + mlx_closure cls, + const mlx_vector_array input) = NULL; +mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array)) = NULL; +mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void) = NULL; +int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls) = NULL; +mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_map_string_to_array)) = NULL; +mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_map_string_to_array, + void*), + void* payload, + void (*dtor)(void*)) = NULL; +int (*mlx_closure_kwargs_set_)( + mlx_closure_kwargs* cls, + const mlx_closure_kwargs src) = NULL; +int (*mlx_closure_kwargs_apply_)( + mlx_vector_array* res, + mlx_closure_kwargs cls, + const mlx_vector_array input_0, + const mlx_map_string_to_array input_1) = NULL; +mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_)(void) = NULL; +int (*mlx_closure_value_and_grad_free_)(mlx_closure_value_and_grad cls) = NULL; +mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_)( + int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)) = NULL; +mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_payload_)( + int (*fun)( + mlx_vector_array*, + mlx_vector_array*, + const mlx_vector_array, + void*), + void* payload, + void (*dtor)(void*)) = NULL; +int (*mlx_closure_value_and_grad_set_)( + mlx_closure_value_and_grad* cls, + const mlx_closure_value_and_grad src) = NULL; +int (*mlx_closure_value_and_grad_apply_)( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + mlx_closure_value_and_grad cls, + const mlx_vector_array input) = NULL; +mlx_closure_custom (*mlx_closure_custom_new_)(void) = NULL; +int (*mlx_closure_custom_free_)(mlx_closure_custom cls) = NULL; +mlx_closure_custom (*mlx_closure_custom_new_func_)(int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const mlx_vector_array)) = NULL; +mlx_closure_custom (*mlx_closure_custom_new_func_payload_)( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const mlx_vector_array, + void*), + void* payload, + void (*dtor)(void*)) = NULL; +int (*mlx_closure_custom_set_)( + mlx_closure_custom* cls, + const mlx_closure_custom src) = NULL; +int (*mlx_closure_custom_apply_)( + mlx_vector_array* res, + mlx_closure_custom cls, + const mlx_vector_array input_0, + const mlx_vector_array input_1, + const mlx_vector_array input_2) = NULL; +mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void) = NULL; +int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls) = NULL; +mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const int*, + size_t _num)) = NULL; +mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const int*, + size_t _num, + void*), + void* payload, + void (*dtor)(void*)) = NULL; +int (*mlx_closure_custom_jvp_set_)( + mlx_closure_custom_jvp* cls, + const mlx_closure_custom_jvp src) = NULL; +int (*mlx_closure_custom_jvp_apply_)( + mlx_vector_array* res, + mlx_closure_custom_jvp cls, + const mlx_vector_array input_0, + const mlx_vector_array input_1, + const int* input_2, + size_t input_2_num) = NULL; +mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void) = NULL; +int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls) = NULL; +mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(int (*fun)( + mlx_vector_array*, + mlx_vector_int*, + const mlx_vector_array, + const int*, + size_t _num)) = NULL; +mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)( + int (*fun)( + mlx_vector_array*, + mlx_vector_int*, + const mlx_vector_array, + const int*, + size_t _num, + void*), + void* payload, + void (*dtor)(void*)) = NULL; +int (*mlx_closure_custom_vmap_set_)( + mlx_closure_custom_vmap* cls, + const mlx_closure_custom_vmap src) = NULL; +int (*mlx_closure_custom_vmap_apply_)( + mlx_vector_array* res_0, + mlx_vector_int* res_1, + mlx_closure_custom_vmap cls, + const mlx_vector_array input_0, + const int* input_1, + size_t input_1_num) = NULL; +int (*mlx_compile_)(mlx_closure* res, const mlx_closure fun, bool shapeless) = NULL; +int (*mlx_detail_compile_)( + mlx_closure* res, + const mlx_closure fun, + uintptr_t fun_id, + bool shapeless, + const uint64_t* constants, + size_t constants_num) = NULL; +int (*mlx_detail_compile_clear_cache_)(void) = NULL; +int (*mlx_detail_compile_erase_)(uintptr_t fun_id) = NULL; +int (*mlx_disable_compile_)(void) = NULL; +int (*mlx_enable_compile_)(void) = NULL; +int (*mlx_set_compile_mode_)(mlx_compile_mode mode) = NULL; +mlx_device (*mlx_device_new_)(void) = NULL; +mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index) = NULL; +int (*mlx_device_free_)(mlx_device dev) = NULL; +int (*mlx_device_set_)(mlx_device* dev, const mlx_device src) = NULL; +int (*mlx_device_tostring_)(mlx_string* str, mlx_device dev) = NULL; +bool (*mlx_device_equal_)(mlx_device lhs, mlx_device rhs) = NULL; +int (*mlx_device_get_index_)(int* index, mlx_device dev) = NULL; +int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev) = NULL; +int (*mlx_get_default_device_)(mlx_device* dev) = NULL; +int (*mlx_set_default_device_)(mlx_device dev) = NULL; +int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL; +int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL; +mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL; +bool (*mlx_distributed_is_available_)(void) = NULL; +mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL; +int (*mlx_distributed_all_gather_)( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream S) = NULL; +int (*mlx_distributed_all_max_)( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_distributed_all_min_)( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_distributed_all_sum_)( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_distributed_recv_)( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + int src, + const mlx_distributed_group group /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_distributed_recv_like_)( + mlx_array* res, + const mlx_array x, + int src, + const mlx_distributed_group group /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_distributed_send_)( + mlx_array* res, + const mlx_array x, + int dst, + const mlx_distributed_group group /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_distributed_sum_scatter_)( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s) = NULL; +void (*mlx_set_error_handler_)( + mlx_error_handler_func handler, + void* data, + void (*dtor)(void*)) = NULL; +void (*_mlx_error_)(const char* file, const int line, const char* fmt, ...) = NULL; +int (*mlx_export_function_)( + const char* file, + const mlx_closure fun, + const mlx_vector_array args, + bool shapeless) = NULL; +int (*mlx_export_function_kwargs_)( + const char* file, + const mlx_closure_kwargs fun, + const mlx_vector_array args, + const mlx_map_string_to_array kwargs, + bool shapeless) = NULL; +mlx_function_exporter (*mlx_function_exporter_new_)( + const char* file, + const mlx_closure fun, + bool shapeless) = NULL; +int (*mlx_function_exporter_free_)(mlx_function_exporter xfunc) = NULL; +int (*mlx_function_exporter_apply_)( + const mlx_function_exporter xfunc, + const mlx_vector_array args) = NULL; +int (*mlx_function_exporter_apply_kwargs_)( + const mlx_function_exporter xfunc, + const mlx_vector_array args, + const mlx_map_string_to_array kwargs) = NULL; +mlx_imported_function (*mlx_imported_function_new_)(const char* file) = NULL; +int (*mlx_imported_function_free_)(mlx_imported_function xfunc) = NULL; +int (*mlx_imported_function_apply_)( + mlx_vector_array* res, + const mlx_imported_function xfunc, + const mlx_vector_array args) = NULL; +int (*mlx_imported_function_apply_kwargs_)( + mlx_vector_array* res, + const mlx_imported_function xfunc, + const mlx_vector_array args, + const mlx_map_string_to_array kwargs) = NULL; +mlx_fast_cuda_kernel_config (*mlx_fast_cuda_kernel_config_new_)(void) = NULL; +void (*mlx_fast_cuda_kernel_config_free_)(mlx_fast_cuda_kernel_config cls) = NULL; +int (*mlx_fast_cuda_kernel_config_add_output_arg_)( + mlx_fast_cuda_kernel_config cls, + const int* shape, + size_t size, + mlx_dtype dtype) = NULL; +int (*mlx_fast_cuda_kernel_config_set_grid_)( + mlx_fast_cuda_kernel_config cls, + int grid1, + int grid2, + int grid3) = NULL; +int (*mlx_fast_cuda_kernel_config_set_thread_group_)( + mlx_fast_cuda_kernel_config cls, + int thread1, + int thread2, + int thread3) = NULL; +int (*mlx_fast_cuda_kernel_config_set_init_value_)( + mlx_fast_cuda_kernel_config cls, + float value) = NULL; +int (*mlx_fast_cuda_kernel_config_set_verbose_)( + mlx_fast_cuda_kernel_config cls, + bool verbose) = NULL; +int (*mlx_fast_cuda_kernel_config_add_template_arg_dtype_)( + mlx_fast_cuda_kernel_config cls, + const char* name, + mlx_dtype dtype) = NULL; +int (*mlx_fast_cuda_kernel_config_add_template_arg_int_)( + mlx_fast_cuda_kernel_config cls, + const char* name, + int value) = NULL; +int (*mlx_fast_cuda_kernel_config_add_template_arg_bool_)( + mlx_fast_cuda_kernel_config cls, + const char* name, + bool value) = NULL; +mlx_fast_cuda_kernel (*mlx_fast_cuda_kernel_new_)( + const char* name, + const mlx_vector_string input_names, + const mlx_vector_string output_names, + const char* source, + const char* header, + bool ensure_row_contiguous, + int shared_memory) = NULL; +void (*mlx_fast_cuda_kernel_free_)(mlx_fast_cuda_kernel cls) = NULL; +int (*mlx_fast_cuda_kernel_apply_)( + mlx_vector_array* outputs, + mlx_fast_cuda_kernel cls, + const mlx_vector_array inputs, + const mlx_fast_cuda_kernel_config config, + const mlx_stream stream) = NULL; +int (*mlx_fast_layer_norm_)( + mlx_array* res, + const mlx_array x, + const mlx_array weight /* may be null */, + const mlx_array bias /* may be null */, + float eps, + const mlx_stream s) = NULL; +mlx_fast_metal_kernel_config (*mlx_fast_metal_kernel_config_new_)(void) = NULL; +void (*mlx_fast_metal_kernel_config_free_)(mlx_fast_metal_kernel_config cls) = NULL; +int (*mlx_fast_metal_kernel_config_add_output_arg_)( + mlx_fast_metal_kernel_config cls, + const int* shape, + size_t size, + mlx_dtype dtype) = NULL; +int (*mlx_fast_metal_kernel_config_set_grid_)( + mlx_fast_metal_kernel_config cls, + int grid1, + int grid2, + int grid3) = NULL; +int (*mlx_fast_metal_kernel_config_set_thread_group_)( + mlx_fast_metal_kernel_config cls, + int thread1, + int thread2, + int thread3) = NULL; +int (*mlx_fast_metal_kernel_config_set_init_value_)( + mlx_fast_metal_kernel_config cls, + float value) = NULL; +int (*mlx_fast_metal_kernel_config_set_verbose_)( + mlx_fast_metal_kernel_config cls, + bool verbose) = NULL; +int (*mlx_fast_metal_kernel_config_add_template_arg_dtype_)( + mlx_fast_metal_kernel_config cls, + const char* name, + mlx_dtype dtype) = NULL; +int (*mlx_fast_metal_kernel_config_add_template_arg_int_)( + mlx_fast_metal_kernel_config cls, + const char* name, + int value) = NULL; +int (*mlx_fast_metal_kernel_config_add_template_arg_bool_)( + mlx_fast_metal_kernel_config cls, + const char* name, + bool value) = NULL; +mlx_fast_metal_kernel (*mlx_fast_metal_kernel_new_)( + const char* name, + const mlx_vector_string input_names, + const mlx_vector_string output_names, + const char* source, + const char* header, + bool ensure_row_contiguous, + bool atomic_outputs) = NULL; +void (*mlx_fast_metal_kernel_free_)(mlx_fast_metal_kernel cls) = NULL; +int (*mlx_fast_metal_kernel_apply_)( + mlx_vector_array* outputs, + mlx_fast_metal_kernel cls, + const mlx_vector_array inputs, + const mlx_fast_metal_kernel_config config, + const mlx_stream stream) = NULL; +int (*mlx_fast_rms_norm_)( + mlx_array* res, + const mlx_array x, + const mlx_array weight /* may be null */, + float eps, + const mlx_stream s) = NULL; +int (*mlx_fast_rope_)( + mlx_array* res, + const mlx_array x, + int dims, + bool traditional, + mlx_optional_float base, + float scale, + int offset, + const mlx_array freqs /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_fast_scaled_dot_product_attention_)( + mlx_array* res, + const mlx_array queries, + const mlx_array keys, + const mlx_array values, + float scale, + const char* mask_mode, + const mlx_array mask_arr /* may be null */, + const mlx_array sinks /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_fft_fft_)( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s) = NULL; +int (*mlx_fft_fft2_)( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +int (*mlx_fft_fftn_)( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +int (*mlx_fft_fftshift_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +int (*mlx_fft_ifft_)( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s) = NULL; +int (*mlx_fft_ifft2_)( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +int (*mlx_fft_ifftn_)( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +int (*mlx_fft_ifftshift_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +int (*mlx_fft_irfft_)( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s) = NULL; +int (*mlx_fft_irfft2_)( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +int (*mlx_fft_irfftn_)( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +int (*mlx_fft_rfft_)( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s) = NULL; +int (*mlx_fft_rfft2_)( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +int (*mlx_fft_rfftn_)( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL; +int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL; +int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL; +int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL; +mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL; +int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL; +int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL; +int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL; +int (*mlx_load_reader_)( + mlx_array* res, + mlx_io_reader in_stream, + const mlx_stream s) = NULL; +int (*mlx_load_)(mlx_array* res, const char* file, const mlx_stream s) = NULL; +int (*mlx_load_safetensors_reader_)( + mlx_map_string_to_array* res_0, + mlx_map_string_to_string* res_1, + mlx_io_reader in_stream, + const mlx_stream s) = NULL; +int (*mlx_load_safetensors_)( + mlx_map_string_to_array* res_0, + mlx_map_string_to_string* res_1, + const char* file, + const mlx_stream s) = NULL; +int (*mlx_save_writer_)(mlx_io_writer out_stream, const mlx_array a) = NULL; +int (*mlx_save_)(const char* file, const mlx_array a) = NULL; +int (*mlx_save_safetensors_writer_)( + mlx_io_writer in_stream, + const mlx_map_string_to_array param, + const mlx_map_string_to_string metadata) = NULL; +int (*mlx_save_safetensors_)( + const char* file, + const mlx_map_string_to_array param, + const mlx_map_string_to_string metadata) = NULL; +int (*mlx_linalg_cholesky_)( + mlx_array* res, + const mlx_array a, + bool upper, + const mlx_stream s) = NULL; +int (*mlx_linalg_cholesky_inv_)( + mlx_array* res, + const mlx_array a, + bool upper, + const mlx_stream s) = NULL; +int (*mlx_linalg_cross_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + int axis, + const mlx_stream s) = NULL; +int (*mlx_linalg_eig_)( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const mlx_stream s) = NULL; +int (*mlx_linalg_eigh_)( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const char* UPLO, + const mlx_stream s) = NULL; +int (*mlx_linalg_eigvals_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_linalg_eigvalsh_)( + mlx_array* res, + const mlx_array a, + const char* UPLO, + const mlx_stream s) = NULL; +int (*mlx_linalg_inv_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_linalg_lu_)(mlx_vector_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_linalg_lu_factor_)( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const mlx_stream s) = NULL; +int (*mlx_linalg_norm_)( + mlx_array* res, + const mlx_array a, + double ord, + const int* axis /* may be null */, + size_t axis_num, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_linalg_norm_matrix_)( + mlx_array* res, + const mlx_array a, + const char* ord, + const int* axis /* may be null */, + size_t axis_num, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_linalg_norm_l2_)( + mlx_array* res, + const mlx_array a, + const int* axis /* may be null */, + size_t axis_num, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_linalg_pinv_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_linalg_qr_)( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const mlx_stream s) = NULL; +int (*mlx_linalg_solve_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_linalg_solve_triangular_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + bool upper, + const mlx_stream s) = NULL; +int (*mlx_linalg_svd_)( + mlx_vector_array* res, + const mlx_array a, + bool compute_uv, + const mlx_stream s) = NULL; +int (*mlx_linalg_tri_inv_)( + mlx_array* res, + const mlx_array a, + bool upper, + const mlx_stream s) = NULL; +mlx_map_string_to_array (*mlx_map_string_to_array_new_)(void) = NULL; +int (*mlx_map_string_to_array_set_)( + mlx_map_string_to_array* map, + const mlx_map_string_to_array src) = NULL; +int (*mlx_map_string_to_array_free_)(mlx_map_string_to_array map) = NULL; +int (*mlx_map_string_to_array_insert_)( + mlx_map_string_to_array map, + const char* key, + const mlx_array value) = NULL; +int (*mlx_map_string_to_array_get_)( + mlx_array* value, + const mlx_map_string_to_array map, + const char* key) = NULL; +mlx_map_string_to_array_iterator (*mlx_map_string_to_array_iterator_new_)( + mlx_map_string_to_array map) = NULL; +int (*mlx_map_string_to_array_iterator_free_)(mlx_map_string_to_array_iterator it) = NULL; +int (*mlx_map_string_to_array_iterator_next_)( + const char** key, + mlx_array* value, + mlx_map_string_to_array_iterator it) = NULL; +mlx_map_string_to_string (*mlx_map_string_to_string_new_)(void) = NULL; +int (*mlx_map_string_to_string_set_)( + mlx_map_string_to_string* map, + const mlx_map_string_to_string src) = NULL; +int (*mlx_map_string_to_string_free_)(mlx_map_string_to_string map) = NULL; +int (*mlx_map_string_to_string_insert_)( + mlx_map_string_to_string map, + const char* key, + const char* value) = NULL; +int (*mlx_map_string_to_string_get_)( + const char** value, + const mlx_map_string_to_string map, + const char* key) = NULL; +mlx_map_string_to_string_iterator (*mlx_map_string_to_string_iterator_new_)( + mlx_map_string_to_string map) = NULL; +int (*mlx_map_string_to_string_iterator_free_)( + mlx_map_string_to_string_iterator it) = NULL; +int (*mlx_map_string_to_string_iterator_next_)( + const char** key, + const char** value, + mlx_map_string_to_string_iterator it) = NULL; +int (*mlx_clear_cache_)(void) = NULL; +int (*mlx_get_active_memory_)(size_t* res) = NULL; +int (*mlx_get_cache_memory_)(size_t* res) = NULL; +int (*mlx_get_memory_limit_)(size_t* res) = NULL; +int (*mlx_get_peak_memory_)(size_t* res) = NULL; +int (*mlx_reset_peak_memory_)(void) = NULL; +int (*mlx_set_cache_limit_)(size_t* res, size_t limit) = NULL; +int (*mlx_set_memory_limit_)(size_t* res, size_t limit) = NULL; +int (*mlx_set_wired_limit_)(size_t* res, size_t limit) = NULL; +mlx_metal_device_info_t (*mlx_metal_device_info_)(void) = NULL; +int (*mlx_metal_is_available_)(bool* res) = NULL; +int (*mlx_metal_start_capture_)(const char* path) = NULL; +int (*mlx_metal_stop_capture_)(void) = NULL; +int (*mlx_abs_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_add_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_addmm_)( + mlx_array* res, + const mlx_array c, + const mlx_array a, + const mlx_array b, + float alpha, + float beta, + const mlx_stream s) = NULL; +int (*mlx_all_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_all_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_all_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_allclose_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + double rtol, + double atol, + bool equal_nan, + const mlx_stream s) = NULL; +int (*mlx_any_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_any_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_any_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_arange_)( + mlx_array* res, + double start, + double stop, + double step, + mlx_dtype dtype, + const mlx_stream s) = NULL; +int (*mlx_arccos_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_arccosh_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_arcsin_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_arcsinh_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_arctan_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_arctan2_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_arctanh_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_argmax_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_argmax_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_argmin_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_argmin_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_argpartition_axis_)( + mlx_array* res, + const mlx_array a, + int kth, + int axis, + const mlx_stream s) = NULL; +int (*mlx_argpartition_)( + mlx_array* res, + const mlx_array a, + int kth, + const mlx_stream s) = NULL; +int (*mlx_argsort_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s) = NULL; +int (*mlx_argsort_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_array_equal_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + bool equal_nan, + const mlx_stream s) = NULL; +int (*mlx_as_strided_)( + mlx_array* res, + const mlx_array a, + const int* shape, + size_t shape_num, + const int64_t* strides, + size_t strides_num, + size_t offset, + const mlx_stream s) = NULL; +int (*mlx_astype_)( + mlx_array* res, + const mlx_array a, + mlx_dtype dtype, + const mlx_stream s) = NULL; +int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_bitwise_and_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_bitwise_invert_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_bitwise_or_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_bitwise_xor_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_block_masked_mm_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + int block_size, + const mlx_array mask_out /* may be null */, + const mlx_array mask_lhs /* may be null */, + const mlx_array mask_rhs /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_broadcast_arrays_)( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_stream s) = NULL; +int (*mlx_broadcast_to_)( + mlx_array* res, + const mlx_array a, + const int* shape, + size_t shape_num, + const mlx_stream s) = NULL; +int (*mlx_ceil_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_clip_)( + mlx_array* res, + const mlx_array a, + const mlx_array a_min /* may be null */, + const mlx_array a_max /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_concatenate_axis_)( + mlx_array* res, + const mlx_vector_array arrays, + int axis, + const mlx_stream s) = NULL; +int (*mlx_concatenate_)( + mlx_array* res, + const mlx_vector_array arrays, + const mlx_stream s) = NULL; +int (*mlx_conjugate_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_contiguous_)( + mlx_array* res, + const mlx_array a, + bool allow_col_major, + const mlx_stream s) = NULL; +int (*mlx_conv1d_)( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride, + int padding, + int dilation, + int groups, + const mlx_stream s) = NULL; +int (*mlx_conv2d_)( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int padding_0, + int padding_1, + int dilation_0, + int dilation_1, + int groups, + const mlx_stream s) = NULL; +int (*mlx_conv3d_)( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int stride_2, + int padding_0, + int padding_1, + int padding_2, + int dilation_0, + int dilation_1, + int dilation_2, + int groups, + const mlx_stream s) = NULL; +int (*mlx_conv_general_)( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + const int* stride, + size_t stride_num, + const int* padding_lo, + size_t padding_lo_num, + const int* padding_hi, + size_t padding_hi_num, + const int* kernel_dilation, + size_t kernel_dilation_num, + const int* input_dilation, + size_t input_dilation_num, + int groups, + bool flip, + const mlx_stream s) = NULL; +int (*mlx_conv_transpose1d_)( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride, + int padding, + int dilation, + int output_padding, + int groups, + const mlx_stream s) = NULL; +int (*mlx_conv_transpose2d_)( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int padding_0, + int padding_1, + int dilation_0, + int dilation_1, + int output_padding_0, + int output_padding_1, + int groups, + const mlx_stream s) = NULL; +int (*mlx_conv_transpose3d_)( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int stride_2, + int padding_0, + int padding_1, + int padding_2, + int dilation_0, + int dilation_1, + int dilation_2, + int output_padding_0, + int output_padding_1, + int output_padding_2, + int groups, + const mlx_stream s) = NULL; +int (*mlx_copy_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_cos_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_cosh_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_cummax_)( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s) = NULL; +int (*mlx_cummin_)( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s) = NULL; +int (*mlx_cumprod_)( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s) = NULL; +int (*mlx_cumsum_)( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s) = NULL; +int (*mlx_degrees_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_depends_)( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_vector_array dependencies) = NULL; +int (*mlx_dequantize_)( + mlx_array* res, + const mlx_array w, + const mlx_array scales, + const mlx_array biases /* may be null */, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + mlx_optional_dtype dtype, + const mlx_stream s) = NULL; +int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL; +int (*mlx_diagonal_)( + mlx_array* res, + const mlx_array a, + int offset, + int axis1, + int axis2, + const mlx_stream s) = NULL; +int (*mlx_divide_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_divmod_)( + mlx_vector_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_einsum_)( + mlx_array* res, + const char* subscripts, + const mlx_vector_array operands, + const mlx_stream s) = NULL; +int (*mlx_equal_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_erf_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_erfinv_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_exp_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_expand_dims_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +int (*mlx_expand_dims_)( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s) = NULL; +int (*mlx_expm1_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_eye_)( + mlx_array* res, + int n, + int m, + int k, + mlx_dtype dtype, + const mlx_stream s) = NULL; +int (*mlx_flatten_)( + mlx_array* res, + const mlx_array a, + int start_axis, + int end_axis, + const mlx_stream s) = NULL; +int (*mlx_floor_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_floor_divide_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_from_fp8_)( + mlx_array* res, + const mlx_array x, + mlx_dtype dtype, + const mlx_stream s) = NULL; +int (*mlx_full_)( + mlx_array* res, + const int* shape, + size_t shape_num, + const mlx_array vals, + mlx_dtype dtype, + const mlx_stream s) = NULL; +int (*mlx_full_like_)( + mlx_array* res, + const mlx_array a, + const mlx_array vals, + mlx_dtype dtype, + const mlx_stream s) = NULL; +int (*mlx_gather_)( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const int* axes, + size_t axes_num, + const int* slice_sizes, + size_t slice_sizes_num, + const mlx_stream s) = NULL; +int (*mlx_gather_mm_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_array lhs_indices /* may be null */, + const mlx_array rhs_indices /* may be null */, + bool sorted_indices, + const mlx_stream s) = NULL; +int (*mlx_gather_qmm_)( + mlx_array* res, + const mlx_array x, + const mlx_array w, + const mlx_array scales, + const mlx_array biases /* may be null */, + const mlx_array lhs_indices /* may be null */, + const mlx_array rhs_indices /* may be null */, + bool transpose, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + bool sorted_indices, + const mlx_stream s) = NULL; +int (*mlx_greater_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_greater_equal_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_hadamard_transform_)( + mlx_array* res, + const mlx_array a, + mlx_optional_float scale, + const mlx_stream s) = NULL; +int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL; +int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_inner_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_isclose_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + double rtol, + double atol, + bool equal_nan, + const mlx_stream s) = NULL; +int (*mlx_isfinite_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_isinf_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_isnan_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_isneginf_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_isposinf_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_kron_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_left_shift_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_less_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_less_equal_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_linspace_)( + mlx_array* res, + double start, + double stop, + int num, + mlx_dtype dtype, + const mlx_stream s) = NULL; +int (*mlx_log_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_log10_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_log1p_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_log2_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_logaddexp_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_logcumsumexp_)( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s) = NULL; +int (*mlx_logical_and_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_logical_not_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_logical_or_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_logsumexp_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_logsumexp_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_logsumexp_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_masked_scatter_)( + mlx_array* res, + const mlx_array a, + const mlx_array mask, + const mlx_array src, + const mlx_stream s) = NULL; +int (*mlx_matmul_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_max_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_max_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_max_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_maximum_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_mean_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_mean_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_mean_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_median_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_meshgrid_)( + mlx_vector_array* res, + const mlx_vector_array arrays, + bool sparse, + const char* indexing, + const mlx_stream s) = NULL; +int (*mlx_min_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_min_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_min_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_minimum_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_moveaxis_)( + mlx_array* res, + const mlx_array a, + int source, + int destination, + const mlx_stream s) = NULL; +int (*mlx_multiply_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_nan_to_num_)( + mlx_array* res, + const mlx_array a, + float nan, + mlx_optional_float posinf, + mlx_optional_float neginf, + const mlx_stream s) = NULL; +int (*mlx_negative_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_not_equal_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_number_of_elements_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool inverted, + mlx_dtype dtype, + const mlx_stream s) = NULL; +int (*mlx_ones_)( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_stream s) = NULL; +int (*mlx_ones_like_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_outer_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_pad_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const int* low_pad_size, + size_t low_pad_size_num, + const int* high_pad_size, + size_t high_pad_size_num, + const mlx_array pad_value, + const char* mode, + const mlx_stream s) = NULL; +int (*mlx_pad_symmetric_)( + mlx_array* res, + const mlx_array a, + int pad_width, + const mlx_array pad_value, + const char* mode, + const mlx_stream s) = NULL; +int (*mlx_partition_axis_)( + mlx_array* res, + const mlx_array a, + int kth, + int axis, + const mlx_stream s) = NULL; +int (*mlx_partition_)( + mlx_array* res, + const mlx_array a, + int kth, + const mlx_stream s) = NULL; +int (*mlx_power_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_prod_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_prod_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_prod_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_put_along_axis_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array values, + int axis, + const mlx_stream s) = NULL; +int (*mlx_quantize_)( + mlx_vector_array* res, + const mlx_array w, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + const mlx_stream s) = NULL; +int (*mlx_quantized_matmul_)( + mlx_array* res, + const mlx_array x, + const mlx_array w, + const mlx_array scales, + const mlx_array biases /* may be null */, + bool transpose, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + const mlx_stream s) = NULL; +int (*mlx_radians_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_real_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_reciprocal_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_remainder_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_repeat_axis_)( + mlx_array* res, + const mlx_array arr, + int repeats, + int axis, + const mlx_stream s) = NULL; +int (*mlx_repeat_)( + mlx_array* res, + const mlx_array arr, + int repeats, + const mlx_stream s) = NULL; +int (*mlx_reshape_)( + mlx_array* res, + const mlx_array a, + const int* shape, + size_t shape_num, + const mlx_stream s) = NULL; +int (*mlx_right_shift_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_roll_axis_)( + mlx_array* res, + const mlx_array a, + const int* shift, + size_t shift_num, + int axis, + const mlx_stream s) = NULL; +int (*mlx_roll_axes_)( + mlx_array* res, + const mlx_array a, + const int* shift, + size_t shift_num, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +int (*mlx_roll_)( + mlx_array* res, + const mlx_array a, + const int* shift, + size_t shift_num, + const mlx_stream s) = NULL; +int (*mlx_round_)( + mlx_array* res, + const mlx_array a, + int decimals, + const mlx_stream s) = NULL; +int (*mlx_rsqrt_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_scatter_)( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +int (*mlx_scatter_add_)( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +int (*mlx_scatter_add_axis_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array values, + int axis, + const mlx_stream s) = NULL; +int (*mlx_scatter_max_)( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +int (*mlx_scatter_min_)( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +int (*mlx_scatter_prod_)( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +int (*mlx_segmented_mm_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_array segments, + const mlx_stream s) = NULL; +int (*mlx_sigmoid_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_sign_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_sin_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_sinh_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_slice_)( + mlx_array* res, + const mlx_array a, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s) = NULL; +int (*mlx_slice_dynamic_)( + mlx_array* res, + const mlx_array a, + const mlx_array start, + const int* axes, + size_t axes_num, + const int* slice_size, + size_t slice_size_num, + const mlx_stream s) = NULL; +int (*mlx_slice_update_)( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s) = NULL; +int (*mlx_slice_update_dynamic_)( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const mlx_array start, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +int (*mlx_softmax_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool precise, + const mlx_stream s) = NULL; +int (*mlx_softmax_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool precise, + const mlx_stream s) = NULL; +int (*mlx_softmax_)( + mlx_array* res, + const mlx_array a, + bool precise, + const mlx_stream s) = NULL; +int (*mlx_sort_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s) = NULL; +int (*mlx_sort_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_split_)( + mlx_vector_array* res, + const mlx_array a, + int num_splits, + int axis, + const mlx_stream s) = NULL; +int (*mlx_split_sections_)( + mlx_vector_array* res, + const mlx_array a, + const int* indices, + size_t indices_num, + int axis, + const mlx_stream s) = NULL; +int (*mlx_sqrt_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_square_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_squeeze_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +int (*mlx_squeeze_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s) = NULL; +int (*mlx_squeeze_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_stack_axis_)( + mlx_array* res, + const mlx_vector_array arrays, + int axis, + const mlx_stream s) = NULL; +int (*mlx_stack_)( + mlx_array* res, + const mlx_vector_array arrays, + const mlx_stream s) = NULL; +int (*mlx_std_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + int ddof, + const mlx_stream s) = NULL; +int (*mlx_std_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + int ddof, + const mlx_stream s) = NULL; +int (*mlx_std_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + int ddof, + const mlx_stream s) = NULL; +int (*mlx_stop_gradient_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_subtract_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) = NULL; +int (*mlx_sum_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_sum_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_sum_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) = NULL; +int (*mlx_swapaxes_)( + mlx_array* res, + const mlx_array a, + int axis1, + int axis2, + const mlx_stream s) = NULL; +int (*mlx_take_axis_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + int axis, + const mlx_stream s) = NULL; +int (*mlx_take_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_stream s) = NULL; +int (*mlx_take_along_axis_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + int axis, + const mlx_stream s) = NULL; +int (*mlx_tan_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_tanh_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_tensordot_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const int* axes_a, + size_t axes_a_num, + const int* axes_b, + size_t axes_b_num, + const mlx_stream s) = NULL; +int (*mlx_tensordot_axis_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + int axis, + const mlx_stream s) = NULL; +int (*mlx_tile_)( + mlx_array* res, + const mlx_array arr, + const int* reps, + size_t reps_num, + const mlx_stream s) = NULL; +int (*mlx_to_fp8_)(mlx_array* res, const mlx_array x, const mlx_stream s) = NULL; +int (*mlx_topk_axis_)( + mlx_array* res, + const mlx_array a, + int k, + int axis, + const mlx_stream s) = NULL; +int (*mlx_topk_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL; +int (*mlx_trace_)( + mlx_array* res, + const mlx_array a, + int offset, + int axis1, + int axis2, + mlx_dtype dtype, + const mlx_stream s) = NULL; +int (*mlx_transpose_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s) = NULL; +int (*mlx_transpose_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_tri_)( + mlx_array* res, + int n, + int m, + int k, + mlx_dtype type, + const mlx_stream s) = NULL; +int (*mlx_tril_)(mlx_array* res, const mlx_array x, int k, const mlx_stream s) = NULL; +int (*mlx_triu_)(mlx_array* res, const mlx_array x, int k, const mlx_stream s) = NULL; +int (*mlx_unflatten_)( + mlx_array* res, + const mlx_array a, + int axis, + const int* shape, + size_t shape_num, + const mlx_stream s) = NULL; +int (*mlx_var_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + int ddof, + const mlx_stream s) = NULL; +int (*mlx_var_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + int ddof, + const mlx_stream s) = NULL; +int (*mlx_var_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + int ddof, + const mlx_stream s) = NULL; +int (*mlx_view_)( + mlx_array* res, + const mlx_array a, + mlx_dtype dtype, + const mlx_stream s) = NULL; +int (*mlx_where_)( + mlx_array* res, + const mlx_array condition, + const mlx_array x, + const mlx_array y, + const mlx_stream s) = NULL; +int (*mlx_zeros_)( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_stream s) = NULL; +int (*mlx_zeros_like_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_random_bernoulli_)( + mlx_array* res, + const mlx_array p, + const int* shape, + size_t shape_num, + const mlx_array key /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_random_bits_)( + mlx_array* res, + const int* shape, + size_t shape_num, + int width, + const mlx_array key /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_random_categorical_shape_)( + mlx_array* res, + const mlx_array logits, + int axis, + const int* shape, + size_t shape_num, + const mlx_array key /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_random_categorical_num_samples_)( + mlx_array* res, + const mlx_array logits_, + int axis, + int num_samples, + const mlx_array key /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_random_categorical_)( + mlx_array* res, + const mlx_array logits, + int axis, + const mlx_array key /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_random_gumbel_)( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_random_key_)(mlx_array* res, uint64_t seed) = NULL; +int (*mlx_random_laplace_)( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + float loc, + float scale, + const mlx_array key /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_random_multivariate_normal_)( + mlx_array* res, + const mlx_array mean, + const mlx_array cov, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_random_normal_broadcast_)( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array loc /* may be null */, + const mlx_array scale /* may be null */, + const mlx_array key /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_random_normal_)( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + float loc, + float scale, + const mlx_array key /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_random_permutation_)( + mlx_array* res, + const mlx_array x, + int axis, + const mlx_array key /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_random_permutation_arange_)( + mlx_array* res, + int x, + const mlx_array key /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_random_randint_)( + mlx_array* res, + const mlx_array low, + const mlx_array high, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_random_seed_)(uint64_t seed) = NULL; +int (*mlx_random_split_num_)( + mlx_array* res, + const mlx_array key, + int num, + const mlx_stream s) = NULL; +int (*mlx_random_split_)( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array key, + const mlx_stream s) = NULL; +int (*mlx_random_truncated_normal_)( + mlx_array* res, + const mlx_array lower, + const mlx_array upper, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s) = NULL; +int (*mlx_random_uniform_)( + mlx_array* res, + const mlx_array low, + const mlx_array high, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s) = NULL; +mlx_stream (*mlx_stream_new_)(void) = NULL; +mlx_stream (*mlx_stream_new_device_)(mlx_device dev) = NULL; +int (*mlx_stream_set_)(mlx_stream* stream, const mlx_stream src) = NULL; +int (*mlx_stream_free_)(mlx_stream stream) = NULL; +int (*mlx_stream_tostring_)(mlx_string* str, mlx_stream stream) = NULL; +bool (*mlx_stream_equal_)(mlx_stream lhs, mlx_stream rhs) = NULL; +int (*mlx_stream_get_device_)(mlx_device* dev, mlx_stream stream) = NULL; +int (*mlx_stream_get_index_)(int* index, mlx_stream stream) = NULL; +int (*mlx_synchronize_)(mlx_stream stream) = NULL; +int (*mlx_get_default_stream_)(mlx_stream* stream, mlx_device dev) = NULL; +int (*mlx_set_default_stream_)(mlx_stream stream) = NULL; +mlx_stream (*mlx_default_cpu_stream_new_)(void) = NULL; +mlx_stream (*mlx_default_gpu_stream_new_)(void) = NULL; +mlx_string (*mlx_string_new_)(void) = NULL; +mlx_string (*mlx_string_new_data_)(const char* str) = NULL; +int (*mlx_string_set_)(mlx_string* str, const mlx_string src) = NULL; +const char * (*mlx_string_data_)(mlx_string str) = NULL; +int (*mlx_string_free_)(mlx_string str) = NULL; +int (*mlx_detail_vmap_replace_)( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_vector_array s_inputs, + const mlx_vector_array s_outputs, + const int* in_axes, + size_t in_axes_num, + const int* out_axes, + size_t out_axes_num) = NULL; +int (*mlx_detail_vmap_trace_)( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array inputs, + const int* in_axes, + size_t in_axes_num) = NULL; +int (*mlx_async_eval_)(const mlx_vector_array outputs) = NULL; +int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun) = NULL; +int (*mlx_custom_function_)( + mlx_closure* res, + const mlx_closure fun, + const mlx_closure_custom fun_vjp /* may be null */, + const mlx_closure_custom_jvp fun_jvp /* may be null */, + const mlx_closure_custom_vmap fun_vmap /* may be null */) = NULL; +int (*mlx_custom_vjp_)( + mlx_closure* res, + const mlx_closure fun, + const mlx_closure_custom fun_vjp) = NULL; +int (*mlx_eval_)(const mlx_vector_array outputs) = NULL; +int (*mlx_jvp_)( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array primals, + const mlx_vector_array tangents) = NULL; +int (*mlx_value_and_grad_)( + mlx_closure_value_and_grad* res, + const mlx_closure fun, + const int* argnums, + size_t argnums_num) = NULL; +int (*mlx_vjp_)( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array primals, + const mlx_vector_array cotangents) = NULL; +mlx_vector_array (*mlx_vector_array_new_)(void) = NULL; +int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src) = NULL; +int (*mlx_vector_array_free_)(mlx_vector_array vec) = NULL; +mlx_vector_array (*mlx_vector_array_new_data_)(const mlx_array* data, size_t size) = NULL; +mlx_vector_array (*mlx_vector_array_new_value_)(const mlx_array val) = NULL; +int (*mlx_vector_array_set_data_)( + mlx_vector_array* vec, + const mlx_array* data, + size_t size) = NULL; +int (*mlx_vector_array_set_value_)(mlx_vector_array* vec, const mlx_array val) = NULL; +int (*mlx_vector_array_append_data_)( + mlx_vector_array vec, + const mlx_array* data, + size_t size) = NULL; +int (*mlx_vector_array_append_value_)(mlx_vector_array vec, const mlx_array val) = NULL; +size_t (*mlx_vector_array_size_)(mlx_vector_array vec) = NULL; +int (*mlx_vector_array_get_)( + mlx_array* res, + const mlx_vector_array vec, + size_t idx) = NULL; +mlx_vector_vector_array (*mlx_vector_vector_array_new_)(void) = NULL; +int (*mlx_vector_vector_array_set_)( + mlx_vector_vector_array* vec, + const mlx_vector_vector_array src) = NULL; +int (*mlx_vector_vector_array_free_)(mlx_vector_vector_array vec) = NULL; +mlx_vector_vector_array (*mlx_vector_vector_array_new_data_)( + const mlx_vector_array* data, + size_t size) = NULL; +mlx_vector_vector_array (*mlx_vector_vector_array_new_value_)( + const mlx_vector_array val) = NULL; +int (*mlx_vector_vector_array_set_data_)( + mlx_vector_vector_array* vec, + const mlx_vector_array* data, + size_t size) = NULL; +int (*mlx_vector_vector_array_set_value_)( + mlx_vector_vector_array* vec, + const mlx_vector_array val) = NULL; +int (*mlx_vector_vector_array_append_data_)( + mlx_vector_vector_array vec, + const mlx_vector_array* data, + size_t size) = NULL; +int (*mlx_vector_vector_array_append_value_)( + mlx_vector_vector_array vec, + const mlx_vector_array val) = NULL; +size_t (*mlx_vector_vector_array_size_)(mlx_vector_vector_array vec) = NULL; +int (*mlx_vector_vector_array_get_)( + mlx_vector_array* res, + const mlx_vector_vector_array vec, + size_t idx) = NULL; +mlx_vector_int (*mlx_vector_int_new_)(void) = NULL; +int (*mlx_vector_int_set_)(mlx_vector_int* vec, const mlx_vector_int src) = NULL; +int (*mlx_vector_int_free_)(mlx_vector_int vec) = NULL; +mlx_vector_int (*mlx_vector_int_new_data_)(int* data, size_t size) = NULL; +mlx_vector_int (*mlx_vector_int_new_value_)(int val) = NULL; +int (*mlx_vector_int_set_data_)(mlx_vector_int* vec, int* data, size_t size) = NULL; +int (*mlx_vector_int_set_value_)(mlx_vector_int* vec, int val) = NULL; +int (*mlx_vector_int_append_data_)(mlx_vector_int vec, int* data, size_t size) = NULL; +int (*mlx_vector_int_append_value_)(mlx_vector_int vec, int val) = NULL; +size_t (*mlx_vector_int_size_)(mlx_vector_int vec) = NULL; +int (*mlx_vector_int_get_)(int* res, const mlx_vector_int vec, size_t idx) = NULL; +mlx_vector_string (*mlx_vector_string_new_)(void) = NULL; +int (*mlx_vector_string_set_)(mlx_vector_string* vec, const mlx_vector_string src) = NULL; +int (*mlx_vector_string_free_)(mlx_vector_string vec) = NULL; +mlx_vector_string (*mlx_vector_string_new_data_)(const char** data, size_t size) = NULL; +mlx_vector_string (*mlx_vector_string_new_value_)(const char* val) = NULL; +int (*mlx_vector_string_set_data_)( + mlx_vector_string* vec, + const char** data, + size_t size) = NULL; +int (*mlx_vector_string_set_value_)(mlx_vector_string* vec, const char* val) = NULL; +int (*mlx_vector_string_append_data_)( + mlx_vector_string vec, + const char** data, + size_t size) = NULL; +int (*mlx_vector_string_append_value_)(mlx_vector_string vec, const char* val) = NULL; +size_t (*mlx_vector_string_size_)(mlx_vector_string vec) = NULL; +int (*mlx_vector_string_get_)(char** res, const mlx_vector_string vec, size_t idx) = NULL; +int (*mlx_version_)(mlx_string* str_) = NULL; + +int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { + CHECK_LOAD(handle, mlx_dtype_size); + CHECK_LOAD(handle, mlx_array_tostring); + CHECK_LOAD(handle, mlx_array_new); + CHECK_LOAD(handle, mlx_array_free); + CHECK_LOAD(handle, mlx_array_new_bool); + CHECK_LOAD(handle, mlx_array_new_int); + CHECK_LOAD(handle, mlx_array_new_float32); + CHECK_LOAD(handle, mlx_array_new_float); + CHECK_LOAD(handle, mlx_array_new_float64); + CHECK_LOAD(handle, mlx_array_new_double); + CHECK_LOAD(handle, mlx_array_new_complex); + CHECK_LOAD(handle, mlx_array_new_data); + CHECK_LOAD(handle, mlx_array_set); + CHECK_LOAD(handle, mlx_array_set_bool); + CHECK_LOAD(handle, mlx_array_set_int); + CHECK_LOAD(handle, mlx_array_set_float32); + CHECK_LOAD(handle, mlx_array_set_float); + CHECK_LOAD(handle, mlx_array_set_float64); + CHECK_LOAD(handle, mlx_array_set_double); + CHECK_LOAD(handle, mlx_array_set_complex); + CHECK_LOAD(handle, mlx_array_set_data); + CHECK_LOAD(handle, mlx_array_itemsize); + CHECK_LOAD(handle, mlx_array_size); + CHECK_LOAD(handle, mlx_array_nbytes); + CHECK_LOAD(handle, mlx_array_ndim); + CHECK_LOAD(handle, mlx_array_shape); + CHECK_LOAD(handle, mlx_array_strides); + CHECK_LOAD(handle, mlx_array_dim); + CHECK_LOAD(handle, mlx_array_dtype); + CHECK_LOAD(handle, mlx_array_eval); + CHECK_LOAD(handle, mlx_array_item_bool); + CHECK_LOAD(handle, mlx_array_item_uint8); + CHECK_LOAD(handle, mlx_array_item_uint16); + CHECK_LOAD(handle, mlx_array_item_uint32); + CHECK_LOAD(handle, mlx_array_item_uint64); + CHECK_LOAD(handle, mlx_array_item_int8); + CHECK_LOAD(handle, mlx_array_item_int16); + CHECK_LOAD(handle, mlx_array_item_int32); + CHECK_LOAD(handle, mlx_array_item_int64); + CHECK_LOAD(handle, mlx_array_item_float32); + CHECK_LOAD(handle, mlx_array_item_float64); + CHECK_LOAD(handle, mlx_array_item_complex64); + CHECK_LOAD(handle, mlx_array_item_float16); + CHECK_LOAD(handle, mlx_array_item_bfloat16); + CHECK_LOAD(handle, mlx_array_data_bool); + CHECK_LOAD(handle, mlx_array_data_uint8); + CHECK_LOAD(handle, mlx_array_data_uint16); + CHECK_LOAD(handle, mlx_array_data_uint32); + CHECK_LOAD(handle, mlx_array_data_uint64); + CHECK_LOAD(handle, mlx_array_data_int8); + CHECK_LOAD(handle, mlx_array_data_int16); + CHECK_LOAD(handle, mlx_array_data_int32); + CHECK_LOAD(handle, mlx_array_data_int64); + CHECK_LOAD(handle, mlx_array_data_float32); + CHECK_LOAD(handle, mlx_array_data_float64); + CHECK_LOAD(handle, mlx_array_data_complex64); + CHECK_LOAD(handle, mlx_array_data_float16); + CHECK_LOAD(handle, mlx_array_data_bfloat16); + CHECK_LOAD(handle, _mlx_array_is_available); + CHECK_LOAD(handle, _mlx_array_wait); + CHECK_LOAD(handle, _mlx_array_is_contiguous); + CHECK_LOAD(handle, _mlx_array_is_row_contiguous); + CHECK_LOAD(handle, _mlx_array_is_col_contiguous); + CHECK_LOAD(handle, mlx_closure_new); + CHECK_LOAD(handle, mlx_closure_free); + CHECK_LOAD(handle, mlx_closure_new_func); + CHECK_LOAD(handle, mlx_closure_new_func_payload); + CHECK_LOAD(handle, mlx_closure_set); + CHECK_LOAD(handle, mlx_closure_apply); + CHECK_LOAD(handle, mlx_closure_new_unary); + CHECK_LOAD(handle, mlx_closure_kwargs_new); + CHECK_LOAD(handle, mlx_closure_kwargs_free); + CHECK_LOAD(handle, mlx_closure_kwargs_new_func); + CHECK_LOAD(handle, mlx_closure_kwargs_new_func_payload); + CHECK_LOAD(handle, mlx_closure_kwargs_set); + CHECK_LOAD(handle, mlx_closure_kwargs_apply); + CHECK_LOAD(handle, mlx_closure_value_and_grad_new); + CHECK_LOAD(handle, mlx_closure_value_and_grad_free); + CHECK_LOAD(handle, mlx_closure_value_and_grad_new_func); + CHECK_LOAD(handle, mlx_closure_value_and_grad_new_func_payload); + CHECK_LOAD(handle, mlx_closure_value_and_grad_set); + CHECK_LOAD(handle, mlx_closure_value_and_grad_apply); + CHECK_LOAD(handle, mlx_closure_custom_new); + CHECK_LOAD(handle, mlx_closure_custom_free); + CHECK_LOAD(handle, mlx_closure_custom_new_func); + CHECK_LOAD(handle, mlx_closure_custom_new_func_payload); + CHECK_LOAD(handle, mlx_closure_custom_set); + CHECK_LOAD(handle, mlx_closure_custom_apply); + CHECK_LOAD(handle, mlx_closure_custom_jvp_new); + CHECK_LOAD(handle, mlx_closure_custom_jvp_free); + CHECK_LOAD(handle, mlx_closure_custom_jvp_new_func); + CHECK_LOAD(handle, mlx_closure_custom_jvp_new_func_payload); + CHECK_LOAD(handle, mlx_closure_custom_jvp_set); + CHECK_LOAD(handle, mlx_closure_custom_jvp_apply); + CHECK_LOAD(handle, mlx_closure_custom_vmap_new); + CHECK_LOAD(handle, mlx_closure_custom_vmap_free); + CHECK_LOAD(handle, mlx_closure_custom_vmap_new_func); + CHECK_LOAD(handle, mlx_closure_custom_vmap_new_func_payload); + CHECK_LOAD(handle, mlx_closure_custom_vmap_set); + CHECK_LOAD(handle, mlx_closure_custom_vmap_apply); + CHECK_LOAD(handle, mlx_compile); + CHECK_LOAD(handle, mlx_detail_compile); + CHECK_LOAD(handle, mlx_detail_compile_clear_cache); + CHECK_LOAD(handle, mlx_detail_compile_erase); + CHECK_LOAD(handle, mlx_disable_compile); + CHECK_LOAD(handle, mlx_enable_compile); + CHECK_LOAD(handle, mlx_set_compile_mode); + CHECK_LOAD(handle, mlx_device_new); + CHECK_LOAD(handle, mlx_device_new_type); + CHECK_LOAD(handle, mlx_device_free); + CHECK_LOAD(handle, mlx_device_set); + CHECK_LOAD(handle, mlx_device_tostring); + CHECK_LOAD(handle, mlx_device_equal); + CHECK_LOAD(handle, mlx_device_get_index); + CHECK_LOAD(handle, mlx_device_get_type); + CHECK_LOAD(handle, mlx_get_default_device); + CHECK_LOAD(handle, mlx_set_default_device); + CHECK_LOAD(handle, mlx_distributed_group_rank); + CHECK_LOAD(handle, mlx_distributed_group_size); + CHECK_LOAD(handle, mlx_distributed_group_split); + CHECK_LOAD(handle, mlx_distributed_is_available); + CHECK_LOAD(handle, mlx_distributed_init); + CHECK_LOAD(handle, mlx_distributed_all_gather); + CHECK_LOAD(handle, mlx_distributed_all_max); + CHECK_LOAD(handle, mlx_distributed_all_min); + CHECK_LOAD(handle, mlx_distributed_all_sum); + CHECK_LOAD(handle, mlx_distributed_recv); + CHECK_LOAD(handle, mlx_distributed_recv_like); + CHECK_LOAD(handle, mlx_distributed_send); + CHECK_LOAD(handle, mlx_distributed_sum_scatter); + CHECK_LOAD(handle, mlx_set_error_handler); + CHECK_LOAD(handle, _mlx_error); + CHECK_LOAD(handle, mlx_export_function); + CHECK_LOAD(handle, mlx_export_function_kwargs); + CHECK_LOAD(handle, mlx_function_exporter_new); + CHECK_LOAD(handle, mlx_function_exporter_free); + CHECK_LOAD(handle, mlx_function_exporter_apply); + CHECK_LOAD(handle, mlx_function_exporter_apply_kwargs); + CHECK_LOAD(handle, mlx_imported_function_new); + CHECK_LOAD(handle, mlx_imported_function_free); + CHECK_LOAD(handle, mlx_imported_function_apply); + CHECK_LOAD(handle, mlx_imported_function_apply_kwargs); + CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_new); + CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_free); + CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_add_output_arg); + CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_set_grid); + CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_set_thread_group); + CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_set_init_value); + CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_set_verbose); + CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_add_template_arg_dtype); + CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_add_template_arg_int); + CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_add_template_arg_bool); + CHECK_LOAD(handle, mlx_fast_cuda_kernel_new); + CHECK_LOAD(handle, mlx_fast_cuda_kernel_free); + CHECK_LOAD(handle, mlx_fast_cuda_kernel_apply); + CHECK_LOAD(handle, mlx_fast_layer_norm); + CHECK_LOAD(handle, mlx_fast_metal_kernel_config_new); + CHECK_LOAD(handle, mlx_fast_metal_kernel_config_free); + CHECK_LOAD(handle, mlx_fast_metal_kernel_config_add_output_arg); + CHECK_LOAD(handle, mlx_fast_metal_kernel_config_set_grid); + CHECK_LOAD(handle, mlx_fast_metal_kernel_config_set_thread_group); + CHECK_LOAD(handle, mlx_fast_metal_kernel_config_set_init_value); + CHECK_LOAD(handle, mlx_fast_metal_kernel_config_set_verbose); + CHECK_LOAD(handle, mlx_fast_metal_kernel_config_add_template_arg_dtype); + CHECK_LOAD(handle, mlx_fast_metal_kernel_config_add_template_arg_int); + CHECK_LOAD(handle, mlx_fast_metal_kernel_config_add_template_arg_bool); + CHECK_LOAD(handle, mlx_fast_metal_kernel_new); + CHECK_LOAD(handle, mlx_fast_metal_kernel_free); + CHECK_LOAD(handle, mlx_fast_metal_kernel_apply); + CHECK_LOAD(handle, mlx_fast_rms_norm); + CHECK_LOAD(handle, mlx_fast_rope); + CHECK_LOAD(handle, mlx_fast_scaled_dot_product_attention); + CHECK_LOAD(handle, mlx_fft_fft); + CHECK_LOAD(handle, mlx_fft_fft2); + CHECK_LOAD(handle, mlx_fft_fftn); + CHECK_LOAD(handle, mlx_fft_fftshift); + CHECK_LOAD(handle, mlx_fft_ifft); + CHECK_LOAD(handle, mlx_fft_ifft2); + CHECK_LOAD(handle, mlx_fft_ifftn); + CHECK_LOAD(handle, mlx_fft_ifftshift); + CHECK_LOAD(handle, mlx_fft_irfft); + CHECK_LOAD(handle, mlx_fft_irfft2); + CHECK_LOAD(handle, mlx_fft_irfftn); + CHECK_LOAD(handle, mlx_fft_rfft); + CHECK_LOAD(handle, mlx_fft_rfft2); + CHECK_LOAD(handle, mlx_fft_rfftn); + CHECK_LOAD(handle, mlx_io_reader_new); + CHECK_LOAD(handle, mlx_io_reader_descriptor); + CHECK_LOAD(handle, mlx_io_reader_tostring); + CHECK_LOAD(handle, mlx_io_reader_free); + CHECK_LOAD(handle, mlx_io_writer_new); + CHECK_LOAD(handle, mlx_io_writer_descriptor); + CHECK_LOAD(handle, mlx_io_writer_tostring); + CHECK_LOAD(handle, mlx_io_writer_free); + CHECK_LOAD(handle, mlx_load_reader); + CHECK_LOAD(handle, mlx_load); + CHECK_LOAD(handle, mlx_load_safetensors_reader); + CHECK_LOAD(handle, mlx_load_safetensors); + CHECK_LOAD(handle, mlx_save_writer); + CHECK_LOAD(handle, mlx_save); + CHECK_LOAD(handle, mlx_save_safetensors_writer); + CHECK_LOAD(handle, mlx_save_safetensors); + CHECK_LOAD(handle, mlx_linalg_cholesky); + CHECK_LOAD(handle, mlx_linalg_cholesky_inv); + CHECK_LOAD(handle, mlx_linalg_cross); + CHECK_LOAD(handle, mlx_linalg_eig); + CHECK_LOAD(handle, mlx_linalg_eigh); + CHECK_LOAD(handle, mlx_linalg_eigvals); + CHECK_LOAD(handle, mlx_linalg_eigvalsh); + CHECK_LOAD(handle, mlx_linalg_inv); + CHECK_LOAD(handle, mlx_linalg_lu); + CHECK_LOAD(handle, mlx_linalg_lu_factor); + CHECK_LOAD(handle, mlx_linalg_norm); + CHECK_LOAD(handle, mlx_linalg_norm_matrix); + CHECK_LOAD(handle, mlx_linalg_norm_l2); + CHECK_LOAD(handle, mlx_linalg_pinv); + CHECK_LOAD(handle, mlx_linalg_qr); + CHECK_LOAD(handle, mlx_linalg_solve); + CHECK_LOAD(handle, mlx_linalg_solve_triangular); + CHECK_LOAD(handle, mlx_linalg_svd); + CHECK_LOAD(handle, mlx_linalg_tri_inv); + CHECK_LOAD(handle, mlx_map_string_to_array_new); + CHECK_LOAD(handle, mlx_map_string_to_array_set); + CHECK_LOAD(handle, mlx_map_string_to_array_free); + CHECK_LOAD(handle, mlx_map_string_to_array_insert); + CHECK_LOAD(handle, mlx_map_string_to_array_get); + CHECK_LOAD(handle, mlx_map_string_to_array_iterator_new); + CHECK_LOAD(handle, mlx_map_string_to_array_iterator_free); + CHECK_LOAD(handle, mlx_map_string_to_array_iterator_next); + CHECK_LOAD(handle, mlx_map_string_to_string_new); + CHECK_LOAD(handle, mlx_map_string_to_string_set); + CHECK_LOAD(handle, mlx_map_string_to_string_free); + CHECK_LOAD(handle, mlx_map_string_to_string_insert); + CHECK_LOAD(handle, mlx_map_string_to_string_get); + CHECK_LOAD(handle, mlx_map_string_to_string_iterator_new); + CHECK_LOAD(handle, mlx_map_string_to_string_iterator_free); + CHECK_LOAD(handle, mlx_map_string_to_string_iterator_next); + CHECK_LOAD(handle, mlx_clear_cache); + CHECK_LOAD(handle, mlx_get_active_memory); + CHECK_LOAD(handle, mlx_get_cache_memory); + CHECK_LOAD(handle, mlx_get_memory_limit); + CHECK_LOAD(handle, mlx_get_peak_memory); + CHECK_LOAD(handle, mlx_reset_peak_memory); + CHECK_LOAD(handle, mlx_set_cache_limit); + CHECK_LOAD(handle, mlx_set_memory_limit); + CHECK_LOAD(handle, mlx_set_wired_limit); + CHECK_LOAD(handle, mlx_metal_device_info); + CHECK_LOAD(handle, mlx_metal_is_available); + CHECK_LOAD(handle, mlx_metal_start_capture); + CHECK_LOAD(handle, mlx_metal_stop_capture); + CHECK_LOAD(handle, mlx_abs); + CHECK_LOAD(handle, mlx_add); + CHECK_LOAD(handle, mlx_addmm); + CHECK_LOAD(handle, mlx_all_axes); + CHECK_LOAD(handle, mlx_all_axis); + CHECK_LOAD(handle, mlx_all); + CHECK_LOAD(handle, mlx_allclose); + CHECK_LOAD(handle, mlx_any_axes); + CHECK_LOAD(handle, mlx_any_axis); + CHECK_LOAD(handle, mlx_any); + CHECK_LOAD(handle, mlx_arange); + CHECK_LOAD(handle, mlx_arccos); + CHECK_LOAD(handle, mlx_arccosh); + CHECK_LOAD(handle, mlx_arcsin); + CHECK_LOAD(handle, mlx_arcsinh); + CHECK_LOAD(handle, mlx_arctan); + CHECK_LOAD(handle, mlx_arctan2); + CHECK_LOAD(handle, mlx_arctanh); + CHECK_LOAD(handle, mlx_argmax_axis); + CHECK_LOAD(handle, mlx_argmax); + CHECK_LOAD(handle, mlx_argmin_axis); + CHECK_LOAD(handle, mlx_argmin); + CHECK_LOAD(handle, mlx_argpartition_axis); + CHECK_LOAD(handle, mlx_argpartition); + CHECK_LOAD(handle, mlx_argsort_axis); + CHECK_LOAD(handle, mlx_argsort); + CHECK_LOAD(handle, mlx_array_equal); + CHECK_LOAD(handle, mlx_as_strided); + CHECK_LOAD(handle, mlx_astype); + CHECK_LOAD(handle, mlx_atleast_1d); + CHECK_LOAD(handle, mlx_atleast_2d); + CHECK_LOAD(handle, mlx_atleast_3d); + CHECK_LOAD(handle, mlx_bitwise_and); + CHECK_LOAD(handle, mlx_bitwise_invert); + CHECK_LOAD(handle, mlx_bitwise_or); + CHECK_LOAD(handle, mlx_bitwise_xor); + CHECK_LOAD(handle, mlx_block_masked_mm); + CHECK_LOAD(handle, mlx_broadcast_arrays); + CHECK_LOAD(handle, mlx_broadcast_to); + CHECK_LOAD(handle, mlx_ceil); + CHECK_LOAD(handle, mlx_clip); + CHECK_LOAD(handle, mlx_concatenate_axis); + CHECK_LOAD(handle, mlx_concatenate); + CHECK_LOAD(handle, mlx_conjugate); + CHECK_LOAD(handle, mlx_contiguous); + CHECK_LOAD(handle, mlx_conv1d); + CHECK_LOAD(handle, mlx_conv2d); + CHECK_LOAD(handle, mlx_conv3d); + CHECK_LOAD(handle, mlx_conv_general); + CHECK_LOAD(handle, mlx_conv_transpose1d); + CHECK_LOAD(handle, mlx_conv_transpose2d); + CHECK_LOAD(handle, mlx_conv_transpose3d); + CHECK_LOAD(handle, mlx_copy); + CHECK_LOAD(handle, mlx_cos); + CHECK_LOAD(handle, mlx_cosh); + CHECK_LOAD(handle, mlx_cummax); + CHECK_LOAD(handle, mlx_cummin); + CHECK_LOAD(handle, mlx_cumprod); + CHECK_LOAD(handle, mlx_cumsum); + CHECK_LOAD(handle, mlx_degrees); + CHECK_LOAD(handle, mlx_depends); + CHECK_LOAD(handle, mlx_dequantize); + CHECK_LOAD(handle, mlx_diag); + CHECK_LOAD(handle, mlx_diagonal); + CHECK_LOAD(handle, mlx_divide); + CHECK_LOAD(handle, mlx_divmod); + CHECK_LOAD(handle, mlx_einsum); + CHECK_LOAD(handle, mlx_equal); + CHECK_LOAD(handle, mlx_erf); + CHECK_LOAD(handle, mlx_erfinv); + CHECK_LOAD(handle, mlx_exp); + CHECK_LOAD(handle, mlx_expand_dims_axes); + CHECK_LOAD(handle, mlx_expand_dims); + CHECK_LOAD(handle, mlx_expm1); + CHECK_LOAD(handle, mlx_eye); + CHECK_LOAD(handle, mlx_flatten); + CHECK_LOAD(handle, mlx_floor); + CHECK_LOAD(handle, mlx_floor_divide); + CHECK_LOAD(handle, mlx_from_fp8); + CHECK_LOAD(handle, mlx_full); + CHECK_LOAD(handle, mlx_full_like); + CHECK_LOAD(handle, mlx_gather); + CHECK_LOAD(handle, mlx_gather_mm); + CHECK_LOAD(handle, mlx_gather_qmm); + CHECK_LOAD(handle, mlx_greater); + CHECK_LOAD(handle, mlx_greater_equal); + CHECK_LOAD(handle, mlx_hadamard_transform); + CHECK_LOAD(handle, mlx_identity); + CHECK_LOAD(handle, mlx_imag); + CHECK_LOAD(handle, mlx_inner); + CHECK_LOAD(handle, mlx_isclose); + CHECK_LOAD(handle, mlx_isfinite); + CHECK_LOAD(handle, mlx_isinf); + CHECK_LOAD(handle, mlx_isnan); + CHECK_LOAD(handle, mlx_isneginf); + CHECK_LOAD(handle, mlx_isposinf); + CHECK_LOAD(handle, mlx_kron); + CHECK_LOAD(handle, mlx_left_shift); + CHECK_LOAD(handle, mlx_less); + CHECK_LOAD(handle, mlx_less_equal); + CHECK_LOAD(handle, mlx_linspace); + CHECK_LOAD(handle, mlx_log); + CHECK_LOAD(handle, mlx_log10); + CHECK_LOAD(handle, mlx_log1p); + CHECK_LOAD(handle, mlx_log2); + CHECK_LOAD(handle, mlx_logaddexp); + CHECK_LOAD(handle, mlx_logcumsumexp); + CHECK_LOAD(handle, mlx_logical_and); + CHECK_LOAD(handle, mlx_logical_not); + CHECK_LOAD(handle, mlx_logical_or); + CHECK_LOAD(handle, mlx_logsumexp_axes); + CHECK_LOAD(handle, mlx_logsumexp_axis); + CHECK_LOAD(handle, mlx_logsumexp); + CHECK_LOAD(handle, mlx_masked_scatter); + CHECK_LOAD(handle, mlx_matmul); + CHECK_LOAD(handle, mlx_max_axes); + CHECK_LOAD(handle, mlx_max_axis); + CHECK_LOAD(handle, mlx_max); + CHECK_LOAD(handle, mlx_maximum); + CHECK_LOAD(handle, mlx_mean_axes); + CHECK_LOAD(handle, mlx_mean_axis); + CHECK_LOAD(handle, mlx_mean); + CHECK_LOAD(handle, mlx_median); + CHECK_LOAD(handle, mlx_meshgrid); + CHECK_LOAD(handle, mlx_min_axes); + CHECK_LOAD(handle, mlx_min_axis); + CHECK_LOAD(handle, mlx_min); + CHECK_LOAD(handle, mlx_minimum); + CHECK_LOAD(handle, mlx_moveaxis); + CHECK_LOAD(handle, mlx_multiply); + CHECK_LOAD(handle, mlx_nan_to_num); + CHECK_LOAD(handle, mlx_negative); + CHECK_LOAD(handle, mlx_not_equal); + CHECK_LOAD(handle, mlx_number_of_elements); + CHECK_LOAD(handle, mlx_ones); + CHECK_LOAD(handle, mlx_ones_like); + CHECK_LOAD(handle, mlx_outer); + CHECK_LOAD(handle, mlx_pad); + CHECK_LOAD(handle, mlx_pad_symmetric); + CHECK_LOAD(handle, mlx_partition_axis); + CHECK_LOAD(handle, mlx_partition); + CHECK_LOAD(handle, mlx_power); + CHECK_LOAD(handle, mlx_prod_axes); + CHECK_LOAD(handle, mlx_prod_axis); + CHECK_LOAD(handle, mlx_prod); + CHECK_LOAD(handle, mlx_put_along_axis); + CHECK_LOAD(handle, mlx_quantize); + CHECK_LOAD(handle, mlx_quantized_matmul); + CHECK_LOAD(handle, mlx_radians); + CHECK_LOAD(handle, mlx_real); + CHECK_LOAD(handle, mlx_reciprocal); + CHECK_LOAD(handle, mlx_remainder); + CHECK_LOAD(handle, mlx_repeat_axis); + CHECK_LOAD(handle, mlx_repeat); + CHECK_LOAD(handle, mlx_reshape); + CHECK_LOAD(handle, mlx_right_shift); + CHECK_LOAD(handle, mlx_roll_axis); + CHECK_LOAD(handle, mlx_roll_axes); + CHECK_LOAD(handle, mlx_roll); + CHECK_LOAD(handle, mlx_round); + CHECK_LOAD(handle, mlx_rsqrt); + CHECK_LOAD(handle, mlx_scatter); + CHECK_LOAD(handle, mlx_scatter_add); + CHECK_LOAD(handle, mlx_scatter_add_axis); + CHECK_LOAD(handle, mlx_scatter_max); + CHECK_LOAD(handle, mlx_scatter_min); + CHECK_LOAD(handle, mlx_scatter_prod); + CHECK_LOAD(handle, mlx_segmented_mm); + CHECK_LOAD(handle, mlx_sigmoid); + CHECK_LOAD(handle, mlx_sign); + CHECK_LOAD(handle, mlx_sin); + CHECK_LOAD(handle, mlx_sinh); + CHECK_LOAD(handle, mlx_slice); + CHECK_LOAD(handle, mlx_slice_dynamic); + CHECK_LOAD(handle, mlx_slice_update); + CHECK_LOAD(handle, mlx_slice_update_dynamic); + CHECK_LOAD(handle, mlx_softmax_axes); + CHECK_LOAD(handle, mlx_softmax_axis); + CHECK_LOAD(handle, mlx_softmax); + CHECK_LOAD(handle, mlx_sort_axis); + CHECK_LOAD(handle, mlx_sort); + CHECK_LOAD(handle, mlx_split); + CHECK_LOAD(handle, mlx_split_sections); + CHECK_LOAD(handle, mlx_sqrt); + CHECK_LOAD(handle, mlx_square); + CHECK_LOAD(handle, mlx_squeeze_axes); + CHECK_LOAD(handle, mlx_squeeze_axis); + CHECK_LOAD(handle, mlx_squeeze); + CHECK_LOAD(handle, mlx_stack_axis); + CHECK_LOAD(handle, mlx_stack); + CHECK_LOAD(handle, mlx_std_axes); + CHECK_LOAD(handle, mlx_std_axis); + CHECK_LOAD(handle, mlx_std); + CHECK_LOAD(handle, mlx_stop_gradient); + CHECK_LOAD(handle, mlx_subtract); + CHECK_LOAD(handle, mlx_sum_axes); + CHECK_LOAD(handle, mlx_sum_axis); + CHECK_LOAD(handle, mlx_sum); + CHECK_LOAD(handle, mlx_swapaxes); + CHECK_LOAD(handle, mlx_take_axis); + CHECK_LOAD(handle, mlx_take); + CHECK_LOAD(handle, mlx_take_along_axis); + CHECK_LOAD(handle, mlx_tan); + CHECK_LOAD(handle, mlx_tanh); + CHECK_LOAD(handle, mlx_tensordot); + CHECK_LOAD(handle, mlx_tensordot_axis); + CHECK_LOAD(handle, mlx_tile); + CHECK_LOAD(handle, mlx_to_fp8); + CHECK_LOAD(handle, mlx_topk_axis); + CHECK_LOAD(handle, mlx_topk); + CHECK_LOAD(handle, mlx_trace); + CHECK_LOAD(handle, mlx_transpose_axes); + CHECK_LOAD(handle, mlx_transpose); + CHECK_LOAD(handle, mlx_tri); + CHECK_LOAD(handle, mlx_tril); + CHECK_LOAD(handle, mlx_triu); + CHECK_LOAD(handle, mlx_unflatten); + CHECK_LOAD(handle, mlx_var_axes); + CHECK_LOAD(handle, mlx_var_axis); + CHECK_LOAD(handle, mlx_var); + CHECK_LOAD(handle, mlx_view); + CHECK_LOAD(handle, mlx_where); + CHECK_LOAD(handle, mlx_zeros); + CHECK_LOAD(handle, mlx_zeros_like); + CHECK_LOAD(handle, mlx_random_bernoulli); + CHECK_LOAD(handle, mlx_random_bits); + CHECK_LOAD(handle, mlx_random_categorical_shape); + CHECK_LOAD(handle, mlx_random_categorical_num_samples); + CHECK_LOAD(handle, mlx_random_categorical); + CHECK_LOAD(handle, mlx_random_gumbel); + CHECK_LOAD(handle, mlx_random_key); + CHECK_LOAD(handle, mlx_random_laplace); + CHECK_LOAD(handle, mlx_random_multivariate_normal); + CHECK_LOAD(handle, mlx_random_normal_broadcast); + CHECK_LOAD(handle, mlx_random_normal); + CHECK_LOAD(handle, mlx_random_permutation); + CHECK_LOAD(handle, mlx_random_permutation_arange); + CHECK_LOAD(handle, mlx_random_randint); + CHECK_LOAD(handle, mlx_random_seed); + CHECK_LOAD(handle, mlx_random_split_num); + CHECK_LOAD(handle, mlx_random_split); + CHECK_LOAD(handle, mlx_random_truncated_normal); + CHECK_LOAD(handle, mlx_random_uniform); + CHECK_LOAD(handle, mlx_stream_new); + CHECK_LOAD(handle, mlx_stream_new_device); + CHECK_LOAD(handle, mlx_stream_set); + CHECK_LOAD(handle, mlx_stream_free); + CHECK_LOAD(handle, mlx_stream_tostring); + CHECK_LOAD(handle, mlx_stream_equal); + CHECK_LOAD(handle, mlx_stream_get_device); + CHECK_LOAD(handle, mlx_stream_get_index); + CHECK_LOAD(handle, mlx_synchronize); + CHECK_LOAD(handle, mlx_get_default_stream); + CHECK_LOAD(handle, mlx_set_default_stream); + CHECK_LOAD(handle, mlx_default_cpu_stream_new); + CHECK_LOAD(handle, mlx_default_gpu_stream_new); + CHECK_LOAD(handle, mlx_string_new); + CHECK_LOAD(handle, mlx_string_new_data); + CHECK_LOAD(handle, mlx_string_set); + CHECK_LOAD(handle, mlx_string_data); + CHECK_LOAD(handle, mlx_string_free); + CHECK_LOAD(handle, mlx_detail_vmap_replace); + CHECK_LOAD(handle, mlx_detail_vmap_trace); + CHECK_LOAD(handle, mlx_async_eval); + CHECK_LOAD(handle, mlx_checkpoint); + CHECK_LOAD(handle, mlx_custom_function); + CHECK_LOAD(handle, mlx_custom_vjp); + CHECK_LOAD(handle, mlx_eval); + CHECK_LOAD(handle, mlx_jvp); + CHECK_LOAD(handle, mlx_value_and_grad); + CHECK_LOAD(handle, mlx_vjp); + CHECK_LOAD(handle, mlx_vector_array_new); + CHECK_LOAD(handle, mlx_vector_array_set); + CHECK_LOAD(handle, mlx_vector_array_free); + CHECK_LOAD(handle, mlx_vector_array_new_data); + CHECK_LOAD(handle, mlx_vector_array_new_value); + CHECK_LOAD(handle, mlx_vector_array_set_data); + CHECK_LOAD(handle, mlx_vector_array_set_value); + CHECK_LOAD(handle, mlx_vector_array_append_data); + CHECK_LOAD(handle, mlx_vector_array_append_value); + CHECK_LOAD(handle, mlx_vector_array_size); + CHECK_LOAD(handle, mlx_vector_array_get); + CHECK_LOAD(handle, mlx_vector_vector_array_new); + CHECK_LOAD(handle, mlx_vector_vector_array_set); + CHECK_LOAD(handle, mlx_vector_vector_array_free); + CHECK_LOAD(handle, mlx_vector_vector_array_new_data); + CHECK_LOAD(handle, mlx_vector_vector_array_new_value); + CHECK_LOAD(handle, mlx_vector_vector_array_set_data); + CHECK_LOAD(handle, mlx_vector_vector_array_set_value); + CHECK_LOAD(handle, mlx_vector_vector_array_append_data); + CHECK_LOAD(handle, mlx_vector_vector_array_append_value); + CHECK_LOAD(handle, mlx_vector_vector_array_size); + CHECK_LOAD(handle, mlx_vector_vector_array_get); + CHECK_LOAD(handle, mlx_vector_int_new); + CHECK_LOAD(handle, mlx_vector_int_set); + CHECK_LOAD(handle, mlx_vector_int_free); + CHECK_LOAD(handle, mlx_vector_int_new_data); + CHECK_LOAD(handle, mlx_vector_int_new_value); + CHECK_LOAD(handle, mlx_vector_int_set_data); + CHECK_LOAD(handle, mlx_vector_int_set_value); + CHECK_LOAD(handle, mlx_vector_int_append_data); + CHECK_LOAD(handle, mlx_vector_int_append_value); + CHECK_LOAD(handle, mlx_vector_int_size); + CHECK_LOAD(handle, mlx_vector_int_get); + CHECK_LOAD(handle, mlx_vector_string_new); + CHECK_LOAD(handle, mlx_vector_string_set); + CHECK_LOAD(handle, mlx_vector_string_free); + CHECK_LOAD(handle, mlx_vector_string_new_data); + CHECK_LOAD(handle, mlx_vector_string_new_value); + CHECK_LOAD(handle, mlx_vector_string_set_data); + CHECK_LOAD(handle, mlx_vector_string_set_value); + CHECK_LOAD(handle, mlx_vector_string_append_data); + CHECK_LOAD(handle, mlx_vector_string_append_value); + CHECK_LOAD(handle, mlx_vector_string_size); + CHECK_LOAD(handle, mlx_vector_string_get); + CHECK_LOAD(handle, mlx_version); + return 0; +} diff --git a/x/mlxrunner/mlx/generated.h b/x/mlxrunner/mlx/generated.h new file mode 100644 index 000000000..c88946d9f --- /dev/null +++ b/x/mlxrunner/mlx/generated.h @@ -0,0 +1,7135 @@ +// This code is auto-generated; DO NOT EDIT. + +#ifndef MLX_GENERATED_H +#define MLX_GENERATED_H + +#include "dynamic.h" + +#define mlx_dtype_size mlx_dtype_size_mlx_gen_orig_ +#define mlx_array_tostring mlx_array_tostring_mlx_gen_orig_ +#define mlx_array_new mlx_array_new_mlx_gen_orig_ +#define mlx_array_free mlx_array_free_mlx_gen_orig_ +#define mlx_array_new_bool mlx_array_new_bool_mlx_gen_orig_ +#define mlx_array_new_int mlx_array_new_int_mlx_gen_orig_ +#define mlx_array_new_float32 mlx_array_new_float32_mlx_gen_orig_ +#define mlx_array_new_float mlx_array_new_float_mlx_gen_orig_ +#define mlx_array_new_float64 mlx_array_new_float64_mlx_gen_orig_ +#define mlx_array_new_double mlx_array_new_double_mlx_gen_orig_ +#define mlx_array_new_complex mlx_array_new_complex_mlx_gen_orig_ +#define mlx_array_new_data mlx_array_new_data_mlx_gen_orig_ +#define mlx_array_set mlx_array_set_mlx_gen_orig_ +#define mlx_array_set_bool mlx_array_set_bool_mlx_gen_orig_ +#define mlx_array_set_int mlx_array_set_int_mlx_gen_orig_ +#define mlx_array_set_float32 mlx_array_set_float32_mlx_gen_orig_ +#define mlx_array_set_float mlx_array_set_float_mlx_gen_orig_ +#define mlx_array_set_float64 mlx_array_set_float64_mlx_gen_orig_ +#define mlx_array_set_double mlx_array_set_double_mlx_gen_orig_ +#define mlx_array_set_complex mlx_array_set_complex_mlx_gen_orig_ +#define mlx_array_set_data mlx_array_set_data_mlx_gen_orig_ +#define mlx_array_itemsize mlx_array_itemsize_mlx_gen_orig_ +#define mlx_array_size mlx_array_size_mlx_gen_orig_ +#define mlx_array_nbytes mlx_array_nbytes_mlx_gen_orig_ +#define mlx_array_ndim mlx_array_ndim_mlx_gen_orig_ +#define mlx_array_shape mlx_array_shape_mlx_gen_orig_ +#define mlx_array_strides mlx_array_strides_mlx_gen_orig_ +#define mlx_array_dim mlx_array_dim_mlx_gen_orig_ +#define mlx_array_dtype mlx_array_dtype_mlx_gen_orig_ +#define mlx_array_eval mlx_array_eval_mlx_gen_orig_ +#define mlx_array_item_bool mlx_array_item_bool_mlx_gen_orig_ +#define mlx_array_item_uint8 mlx_array_item_uint8_mlx_gen_orig_ +#define mlx_array_item_uint16 mlx_array_item_uint16_mlx_gen_orig_ +#define mlx_array_item_uint32 mlx_array_item_uint32_mlx_gen_orig_ +#define mlx_array_item_uint64 mlx_array_item_uint64_mlx_gen_orig_ +#define mlx_array_item_int8 mlx_array_item_int8_mlx_gen_orig_ +#define mlx_array_item_int16 mlx_array_item_int16_mlx_gen_orig_ +#define mlx_array_item_int32 mlx_array_item_int32_mlx_gen_orig_ +#define mlx_array_item_int64 mlx_array_item_int64_mlx_gen_orig_ +#define mlx_array_item_float32 mlx_array_item_float32_mlx_gen_orig_ +#define mlx_array_item_float64 mlx_array_item_float64_mlx_gen_orig_ +#define mlx_array_item_complex64 mlx_array_item_complex64_mlx_gen_orig_ +#define mlx_array_item_float16 mlx_array_item_float16_mlx_gen_orig_ +#define mlx_array_item_bfloat16 mlx_array_item_bfloat16_mlx_gen_orig_ +#define mlx_array_data_bool mlx_array_data_bool_mlx_gen_orig_ +#define mlx_array_data_uint8 mlx_array_data_uint8_mlx_gen_orig_ +#define mlx_array_data_uint16 mlx_array_data_uint16_mlx_gen_orig_ +#define mlx_array_data_uint32 mlx_array_data_uint32_mlx_gen_orig_ +#define mlx_array_data_uint64 mlx_array_data_uint64_mlx_gen_orig_ +#define mlx_array_data_int8 mlx_array_data_int8_mlx_gen_orig_ +#define mlx_array_data_int16 mlx_array_data_int16_mlx_gen_orig_ +#define mlx_array_data_int32 mlx_array_data_int32_mlx_gen_orig_ +#define mlx_array_data_int64 mlx_array_data_int64_mlx_gen_orig_ +#define mlx_array_data_float32 mlx_array_data_float32_mlx_gen_orig_ +#define mlx_array_data_float64 mlx_array_data_float64_mlx_gen_orig_ +#define mlx_array_data_complex64 mlx_array_data_complex64_mlx_gen_orig_ +#define mlx_array_data_float16 mlx_array_data_float16_mlx_gen_orig_ +#define mlx_array_data_bfloat16 mlx_array_data_bfloat16_mlx_gen_orig_ +#define _mlx_array_is_available _mlx_array_is_available_mlx_gen_orig_ +#define _mlx_array_wait _mlx_array_wait_mlx_gen_orig_ +#define _mlx_array_is_contiguous _mlx_array_is_contiguous_mlx_gen_orig_ +#define _mlx_array_is_row_contiguous _mlx_array_is_row_contiguous_mlx_gen_orig_ +#define _mlx_array_is_col_contiguous _mlx_array_is_col_contiguous_mlx_gen_orig_ +#define mlx_closure_new mlx_closure_new_mlx_gen_orig_ +#define mlx_closure_free mlx_closure_free_mlx_gen_orig_ +#define mlx_closure_new_func mlx_closure_new_func_mlx_gen_orig_ +#define mlx_closure_new_func_payload mlx_closure_new_func_payload_mlx_gen_orig_ +#define mlx_closure_set mlx_closure_set_mlx_gen_orig_ +#define mlx_closure_apply mlx_closure_apply_mlx_gen_orig_ +#define mlx_closure_new_unary mlx_closure_new_unary_mlx_gen_orig_ +#define mlx_closure_kwargs_new mlx_closure_kwargs_new_mlx_gen_orig_ +#define mlx_closure_kwargs_free mlx_closure_kwargs_free_mlx_gen_orig_ +#define mlx_closure_kwargs_new_func mlx_closure_kwargs_new_func_mlx_gen_orig_ +#define mlx_closure_kwargs_new_func_payload mlx_closure_kwargs_new_func_payload_mlx_gen_orig_ +#define mlx_closure_kwargs_set mlx_closure_kwargs_set_mlx_gen_orig_ +#define mlx_closure_kwargs_apply mlx_closure_kwargs_apply_mlx_gen_orig_ +#define mlx_closure_value_and_grad_new mlx_closure_value_and_grad_new_mlx_gen_orig_ +#define mlx_closure_value_and_grad_free mlx_closure_value_and_grad_free_mlx_gen_orig_ +#define mlx_closure_value_and_grad_new_func mlx_closure_value_and_grad_new_func_mlx_gen_orig_ +#define mlx_closure_value_and_grad_new_func_payload mlx_closure_value_and_grad_new_func_payload_mlx_gen_orig_ +#define mlx_closure_value_and_grad_set mlx_closure_value_and_grad_set_mlx_gen_orig_ +#define mlx_closure_value_and_grad_apply mlx_closure_value_and_grad_apply_mlx_gen_orig_ +#define mlx_closure_custom_new mlx_closure_custom_new_mlx_gen_orig_ +#define mlx_closure_custom_free mlx_closure_custom_free_mlx_gen_orig_ +#define mlx_closure_custom_new_func mlx_closure_custom_new_func_mlx_gen_orig_ +#define mlx_closure_custom_new_func_payload mlx_closure_custom_new_func_payload_mlx_gen_orig_ +#define mlx_closure_custom_set mlx_closure_custom_set_mlx_gen_orig_ +#define mlx_closure_custom_apply mlx_closure_custom_apply_mlx_gen_orig_ +#define mlx_closure_custom_jvp_new mlx_closure_custom_jvp_new_mlx_gen_orig_ +#define mlx_closure_custom_jvp_free mlx_closure_custom_jvp_free_mlx_gen_orig_ +#define mlx_closure_custom_jvp_new_func mlx_closure_custom_jvp_new_func_mlx_gen_orig_ +#define mlx_closure_custom_jvp_new_func_payload mlx_closure_custom_jvp_new_func_payload_mlx_gen_orig_ +#define mlx_closure_custom_jvp_set mlx_closure_custom_jvp_set_mlx_gen_orig_ +#define mlx_closure_custom_jvp_apply mlx_closure_custom_jvp_apply_mlx_gen_orig_ +#define mlx_closure_custom_vmap_new mlx_closure_custom_vmap_new_mlx_gen_orig_ +#define mlx_closure_custom_vmap_free mlx_closure_custom_vmap_free_mlx_gen_orig_ +#define mlx_closure_custom_vmap_new_func mlx_closure_custom_vmap_new_func_mlx_gen_orig_ +#define mlx_closure_custom_vmap_new_func_payload mlx_closure_custom_vmap_new_func_payload_mlx_gen_orig_ +#define mlx_closure_custom_vmap_set mlx_closure_custom_vmap_set_mlx_gen_orig_ +#define mlx_closure_custom_vmap_apply mlx_closure_custom_vmap_apply_mlx_gen_orig_ +#define mlx_compile mlx_compile_mlx_gen_orig_ +#define mlx_detail_compile mlx_detail_compile_mlx_gen_orig_ +#define mlx_detail_compile_clear_cache mlx_detail_compile_clear_cache_mlx_gen_orig_ +#define mlx_detail_compile_erase mlx_detail_compile_erase_mlx_gen_orig_ +#define mlx_disable_compile mlx_disable_compile_mlx_gen_orig_ +#define mlx_enable_compile mlx_enable_compile_mlx_gen_orig_ +#define mlx_set_compile_mode mlx_set_compile_mode_mlx_gen_orig_ +#define mlx_device_new mlx_device_new_mlx_gen_orig_ +#define mlx_device_new_type mlx_device_new_type_mlx_gen_orig_ +#define mlx_device_free mlx_device_free_mlx_gen_orig_ +#define mlx_device_set mlx_device_set_mlx_gen_orig_ +#define mlx_device_tostring mlx_device_tostring_mlx_gen_orig_ +#define mlx_device_equal mlx_device_equal_mlx_gen_orig_ +#define mlx_device_get_index mlx_device_get_index_mlx_gen_orig_ +#define mlx_device_get_type mlx_device_get_type_mlx_gen_orig_ +#define mlx_get_default_device mlx_get_default_device_mlx_gen_orig_ +#define mlx_set_default_device mlx_set_default_device_mlx_gen_orig_ +#define mlx_distributed_group_rank mlx_distributed_group_rank_mlx_gen_orig_ +#define mlx_distributed_group_size mlx_distributed_group_size_mlx_gen_orig_ +#define mlx_distributed_group_split mlx_distributed_group_split_mlx_gen_orig_ +#define mlx_distributed_is_available mlx_distributed_is_available_mlx_gen_orig_ +#define mlx_distributed_init mlx_distributed_init_mlx_gen_orig_ +#define mlx_distributed_all_gather mlx_distributed_all_gather_mlx_gen_orig_ +#define mlx_distributed_all_max mlx_distributed_all_max_mlx_gen_orig_ +#define mlx_distributed_all_min mlx_distributed_all_min_mlx_gen_orig_ +#define mlx_distributed_all_sum mlx_distributed_all_sum_mlx_gen_orig_ +#define mlx_distributed_recv mlx_distributed_recv_mlx_gen_orig_ +#define mlx_distributed_recv_like mlx_distributed_recv_like_mlx_gen_orig_ +#define mlx_distributed_send mlx_distributed_send_mlx_gen_orig_ +#define mlx_distributed_sum_scatter mlx_distributed_sum_scatter_mlx_gen_orig_ +#define mlx_set_error_handler mlx_set_error_handler_mlx_gen_orig_ +#define _mlx_error _mlx_error_mlx_gen_orig_ +#define mlx_export_function mlx_export_function_mlx_gen_orig_ +#define mlx_export_function_kwargs mlx_export_function_kwargs_mlx_gen_orig_ +#define mlx_function_exporter_new mlx_function_exporter_new_mlx_gen_orig_ +#define mlx_function_exporter_free mlx_function_exporter_free_mlx_gen_orig_ +#define mlx_function_exporter_apply mlx_function_exporter_apply_mlx_gen_orig_ +#define mlx_function_exporter_apply_kwargs mlx_function_exporter_apply_kwargs_mlx_gen_orig_ +#define mlx_imported_function_new mlx_imported_function_new_mlx_gen_orig_ +#define mlx_imported_function_free mlx_imported_function_free_mlx_gen_orig_ +#define mlx_imported_function_apply mlx_imported_function_apply_mlx_gen_orig_ +#define mlx_imported_function_apply_kwargs mlx_imported_function_apply_kwargs_mlx_gen_orig_ +#define mlx_fast_cuda_kernel_config_new mlx_fast_cuda_kernel_config_new_mlx_gen_orig_ +#define mlx_fast_cuda_kernel_config_free mlx_fast_cuda_kernel_config_free_mlx_gen_orig_ +#define mlx_fast_cuda_kernel_config_add_output_arg mlx_fast_cuda_kernel_config_add_output_arg_mlx_gen_orig_ +#define mlx_fast_cuda_kernel_config_set_grid mlx_fast_cuda_kernel_config_set_grid_mlx_gen_orig_ +#define mlx_fast_cuda_kernel_config_set_thread_group mlx_fast_cuda_kernel_config_set_thread_group_mlx_gen_orig_ +#define mlx_fast_cuda_kernel_config_set_init_value mlx_fast_cuda_kernel_config_set_init_value_mlx_gen_orig_ +#define mlx_fast_cuda_kernel_config_set_verbose mlx_fast_cuda_kernel_config_set_verbose_mlx_gen_orig_ +#define mlx_fast_cuda_kernel_config_add_template_arg_dtype mlx_fast_cuda_kernel_config_add_template_arg_dtype_mlx_gen_orig_ +#define mlx_fast_cuda_kernel_config_add_template_arg_int mlx_fast_cuda_kernel_config_add_template_arg_int_mlx_gen_orig_ +#define mlx_fast_cuda_kernel_config_add_template_arg_bool mlx_fast_cuda_kernel_config_add_template_arg_bool_mlx_gen_orig_ +#define mlx_fast_cuda_kernel_new mlx_fast_cuda_kernel_new_mlx_gen_orig_ +#define mlx_fast_cuda_kernel_free mlx_fast_cuda_kernel_free_mlx_gen_orig_ +#define mlx_fast_cuda_kernel_apply mlx_fast_cuda_kernel_apply_mlx_gen_orig_ +#define mlx_fast_layer_norm mlx_fast_layer_norm_mlx_gen_orig_ +#define mlx_fast_metal_kernel_config_new mlx_fast_metal_kernel_config_new_mlx_gen_orig_ +#define mlx_fast_metal_kernel_config_free mlx_fast_metal_kernel_config_free_mlx_gen_orig_ +#define mlx_fast_metal_kernel_config_add_output_arg mlx_fast_metal_kernel_config_add_output_arg_mlx_gen_orig_ +#define mlx_fast_metal_kernel_config_set_grid mlx_fast_metal_kernel_config_set_grid_mlx_gen_orig_ +#define mlx_fast_metal_kernel_config_set_thread_group mlx_fast_metal_kernel_config_set_thread_group_mlx_gen_orig_ +#define mlx_fast_metal_kernel_config_set_init_value mlx_fast_metal_kernel_config_set_init_value_mlx_gen_orig_ +#define mlx_fast_metal_kernel_config_set_verbose mlx_fast_metal_kernel_config_set_verbose_mlx_gen_orig_ +#define mlx_fast_metal_kernel_config_add_template_arg_dtype mlx_fast_metal_kernel_config_add_template_arg_dtype_mlx_gen_orig_ +#define mlx_fast_metal_kernel_config_add_template_arg_int mlx_fast_metal_kernel_config_add_template_arg_int_mlx_gen_orig_ +#define mlx_fast_metal_kernel_config_add_template_arg_bool mlx_fast_metal_kernel_config_add_template_arg_bool_mlx_gen_orig_ +#define mlx_fast_metal_kernel_new mlx_fast_metal_kernel_new_mlx_gen_orig_ +#define mlx_fast_metal_kernel_free mlx_fast_metal_kernel_free_mlx_gen_orig_ +#define mlx_fast_metal_kernel_apply mlx_fast_metal_kernel_apply_mlx_gen_orig_ +#define mlx_fast_rms_norm mlx_fast_rms_norm_mlx_gen_orig_ +#define mlx_fast_rope mlx_fast_rope_mlx_gen_orig_ +#define mlx_fast_scaled_dot_product_attention mlx_fast_scaled_dot_product_attention_mlx_gen_orig_ +#define mlx_fft_fft mlx_fft_fft_mlx_gen_orig_ +#define mlx_fft_fft2 mlx_fft_fft2_mlx_gen_orig_ +#define mlx_fft_fftn mlx_fft_fftn_mlx_gen_orig_ +#define mlx_fft_fftshift mlx_fft_fftshift_mlx_gen_orig_ +#define mlx_fft_ifft mlx_fft_ifft_mlx_gen_orig_ +#define mlx_fft_ifft2 mlx_fft_ifft2_mlx_gen_orig_ +#define mlx_fft_ifftn mlx_fft_ifftn_mlx_gen_orig_ +#define mlx_fft_ifftshift mlx_fft_ifftshift_mlx_gen_orig_ +#define mlx_fft_irfft mlx_fft_irfft_mlx_gen_orig_ +#define mlx_fft_irfft2 mlx_fft_irfft2_mlx_gen_orig_ +#define mlx_fft_irfftn mlx_fft_irfftn_mlx_gen_orig_ +#define mlx_fft_rfft mlx_fft_rfft_mlx_gen_orig_ +#define mlx_fft_rfft2 mlx_fft_rfft2_mlx_gen_orig_ +#define mlx_fft_rfftn mlx_fft_rfftn_mlx_gen_orig_ +#define mlx_io_reader_new mlx_io_reader_new_mlx_gen_orig_ +#define mlx_io_reader_descriptor mlx_io_reader_descriptor_mlx_gen_orig_ +#define mlx_io_reader_tostring mlx_io_reader_tostring_mlx_gen_orig_ +#define mlx_io_reader_free mlx_io_reader_free_mlx_gen_orig_ +#define mlx_io_writer_new mlx_io_writer_new_mlx_gen_orig_ +#define mlx_io_writer_descriptor mlx_io_writer_descriptor_mlx_gen_orig_ +#define mlx_io_writer_tostring mlx_io_writer_tostring_mlx_gen_orig_ +#define mlx_io_writer_free mlx_io_writer_free_mlx_gen_orig_ +#define mlx_load_reader mlx_load_reader_mlx_gen_orig_ +#define mlx_load mlx_load_mlx_gen_orig_ +#define mlx_load_safetensors_reader mlx_load_safetensors_reader_mlx_gen_orig_ +#define mlx_load_safetensors mlx_load_safetensors_mlx_gen_orig_ +#define mlx_save_writer mlx_save_writer_mlx_gen_orig_ +#define mlx_save mlx_save_mlx_gen_orig_ +#define mlx_save_safetensors_writer mlx_save_safetensors_writer_mlx_gen_orig_ +#define mlx_save_safetensors mlx_save_safetensors_mlx_gen_orig_ +#define mlx_linalg_cholesky mlx_linalg_cholesky_mlx_gen_orig_ +#define mlx_linalg_cholesky_inv mlx_linalg_cholesky_inv_mlx_gen_orig_ +#define mlx_linalg_cross mlx_linalg_cross_mlx_gen_orig_ +#define mlx_linalg_eig mlx_linalg_eig_mlx_gen_orig_ +#define mlx_linalg_eigh mlx_linalg_eigh_mlx_gen_orig_ +#define mlx_linalg_eigvals mlx_linalg_eigvals_mlx_gen_orig_ +#define mlx_linalg_eigvalsh mlx_linalg_eigvalsh_mlx_gen_orig_ +#define mlx_linalg_inv mlx_linalg_inv_mlx_gen_orig_ +#define mlx_linalg_lu mlx_linalg_lu_mlx_gen_orig_ +#define mlx_linalg_lu_factor mlx_linalg_lu_factor_mlx_gen_orig_ +#define mlx_linalg_norm mlx_linalg_norm_mlx_gen_orig_ +#define mlx_linalg_norm_matrix mlx_linalg_norm_matrix_mlx_gen_orig_ +#define mlx_linalg_norm_l2 mlx_linalg_norm_l2_mlx_gen_orig_ +#define mlx_linalg_pinv mlx_linalg_pinv_mlx_gen_orig_ +#define mlx_linalg_qr mlx_linalg_qr_mlx_gen_orig_ +#define mlx_linalg_solve mlx_linalg_solve_mlx_gen_orig_ +#define mlx_linalg_solve_triangular mlx_linalg_solve_triangular_mlx_gen_orig_ +#define mlx_linalg_svd mlx_linalg_svd_mlx_gen_orig_ +#define mlx_linalg_tri_inv mlx_linalg_tri_inv_mlx_gen_orig_ +#define mlx_map_string_to_array_new mlx_map_string_to_array_new_mlx_gen_orig_ +#define mlx_map_string_to_array_set mlx_map_string_to_array_set_mlx_gen_orig_ +#define mlx_map_string_to_array_free mlx_map_string_to_array_free_mlx_gen_orig_ +#define mlx_map_string_to_array_insert mlx_map_string_to_array_insert_mlx_gen_orig_ +#define mlx_map_string_to_array_get mlx_map_string_to_array_get_mlx_gen_orig_ +#define mlx_map_string_to_array_iterator_new mlx_map_string_to_array_iterator_new_mlx_gen_orig_ +#define mlx_map_string_to_array_iterator_free mlx_map_string_to_array_iterator_free_mlx_gen_orig_ +#define mlx_map_string_to_array_iterator_next mlx_map_string_to_array_iterator_next_mlx_gen_orig_ +#define mlx_map_string_to_string_new mlx_map_string_to_string_new_mlx_gen_orig_ +#define mlx_map_string_to_string_set mlx_map_string_to_string_set_mlx_gen_orig_ +#define mlx_map_string_to_string_free mlx_map_string_to_string_free_mlx_gen_orig_ +#define mlx_map_string_to_string_insert mlx_map_string_to_string_insert_mlx_gen_orig_ +#define mlx_map_string_to_string_get mlx_map_string_to_string_get_mlx_gen_orig_ +#define mlx_map_string_to_string_iterator_new mlx_map_string_to_string_iterator_new_mlx_gen_orig_ +#define mlx_map_string_to_string_iterator_free mlx_map_string_to_string_iterator_free_mlx_gen_orig_ +#define mlx_map_string_to_string_iterator_next mlx_map_string_to_string_iterator_next_mlx_gen_orig_ +#define mlx_clear_cache mlx_clear_cache_mlx_gen_orig_ +#define mlx_get_active_memory mlx_get_active_memory_mlx_gen_orig_ +#define mlx_get_cache_memory mlx_get_cache_memory_mlx_gen_orig_ +#define mlx_get_memory_limit mlx_get_memory_limit_mlx_gen_orig_ +#define mlx_get_peak_memory mlx_get_peak_memory_mlx_gen_orig_ +#define mlx_reset_peak_memory mlx_reset_peak_memory_mlx_gen_orig_ +#define mlx_set_cache_limit mlx_set_cache_limit_mlx_gen_orig_ +#define mlx_set_memory_limit mlx_set_memory_limit_mlx_gen_orig_ +#define mlx_set_wired_limit mlx_set_wired_limit_mlx_gen_orig_ +#define mlx_metal_device_info mlx_metal_device_info_mlx_gen_orig_ +#define mlx_metal_is_available mlx_metal_is_available_mlx_gen_orig_ +#define mlx_metal_start_capture mlx_metal_start_capture_mlx_gen_orig_ +#define mlx_metal_stop_capture mlx_metal_stop_capture_mlx_gen_orig_ +#define mlx_abs mlx_abs_mlx_gen_orig_ +#define mlx_add mlx_add_mlx_gen_orig_ +#define mlx_addmm mlx_addmm_mlx_gen_orig_ +#define mlx_all_axes mlx_all_axes_mlx_gen_orig_ +#define mlx_all_axis mlx_all_axis_mlx_gen_orig_ +#define mlx_all mlx_all_mlx_gen_orig_ +#define mlx_allclose mlx_allclose_mlx_gen_orig_ +#define mlx_any_axes mlx_any_axes_mlx_gen_orig_ +#define mlx_any_axis mlx_any_axis_mlx_gen_orig_ +#define mlx_any mlx_any_mlx_gen_orig_ +#define mlx_arange mlx_arange_mlx_gen_orig_ +#define mlx_arccos mlx_arccos_mlx_gen_orig_ +#define mlx_arccosh mlx_arccosh_mlx_gen_orig_ +#define mlx_arcsin mlx_arcsin_mlx_gen_orig_ +#define mlx_arcsinh mlx_arcsinh_mlx_gen_orig_ +#define mlx_arctan mlx_arctan_mlx_gen_orig_ +#define mlx_arctan2 mlx_arctan2_mlx_gen_orig_ +#define mlx_arctanh mlx_arctanh_mlx_gen_orig_ +#define mlx_argmax_axis mlx_argmax_axis_mlx_gen_orig_ +#define mlx_argmax mlx_argmax_mlx_gen_orig_ +#define mlx_argmin_axis mlx_argmin_axis_mlx_gen_orig_ +#define mlx_argmin mlx_argmin_mlx_gen_orig_ +#define mlx_argpartition_axis mlx_argpartition_axis_mlx_gen_orig_ +#define mlx_argpartition mlx_argpartition_mlx_gen_orig_ +#define mlx_argsort_axis mlx_argsort_axis_mlx_gen_orig_ +#define mlx_argsort mlx_argsort_mlx_gen_orig_ +#define mlx_array_equal mlx_array_equal_mlx_gen_orig_ +#define mlx_as_strided mlx_as_strided_mlx_gen_orig_ +#define mlx_astype mlx_astype_mlx_gen_orig_ +#define mlx_atleast_1d mlx_atleast_1d_mlx_gen_orig_ +#define mlx_atleast_2d mlx_atleast_2d_mlx_gen_orig_ +#define mlx_atleast_3d mlx_atleast_3d_mlx_gen_orig_ +#define mlx_bitwise_and mlx_bitwise_and_mlx_gen_orig_ +#define mlx_bitwise_invert mlx_bitwise_invert_mlx_gen_orig_ +#define mlx_bitwise_or mlx_bitwise_or_mlx_gen_orig_ +#define mlx_bitwise_xor mlx_bitwise_xor_mlx_gen_orig_ +#define mlx_block_masked_mm mlx_block_masked_mm_mlx_gen_orig_ +#define mlx_broadcast_arrays mlx_broadcast_arrays_mlx_gen_orig_ +#define mlx_broadcast_to mlx_broadcast_to_mlx_gen_orig_ +#define mlx_ceil mlx_ceil_mlx_gen_orig_ +#define mlx_clip mlx_clip_mlx_gen_orig_ +#define mlx_concatenate_axis mlx_concatenate_axis_mlx_gen_orig_ +#define mlx_concatenate mlx_concatenate_mlx_gen_orig_ +#define mlx_conjugate mlx_conjugate_mlx_gen_orig_ +#define mlx_contiguous mlx_contiguous_mlx_gen_orig_ +#define mlx_conv1d mlx_conv1d_mlx_gen_orig_ +#define mlx_conv2d mlx_conv2d_mlx_gen_orig_ +#define mlx_conv3d mlx_conv3d_mlx_gen_orig_ +#define mlx_conv_general mlx_conv_general_mlx_gen_orig_ +#define mlx_conv_transpose1d mlx_conv_transpose1d_mlx_gen_orig_ +#define mlx_conv_transpose2d mlx_conv_transpose2d_mlx_gen_orig_ +#define mlx_conv_transpose3d mlx_conv_transpose3d_mlx_gen_orig_ +#define mlx_copy mlx_copy_mlx_gen_orig_ +#define mlx_cos mlx_cos_mlx_gen_orig_ +#define mlx_cosh mlx_cosh_mlx_gen_orig_ +#define mlx_cummax mlx_cummax_mlx_gen_orig_ +#define mlx_cummin mlx_cummin_mlx_gen_orig_ +#define mlx_cumprod mlx_cumprod_mlx_gen_orig_ +#define mlx_cumsum mlx_cumsum_mlx_gen_orig_ +#define mlx_degrees mlx_degrees_mlx_gen_orig_ +#define mlx_depends mlx_depends_mlx_gen_orig_ +#define mlx_dequantize mlx_dequantize_mlx_gen_orig_ +#define mlx_diag mlx_diag_mlx_gen_orig_ +#define mlx_diagonal mlx_diagonal_mlx_gen_orig_ +#define mlx_divide mlx_divide_mlx_gen_orig_ +#define mlx_divmod mlx_divmod_mlx_gen_orig_ +#define mlx_einsum mlx_einsum_mlx_gen_orig_ +#define mlx_equal mlx_equal_mlx_gen_orig_ +#define mlx_erf mlx_erf_mlx_gen_orig_ +#define mlx_erfinv mlx_erfinv_mlx_gen_orig_ +#define mlx_exp mlx_exp_mlx_gen_orig_ +#define mlx_expand_dims_axes mlx_expand_dims_axes_mlx_gen_orig_ +#define mlx_expand_dims mlx_expand_dims_mlx_gen_orig_ +#define mlx_expm1 mlx_expm1_mlx_gen_orig_ +#define mlx_eye mlx_eye_mlx_gen_orig_ +#define mlx_flatten mlx_flatten_mlx_gen_orig_ +#define mlx_floor mlx_floor_mlx_gen_orig_ +#define mlx_floor_divide mlx_floor_divide_mlx_gen_orig_ +#define mlx_from_fp8 mlx_from_fp8_mlx_gen_orig_ +#define mlx_full mlx_full_mlx_gen_orig_ +#define mlx_full_like mlx_full_like_mlx_gen_orig_ +#define mlx_gather mlx_gather_mlx_gen_orig_ +#define mlx_gather_mm mlx_gather_mm_mlx_gen_orig_ +#define mlx_gather_qmm mlx_gather_qmm_mlx_gen_orig_ +#define mlx_greater mlx_greater_mlx_gen_orig_ +#define mlx_greater_equal mlx_greater_equal_mlx_gen_orig_ +#define mlx_hadamard_transform mlx_hadamard_transform_mlx_gen_orig_ +#define mlx_identity mlx_identity_mlx_gen_orig_ +#define mlx_imag mlx_imag_mlx_gen_orig_ +#define mlx_inner mlx_inner_mlx_gen_orig_ +#define mlx_isclose mlx_isclose_mlx_gen_orig_ +#define mlx_isfinite mlx_isfinite_mlx_gen_orig_ +#define mlx_isinf mlx_isinf_mlx_gen_orig_ +#define mlx_isnan mlx_isnan_mlx_gen_orig_ +#define mlx_isneginf mlx_isneginf_mlx_gen_orig_ +#define mlx_isposinf mlx_isposinf_mlx_gen_orig_ +#define mlx_kron mlx_kron_mlx_gen_orig_ +#define mlx_left_shift mlx_left_shift_mlx_gen_orig_ +#define mlx_less mlx_less_mlx_gen_orig_ +#define mlx_less_equal mlx_less_equal_mlx_gen_orig_ +#define mlx_linspace mlx_linspace_mlx_gen_orig_ +#define mlx_log mlx_log_mlx_gen_orig_ +#define mlx_log10 mlx_log10_mlx_gen_orig_ +#define mlx_log1p mlx_log1p_mlx_gen_orig_ +#define mlx_log2 mlx_log2_mlx_gen_orig_ +#define mlx_logaddexp mlx_logaddexp_mlx_gen_orig_ +#define mlx_logcumsumexp mlx_logcumsumexp_mlx_gen_orig_ +#define mlx_logical_and mlx_logical_and_mlx_gen_orig_ +#define mlx_logical_not mlx_logical_not_mlx_gen_orig_ +#define mlx_logical_or mlx_logical_or_mlx_gen_orig_ +#define mlx_logsumexp_axes mlx_logsumexp_axes_mlx_gen_orig_ +#define mlx_logsumexp_axis mlx_logsumexp_axis_mlx_gen_orig_ +#define mlx_logsumexp mlx_logsumexp_mlx_gen_orig_ +#define mlx_masked_scatter mlx_masked_scatter_mlx_gen_orig_ +#define mlx_matmul mlx_matmul_mlx_gen_orig_ +#define mlx_max_axes mlx_max_axes_mlx_gen_orig_ +#define mlx_max_axis mlx_max_axis_mlx_gen_orig_ +#define mlx_max mlx_max_mlx_gen_orig_ +#define mlx_maximum mlx_maximum_mlx_gen_orig_ +#define mlx_mean_axes mlx_mean_axes_mlx_gen_orig_ +#define mlx_mean_axis mlx_mean_axis_mlx_gen_orig_ +#define mlx_mean mlx_mean_mlx_gen_orig_ +#define mlx_median mlx_median_mlx_gen_orig_ +#define mlx_meshgrid mlx_meshgrid_mlx_gen_orig_ +#define mlx_min_axes mlx_min_axes_mlx_gen_orig_ +#define mlx_min_axis mlx_min_axis_mlx_gen_orig_ +#define mlx_min mlx_min_mlx_gen_orig_ +#define mlx_minimum mlx_minimum_mlx_gen_orig_ +#define mlx_moveaxis mlx_moveaxis_mlx_gen_orig_ +#define mlx_multiply mlx_multiply_mlx_gen_orig_ +#define mlx_nan_to_num mlx_nan_to_num_mlx_gen_orig_ +#define mlx_negative mlx_negative_mlx_gen_orig_ +#define mlx_not_equal mlx_not_equal_mlx_gen_orig_ +#define mlx_number_of_elements mlx_number_of_elements_mlx_gen_orig_ +#define mlx_ones mlx_ones_mlx_gen_orig_ +#define mlx_ones_like mlx_ones_like_mlx_gen_orig_ +#define mlx_outer mlx_outer_mlx_gen_orig_ +#define mlx_pad mlx_pad_mlx_gen_orig_ +#define mlx_pad_symmetric mlx_pad_symmetric_mlx_gen_orig_ +#define mlx_partition_axis mlx_partition_axis_mlx_gen_orig_ +#define mlx_partition mlx_partition_mlx_gen_orig_ +#define mlx_power mlx_power_mlx_gen_orig_ +#define mlx_prod_axes mlx_prod_axes_mlx_gen_orig_ +#define mlx_prod_axis mlx_prod_axis_mlx_gen_orig_ +#define mlx_prod mlx_prod_mlx_gen_orig_ +#define mlx_put_along_axis mlx_put_along_axis_mlx_gen_orig_ +#define mlx_quantize mlx_quantize_mlx_gen_orig_ +#define mlx_quantized_matmul mlx_quantized_matmul_mlx_gen_orig_ +#define mlx_radians mlx_radians_mlx_gen_orig_ +#define mlx_real mlx_real_mlx_gen_orig_ +#define mlx_reciprocal mlx_reciprocal_mlx_gen_orig_ +#define mlx_remainder mlx_remainder_mlx_gen_orig_ +#define mlx_repeat_axis mlx_repeat_axis_mlx_gen_orig_ +#define mlx_repeat mlx_repeat_mlx_gen_orig_ +#define mlx_reshape mlx_reshape_mlx_gen_orig_ +#define mlx_right_shift mlx_right_shift_mlx_gen_orig_ +#define mlx_roll_axis mlx_roll_axis_mlx_gen_orig_ +#define mlx_roll_axes mlx_roll_axes_mlx_gen_orig_ +#define mlx_roll mlx_roll_mlx_gen_orig_ +#define mlx_round mlx_round_mlx_gen_orig_ +#define mlx_rsqrt mlx_rsqrt_mlx_gen_orig_ +#define mlx_scatter mlx_scatter_mlx_gen_orig_ +#define mlx_scatter_add mlx_scatter_add_mlx_gen_orig_ +#define mlx_scatter_add_axis mlx_scatter_add_axis_mlx_gen_orig_ +#define mlx_scatter_max mlx_scatter_max_mlx_gen_orig_ +#define mlx_scatter_min mlx_scatter_min_mlx_gen_orig_ +#define mlx_scatter_prod mlx_scatter_prod_mlx_gen_orig_ +#define mlx_segmented_mm mlx_segmented_mm_mlx_gen_orig_ +#define mlx_sigmoid mlx_sigmoid_mlx_gen_orig_ +#define mlx_sign mlx_sign_mlx_gen_orig_ +#define mlx_sin mlx_sin_mlx_gen_orig_ +#define mlx_sinh mlx_sinh_mlx_gen_orig_ +#define mlx_slice mlx_slice_mlx_gen_orig_ +#define mlx_slice_dynamic mlx_slice_dynamic_mlx_gen_orig_ +#define mlx_slice_update mlx_slice_update_mlx_gen_orig_ +#define mlx_slice_update_dynamic mlx_slice_update_dynamic_mlx_gen_orig_ +#define mlx_softmax_axes mlx_softmax_axes_mlx_gen_orig_ +#define mlx_softmax_axis mlx_softmax_axis_mlx_gen_orig_ +#define mlx_softmax mlx_softmax_mlx_gen_orig_ +#define mlx_sort_axis mlx_sort_axis_mlx_gen_orig_ +#define mlx_sort mlx_sort_mlx_gen_orig_ +#define mlx_split mlx_split_mlx_gen_orig_ +#define mlx_split_sections mlx_split_sections_mlx_gen_orig_ +#define mlx_sqrt mlx_sqrt_mlx_gen_orig_ +#define mlx_square mlx_square_mlx_gen_orig_ +#define mlx_squeeze_axes mlx_squeeze_axes_mlx_gen_orig_ +#define mlx_squeeze_axis mlx_squeeze_axis_mlx_gen_orig_ +#define mlx_squeeze mlx_squeeze_mlx_gen_orig_ +#define mlx_stack_axis mlx_stack_axis_mlx_gen_orig_ +#define mlx_stack mlx_stack_mlx_gen_orig_ +#define mlx_std_axes mlx_std_axes_mlx_gen_orig_ +#define mlx_std_axis mlx_std_axis_mlx_gen_orig_ +#define mlx_std mlx_std_mlx_gen_orig_ +#define mlx_stop_gradient mlx_stop_gradient_mlx_gen_orig_ +#define mlx_subtract mlx_subtract_mlx_gen_orig_ +#define mlx_sum_axes mlx_sum_axes_mlx_gen_orig_ +#define mlx_sum_axis mlx_sum_axis_mlx_gen_orig_ +#define mlx_sum mlx_sum_mlx_gen_orig_ +#define mlx_swapaxes mlx_swapaxes_mlx_gen_orig_ +#define mlx_take_axis mlx_take_axis_mlx_gen_orig_ +#define mlx_take mlx_take_mlx_gen_orig_ +#define mlx_take_along_axis mlx_take_along_axis_mlx_gen_orig_ +#define mlx_tan mlx_tan_mlx_gen_orig_ +#define mlx_tanh mlx_tanh_mlx_gen_orig_ +#define mlx_tensordot mlx_tensordot_mlx_gen_orig_ +#define mlx_tensordot_axis mlx_tensordot_axis_mlx_gen_orig_ +#define mlx_tile mlx_tile_mlx_gen_orig_ +#define mlx_to_fp8 mlx_to_fp8_mlx_gen_orig_ +#define mlx_topk_axis mlx_topk_axis_mlx_gen_orig_ +#define mlx_topk mlx_topk_mlx_gen_orig_ +#define mlx_trace mlx_trace_mlx_gen_orig_ +#define mlx_transpose_axes mlx_transpose_axes_mlx_gen_orig_ +#define mlx_transpose mlx_transpose_mlx_gen_orig_ +#define mlx_tri mlx_tri_mlx_gen_orig_ +#define mlx_tril mlx_tril_mlx_gen_orig_ +#define mlx_triu mlx_triu_mlx_gen_orig_ +#define mlx_unflatten mlx_unflatten_mlx_gen_orig_ +#define mlx_var_axes mlx_var_axes_mlx_gen_orig_ +#define mlx_var_axis mlx_var_axis_mlx_gen_orig_ +#define mlx_var mlx_var_mlx_gen_orig_ +#define mlx_view mlx_view_mlx_gen_orig_ +#define mlx_where mlx_where_mlx_gen_orig_ +#define mlx_zeros mlx_zeros_mlx_gen_orig_ +#define mlx_zeros_like mlx_zeros_like_mlx_gen_orig_ +#define mlx_random_bernoulli mlx_random_bernoulli_mlx_gen_orig_ +#define mlx_random_bits mlx_random_bits_mlx_gen_orig_ +#define mlx_random_categorical_shape mlx_random_categorical_shape_mlx_gen_orig_ +#define mlx_random_categorical_num_samples mlx_random_categorical_num_samples_mlx_gen_orig_ +#define mlx_random_categorical mlx_random_categorical_mlx_gen_orig_ +#define mlx_random_gumbel mlx_random_gumbel_mlx_gen_orig_ +#define mlx_random_key mlx_random_key_mlx_gen_orig_ +#define mlx_random_laplace mlx_random_laplace_mlx_gen_orig_ +#define mlx_random_multivariate_normal mlx_random_multivariate_normal_mlx_gen_orig_ +#define mlx_random_normal_broadcast mlx_random_normal_broadcast_mlx_gen_orig_ +#define mlx_random_normal mlx_random_normal_mlx_gen_orig_ +#define mlx_random_permutation mlx_random_permutation_mlx_gen_orig_ +#define mlx_random_permutation_arange mlx_random_permutation_arange_mlx_gen_orig_ +#define mlx_random_randint mlx_random_randint_mlx_gen_orig_ +#define mlx_random_seed mlx_random_seed_mlx_gen_orig_ +#define mlx_random_split_num mlx_random_split_num_mlx_gen_orig_ +#define mlx_random_split mlx_random_split_mlx_gen_orig_ +#define mlx_random_truncated_normal mlx_random_truncated_normal_mlx_gen_orig_ +#define mlx_random_uniform mlx_random_uniform_mlx_gen_orig_ +#define mlx_stream_new mlx_stream_new_mlx_gen_orig_ +#define mlx_stream_new_device mlx_stream_new_device_mlx_gen_orig_ +#define mlx_stream_set mlx_stream_set_mlx_gen_orig_ +#define mlx_stream_free mlx_stream_free_mlx_gen_orig_ +#define mlx_stream_tostring mlx_stream_tostring_mlx_gen_orig_ +#define mlx_stream_equal mlx_stream_equal_mlx_gen_orig_ +#define mlx_stream_get_device mlx_stream_get_device_mlx_gen_orig_ +#define mlx_stream_get_index mlx_stream_get_index_mlx_gen_orig_ +#define mlx_synchronize mlx_synchronize_mlx_gen_orig_ +#define mlx_get_default_stream mlx_get_default_stream_mlx_gen_orig_ +#define mlx_set_default_stream mlx_set_default_stream_mlx_gen_orig_ +#define mlx_default_cpu_stream_new mlx_default_cpu_stream_new_mlx_gen_orig_ +#define mlx_default_gpu_stream_new mlx_default_gpu_stream_new_mlx_gen_orig_ +#define mlx_string_new mlx_string_new_mlx_gen_orig_ +#define mlx_string_new_data mlx_string_new_data_mlx_gen_orig_ +#define mlx_string_set mlx_string_set_mlx_gen_orig_ +#define mlx_string_data mlx_string_data_mlx_gen_orig_ +#define mlx_string_free mlx_string_free_mlx_gen_orig_ +#define mlx_detail_vmap_replace mlx_detail_vmap_replace_mlx_gen_orig_ +#define mlx_detail_vmap_trace mlx_detail_vmap_trace_mlx_gen_orig_ +#define mlx_async_eval mlx_async_eval_mlx_gen_orig_ +#define mlx_checkpoint mlx_checkpoint_mlx_gen_orig_ +#define mlx_custom_function mlx_custom_function_mlx_gen_orig_ +#define mlx_custom_vjp mlx_custom_vjp_mlx_gen_orig_ +#define mlx_eval mlx_eval_mlx_gen_orig_ +#define mlx_jvp mlx_jvp_mlx_gen_orig_ +#define mlx_value_and_grad mlx_value_and_grad_mlx_gen_orig_ +#define mlx_vjp mlx_vjp_mlx_gen_orig_ +#define mlx_vector_array_new mlx_vector_array_new_mlx_gen_orig_ +#define mlx_vector_array_set mlx_vector_array_set_mlx_gen_orig_ +#define mlx_vector_array_free mlx_vector_array_free_mlx_gen_orig_ +#define mlx_vector_array_new_data mlx_vector_array_new_data_mlx_gen_orig_ +#define mlx_vector_array_new_value mlx_vector_array_new_value_mlx_gen_orig_ +#define mlx_vector_array_set_data mlx_vector_array_set_data_mlx_gen_orig_ +#define mlx_vector_array_set_value mlx_vector_array_set_value_mlx_gen_orig_ +#define mlx_vector_array_append_data mlx_vector_array_append_data_mlx_gen_orig_ +#define mlx_vector_array_append_value mlx_vector_array_append_value_mlx_gen_orig_ +#define mlx_vector_array_size mlx_vector_array_size_mlx_gen_orig_ +#define mlx_vector_array_get mlx_vector_array_get_mlx_gen_orig_ +#define mlx_vector_vector_array_new mlx_vector_vector_array_new_mlx_gen_orig_ +#define mlx_vector_vector_array_set mlx_vector_vector_array_set_mlx_gen_orig_ +#define mlx_vector_vector_array_free mlx_vector_vector_array_free_mlx_gen_orig_ +#define mlx_vector_vector_array_new_data mlx_vector_vector_array_new_data_mlx_gen_orig_ +#define mlx_vector_vector_array_new_value mlx_vector_vector_array_new_value_mlx_gen_orig_ +#define mlx_vector_vector_array_set_data mlx_vector_vector_array_set_data_mlx_gen_orig_ +#define mlx_vector_vector_array_set_value mlx_vector_vector_array_set_value_mlx_gen_orig_ +#define mlx_vector_vector_array_append_data mlx_vector_vector_array_append_data_mlx_gen_orig_ +#define mlx_vector_vector_array_append_value mlx_vector_vector_array_append_value_mlx_gen_orig_ +#define mlx_vector_vector_array_size mlx_vector_vector_array_size_mlx_gen_orig_ +#define mlx_vector_vector_array_get mlx_vector_vector_array_get_mlx_gen_orig_ +#define mlx_vector_int_new mlx_vector_int_new_mlx_gen_orig_ +#define mlx_vector_int_set mlx_vector_int_set_mlx_gen_orig_ +#define mlx_vector_int_free mlx_vector_int_free_mlx_gen_orig_ +#define mlx_vector_int_new_data mlx_vector_int_new_data_mlx_gen_orig_ +#define mlx_vector_int_new_value mlx_vector_int_new_value_mlx_gen_orig_ +#define mlx_vector_int_set_data mlx_vector_int_set_data_mlx_gen_orig_ +#define mlx_vector_int_set_value mlx_vector_int_set_value_mlx_gen_orig_ +#define mlx_vector_int_append_data mlx_vector_int_append_data_mlx_gen_orig_ +#define mlx_vector_int_append_value mlx_vector_int_append_value_mlx_gen_orig_ +#define mlx_vector_int_size mlx_vector_int_size_mlx_gen_orig_ +#define mlx_vector_int_get mlx_vector_int_get_mlx_gen_orig_ +#define mlx_vector_string_new mlx_vector_string_new_mlx_gen_orig_ +#define mlx_vector_string_set mlx_vector_string_set_mlx_gen_orig_ +#define mlx_vector_string_free mlx_vector_string_free_mlx_gen_orig_ +#define mlx_vector_string_new_data mlx_vector_string_new_data_mlx_gen_orig_ +#define mlx_vector_string_new_value mlx_vector_string_new_value_mlx_gen_orig_ +#define mlx_vector_string_set_data mlx_vector_string_set_data_mlx_gen_orig_ +#define mlx_vector_string_set_value mlx_vector_string_set_value_mlx_gen_orig_ +#define mlx_vector_string_append_data mlx_vector_string_append_data_mlx_gen_orig_ +#define mlx_vector_string_append_value mlx_vector_string_append_value_mlx_gen_orig_ +#define mlx_vector_string_size mlx_vector_string_size_mlx_gen_orig_ +#define mlx_vector_string_get mlx_vector_string_get_mlx_gen_orig_ +#define mlx_version mlx_version_mlx_gen_orig_ + +#include "mlx/c/mlx.h" + +#undef mlx_dtype_size +#undef mlx_array_tostring +#undef mlx_array_new +#undef mlx_array_free +#undef mlx_array_new_bool +#undef mlx_array_new_int +#undef mlx_array_new_float32 +#undef mlx_array_new_float +#undef mlx_array_new_float64 +#undef mlx_array_new_double +#undef mlx_array_new_complex +#undef mlx_array_new_data +#undef mlx_array_set +#undef mlx_array_set_bool +#undef mlx_array_set_int +#undef mlx_array_set_float32 +#undef mlx_array_set_float +#undef mlx_array_set_float64 +#undef mlx_array_set_double +#undef mlx_array_set_complex +#undef mlx_array_set_data +#undef mlx_array_itemsize +#undef mlx_array_size +#undef mlx_array_nbytes +#undef mlx_array_ndim +#undef mlx_array_shape +#undef mlx_array_strides +#undef mlx_array_dim +#undef mlx_array_dtype +#undef mlx_array_eval +#undef mlx_array_item_bool +#undef mlx_array_item_uint8 +#undef mlx_array_item_uint16 +#undef mlx_array_item_uint32 +#undef mlx_array_item_uint64 +#undef mlx_array_item_int8 +#undef mlx_array_item_int16 +#undef mlx_array_item_int32 +#undef mlx_array_item_int64 +#undef mlx_array_item_float32 +#undef mlx_array_item_float64 +#undef mlx_array_item_complex64 +#undef mlx_array_item_float16 +#undef mlx_array_item_bfloat16 +#undef mlx_array_data_bool +#undef mlx_array_data_uint8 +#undef mlx_array_data_uint16 +#undef mlx_array_data_uint32 +#undef mlx_array_data_uint64 +#undef mlx_array_data_int8 +#undef mlx_array_data_int16 +#undef mlx_array_data_int32 +#undef mlx_array_data_int64 +#undef mlx_array_data_float32 +#undef mlx_array_data_float64 +#undef mlx_array_data_complex64 +#undef mlx_array_data_float16 +#undef mlx_array_data_bfloat16 +#undef _mlx_array_is_available +#undef _mlx_array_wait +#undef _mlx_array_is_contiguous +#undef _mlx_array_is_row_contiguous +#undef _mlx_array_is_col_contiguous +#undef mlx_closure_new +#undef mlx_closure_free +#undef mlx_closure_new_func +#undef mlx_closure_new_func_payload +#undef mlx_closure_set +#undef mlx_closure_apply +#undef mlx_closure_new_unary +#undef mlx_closure_kwargs_new +#undef mlx_closure_kwargs_free +#undef mlx_closure_kwargs_new_func +#undef mlx_closure_kwargs_new_func_payload +#undef mlx_closure_kwargs_set +#undef mlx_closure_kwargs_apply +#undef mlx_closure_value_and_grad_new +#undef mlx_closure_value_and_grad_free +#undef mlx_closure_value_and_grad_new_func +#undef mlx_closure_value_and_grad_new_func_payload +#undef mlx_closure_value_and_grad_set +#undef mlx_closure_value_and_grad_apply +#undef mlx_closure_custom_new +#undef mlx_closure_custom_free +#undef mlx_closure_custom_new_func +#undef mlx_closure_custom_new_func_payload +#undef mlx_closure_custom_set +#undef mlx_closure_custom_apply +#undef mlx_closure_custom_jvp_new +#undef mlx_closure_custom_jvp_free +#undef mlx_closure_custom_jvp_new_func +#undef mlx_closure_custom_jvp_new_func_payload +#undef mlx_closure_custom_jvp_set +#undef mlx_closure_custom_jvp_apply +#undef mlx_closure_custom_vmap_new +#undef mlx_closure_custom_vmap_free +#undef mlx_closure_custom_vmap_new_func +#undef mlx_closure_custom_vmap_new_func_payload +#undef mlx_closure_custom_vmap_set +#undef mlx_closure_custom_vmap_apply +#undef mlx_compile +#undef mlx_detail_compile +#undef mlx_detail_compile_clear_cache +#undef mlx_detail_compile_erase +#undef mlx_disable_compile +#undef mlx_enable_compile +#undef mlx_set_compile_mode +#undef mlx_device_new +#undef mlx_device_new_type +#undef mlx_device_free +#undef mlx_device_set +#undef mlx_device_tostring +#undef mlx_device_equal +#undef mlx_device_get_index +#undef mlx_device_get_type +#undef mlx_get_default_device +#undef mlx_set_default_device +#undef mlx_distributed_group_rank +#undef mlx_distributed_group_size +#undef mlx_distributed_group_split +#undef mlx_distributed_is_available +#undef mlx_distributed_init +#undef mlx_distributed_all_gather +#undef mlx_distributed_all_max +#undef mlx_distributed_all_min +#undef mlx_distributed_all_sum +#undef mlx_distributed_recv +#undef mlx_distributed_recv_like +#undef mlx_distributed_send +#undef mlx_distributed_sum_scatter +#undef mlx_set_error_handler +#undef _mlx_error +#undef mlx_export_function +#undef mlx_export_function_kwargs +#undef mlx_function_exporter_new +#undef mlx_function_exporter_free +#undef mlx_function_exporter_apply +#undef mlx_function_exporter_apply_kwargs +#undef mlx_imported_function_new +#undef mlx_imported_function_free +#undef mlx_imported_function_apply +#undef mlx_imported_function_apply_kwargs +#undef mlx_fast_cuda_kernel_config_new +#undef mlx_fast_cuda_kernel_config_free +#undef mlx_fast_cuda_kernel_config_add_output_arg +#undef mlx_fast_cuda_kernel_config_set_grid +#undef mlx_fast_cuda_kernel_config_set_thread_group +#undef mlx_fast_cuda_kernel_config_set_init_value +#undef mlx_fast_cuda_kernel_config_set_verbose +#undef mlx_fast_cuda_kernel_config_add_template_arg_dtype +#undef mlx_fast_cuda_kernel_config_add_template_arg_int +#undef mlx_fast_cuda_kernel_config_add_template_arg_bool +#undef mlx_fast_cuda_kernel_new +#undef mlx_fast_cuda_kernel_free +#undef mlx_fast_cuda_kernel_apply +#undef mlx_fast_layer_norm +#undef mlx_fast_metal_kernel_config_new +#undef mlx_fast_metal_kernel_config_free +#undef mlx_fast_metal_kernel_config_add_output_arg +#undef mlx_fast_metal_kernel_config_set_grid +#undef mlx_fast_metal_kernel_config_set_thread_group +#undef mlx_fast_metal_kernel_config_set_init_value +#undef mlx_fast_metal_kernel_config_set_verbose +#undef mlx_fast_metal_kernel_config_add_template_arg_dtype +#undef mlx_fast_metal_kernel_config_add_template_arg_int +#undef mlx_fast_metal_kernel_config_add_template_arg_bool +#undef mlx_fast_metal_kernel_new +#undef mlx_fast_metal_kernel_free +#undef mlx_fast_metal_kernel_apply +#undef mlx_fast_rms_norm +#undef mlx_fast_rope +#undef mlx_fast_scaled_dot_product_attention +#undef mlx_fft_fft +#undef mlx_fft_fft2 +#undef mlx_fft_fftn +#undef mlx_fft_fftshift +#undef mlx_fft_ifft +#undef mlx_fft_ifft2 +#undef mlx_fft_ifftn +#undef mlx_fft_ifftshift +#undef mlx_fft_irfft +#undef mlx_fft_irfft2 +#undef mlx_fft_irfftn +#undef mlx_fft_rfft +#undef mlx_fft_rfft2 +#undef mlx_fft_rfftn +#undef mlx_io_reader_new +#undef mlx_io_reader_descriptor +#undef mlx_io_reader_tostring +#undef mlx_io_reader_free +#undef mlx_io_writer_new +#undef mlx_io_writer_descriptor +#undef mlx_io_writer_tostring +#undef mlx_io_writer_free +#undef mlx_load_reader +#undef mlx_load +#undef mlx_load_safetensors_reader +#undef mlx_load_safetensors +#undef mlx_save_writer +#undef mlx_save +#undef mlx_save_safetensors_writer +#undef mlx_save_safetensors +#undef mlx_linalg_cholesky +#undef mlx_linalg_cholesky_inv +#undef mlx_linalg_cross +#undef mlx_linalg_eig +#undef mlx_linalg_eigh +#undef mlx_linalg_eigvals +#undef mlx_linalg_eigvalsh +#undef mlx_linalg_inv +#undef mlx_linalg_lu +#undef mlx_linalg_lu_factor +#undef mlx_linalg_norm +#undef mlx_linalg_norm_matrix +#undef mlx_linalg_norm_l2 +#undef mlx_linalg_pinv +#undef mlx_linalg_qr +#undef mlx_linalg_solve +#undef mlx_linalg_solve_triangular +#undef mlx_linalg_svd +#undef mlx_linalg_tri_inv +#undef mlx_map_string_to_array_new +#undef mlx_map_string_to_array_set +#undef mlx_map_string_to_array_free +#undef mlx_map_string_to_array_insert +#undef mlx_map_string_to_array_get +#undef mlx_map_string_to_array_iterator_new +#undef mlx_map_string_to_array_iterator_free +#undef mlx_map_string_to_array_iterator_next +#undef mlx_map_string_to_string_new +#undef mlx_map_string_to_string_set +#undef mlx_map_string_to_string_free +#undef mlx_map_string_to_string_insert +#undef mlx_map_string_to_string_get +#undef mlx_map_string_to_string_iterator_new +#undef mlx_map_string_to_string_iterator_free +#undef mlx_map_string_to_string_iterator_next +#undef mlx_clear_cache +#undef mlx_get_active_memory +#undef mlx_get_cache_memory +#undef mlx_get_memory_limit +#undef mlx_get_peak_memory +#undef mlx_reset_peak_memory +#undef mlx_set_cache_limit +#undef mlx_set_memory_limit +#undef mlx_set_wired_limit +#undef mlx_metal_device_info +#undef mlx_metal_is_available +#undef mlx_metal_start_capture +#undef mlx_metal_stop_capture +#undef mlx_abs +#undef mlx_add +#undef mlx_addmm +#undef mlx_all_axes +#undef mlx_all_axis +#undef mlx_all +#undef mlx_allclose +#undef mlx_any_axes +#undef mlx_any_axis +#undef mlx_any +#undef mlx_arange +#undef mlx_arccos +#undef mlx_arccosh +#undef mlx_arcsin +#undef mlx_arcsinh +#undef mlx_arctan +#undef mlx_arctan2 +#undef mlx_arctanh +#undef mlx_argmax_axis +#undef mlx_argmax +#undef mlx_argmin_axis +#undef mlx_argmin +#undef mlx_argpartition_axis +#undef mlx_argpartition +#undef mlx_argsort_axis +#undef mlx_argsort +#undef mlx_array_equal +#undef mlx_as_strided +#undef mlx_astype +#undef mlx_atleast_1d +#undef mlx_atleast_2d +#undef mlx_atleast_3d +#undef mlx_bitwise_and +#undef mlx_bitwise_invert +#undef mlx_bitwise_or +#undef mlx_bitwise_xor +#undef mlx_block_masked_mm +#undef mlx_broadcast_arrays +#undef mlx_broadcast_to +#undef mlx_ceil +#undef mlx_clip +#undef mlx_concatenate_axis +#undef mlx_concatenate +#undef mlx_conjugate +#undef mlx_contiguous +#undef mlx_conv1d +#undef mlx_conv2d +#undef mlx_conv3d +#undef mlx_conv_general +#undef mlx_conv_transpose1d +#undef mlx_conv_transpose2d +#undef mlx_conv_transpose3d +#undef mlx_copy +#undef mlx_cos +#undef mlx_cosh +#undef mlx_cummax +#undef mlx_cummin +#undef mlx_cumprod +#undef mlx_cumsum +#undef mlx_degrees +#undef mlx_depends +#undef mlx_dequantize +#undef mlx_diag +#undef mlx_diagonal +#undef mlx_divide +#undef mlx_divmod +#undef mlx_einsum +#undef mlx_equal +#undef mlx_erf +#undef mlx_erfinv +#undef mlx_exp +#undef mlx_expand_dims_axes +#undef mlx_expand_dims +#undef mlx_expm1 +#undef mlx_eye +#undef mlx_flatten +#undef mlx_floor +#undef mlx_floor_divide +#undef mlx_from_fp8 +#undef mlx_full +#undef mlx_full_like +#undef mlx_gather +#undef mlx_gather_mm +#undef mlx_gather_qmm +#undef mlx_greater +#undef mlx_greater_equal +#undef mlx_hadamard_transform +#undef mlx_identity +#undef mlx_imag +#undef mlx_inner +#undef mlx_isclose +#undef mlx_isfinite +#undef mlx_isinf +#undef mlx_isnan +#undef mlx_isneginf +#undef mlx_isposinf +#undef mlx_kron +#undef mlx_left_shift +#undef mlx_less +#undef mlx_less_equal +#undef mlx_linspace +#undef mlx_log +#undef mlx_log10 +#undef mlx_log1p +#undef mlx_log2 +#undef mlx_logaddexp +#undef mlx_logcumsumexp +#undef mlx_logical_and +#undef mlx_logical_not +#undef mlx_logical_or +#undef mlx_logsumexp_axes +#undef mlx_logsumexp_axis +#undef mlx_logsumexp +#undef mlx_masked_scatter +#undef mlx_matmul +#undef mlx_max_axes +#undef mlx_max_axis +#undef mlx_max +#undef mlx_maximum +#undef mlx_mean_axes +#undef mlx_mean_axis +#undef mlx_mean +#undef mlx_median +#undef mlx_meshgrid +#undef mlx_min_axes +#undef mlx_min_axis +#undef mlx_min +#undef mlx_minimum +#undef mlx_moveaxis +#undef mlx_multiply +#undef mlx_nan_to_num +#undef mlx_negative +#undef mlx_not_equal +#undef mlx_number_of_elements +#undef mlx_ones +#undef mlx_ones_like +#undef mlx_outer +#undef mlx_pad +#undef mlx_pad_symmetric +#undef mlx_partition_axis +#undef mlx_partition +#undef mlx_power +#undef mlx_prod_axes +#undef mlx_prod_axis +#undef mlx_prod +#undef mlx_put_along_axis +#undef mlx_quantize +#undef mlx_quantized_matmul +#undef mlx_radians +#undef mlx_real +#undef mlx_reciprocal +#undef mlx_remainder +#undef mlx_repeat_axis +#undef mlx_repeat +#undef mlx_reshape +#undef mlx_right_shift +#undef mlx_roll_axis +#undef mlx_roll_axes +#undef mlx_roll +#undef mlx_round +#undef mlx_rsqrt +#undef mlx_scatter +#undef mlx_scatter_add +#undef mlx_scatter_add_axis +#undef mlx_scatter_max +#undef mlx_scatter_min +#undef mlx_scatter_prod +#undef mlx_segmented_mm +#undef mlx_sigmoid +#undef mlx_sign +#undef mlx_sin +#undef mlx_sinh +#undef mlx_slice +#undef mlx_slice_dynamic +#undef mlx_slice_update +#undef mlx_slice_update_dynamic +#undef mlx_softmax_axes +#undef mlx_softmax_axis +#undef mlx_softmax +#undef mlx_sort_axis +#undef mlx_sort +#undef mlx_split +#undef mlx_split_sections +#undef mlx_sqrt +#undef mlx_square +#undef mlx_squeeze_axes +#undef mlx_squeeze_axis +#undef mlx_squeeze +#undef mlx_stack_axis +#undef mlx_stack +#undef mlx_std_axes +#undef mlx_std_axis +#undef mlx_std +#undef mlx_stop_gradient +#undef mlx_subtract +#undef mlx_sum_axes +#undef mlx_sum_axis +#undef mlx_sum +#undef mlx_swapaxes +#undef mlx_take_axis +#undef mlx_take +#undef mlx_take_along_axis +#undef mlx_tan +#undef mlx_tanh +#undef mlx_tensordot +#undef mlx_tensordot_axis +#undef mlx_tile +#undef mlx_to_fp8 +#undef mlx_topk_axis +#undef mlx_topk +#undef mlx_trace +#undef mlx_transpose_axes +#undef mlx_transpose +#undef mlx_tri +#undef mlx_tril +#undef mlx_triu +#undef mlx_unflatten +#undef mlx_var_axes +#undef mlx_var_axis +#undef mlx_var +#undef mlx_view +#undef mlx_where +#undef mlx_zeros +#undef mlx_zeros_like +#undef mlx_random_bernoulli +#undef mlx_random_bits +#undef mlx_random_categorical_shape +#undef mlx_random_categorical_num_samples +#undef mlx_random_categorical +#undef mlx_random_gumbel +#undef mlx_random_key +#undef mlx_random_laplace +#undef mlx_random_multivariate_normal +#undef mlx_random_normal_broadcast +#undef mlx_random_normal +#undef mlx_random_permutation +#undef mlx_random_permutation_arange +#undef mlx_random_randint +#undef mlx_random_seed +#undef mlx_random_split_num +#undef mlx_random_split +#undef mlx_random_truncated_normal +#undef mlx_random_uniform +#undef mlx_stream_new +#undef mlx_stream_new_device +#undef mlx_stream_set +#undef mlx_stream_free +#undef mlx_stream_tostring +#undef mlx_stream_equal +#undef mlx_stream_get_device +#undef mlx_stream_get_index +#undef mlx_synchronize +#undef mlx_get_default_stream +#undef mlx_set_default_stream +#undef mlx_default_cpu_stream_new +#undef mlx_default_gpu_stream_new +#undef mlx_string_new +#undef mlx_string_new_data +#undef mlx_string_set +#undef mlx_string_data +#undef mlx_string_free +#undef mlx_detail_vmap_replace +#undef mlx_detail_vmap_trace +#undef mlx_async_eval +#undef mlx_checkpoint +#undef mlx_custom_function +#undef mlx_custom_vjp +#undef mlx_eval +#undef mlx_jvp +#undef mlx_value_and_grad +#undef mlx_vjp +#undef mlx_vector_array_new +#undef mlx_vector_array_set +#undef mlx_vector_array_free +#undef mlx_vector_array_new_data +#undef mlx_vector_array_new_value +#undef mlx_vector_array_set_data +#undef mlx_vector_array_set_value +#undef mlx_vector_array_append_data +#undef mlx_vector_array_append_value +#undef mlx_vector_array_size +#undef mlx_vector_array_get +#undef mlx_vector_vector_array_new +#undef mlx_vector_vector_array_set +#undef mlx_vector_vector_array_free +#undef mlx_vector_vector_array_new_data +#undef mlx_vector_vector_array_new_value +#undef mlx_vector_vector_array_set_data +#undef mlx_vector_vector_array_set_value +#undef mlx_vector_vector_array_append_data +#undef mlx_vector_vector_array_append_value +#undef mlx_vector_vector_array_size +#undef mlx_vector_vector_array_get +#undef mlx_vector_int_new +#undef mlx_vector_int_set +#undef mlx_vector_int_free +#undef mlx_vector_int_new_data +#undef mlx_vector_int_new_value +#undef mlx_vector_int_set_data +#undef mlx_vector_int_set_value +#undef mlx_vector_int_append_data +#undef mlx_vector_int_append_value +#undef mlx_vector_int_size +#undef mlx_vector_int_get +#undef mlx_vector_string_new +#undef mlx_vector_string_set +#undef mlx_vector_string_free +#undef mlx_vector_string_new_data +#undef mlx_vector_string_new_value +#undef mlx_vector_string_set_data +#undef mlx_vector_string_set_value +#undef mlx_vector_string_append_data +#undef mlx_vector_string_append_value +#undef mlx_vector_string_size +#undef mlx_vector_string_get +#undef mlx_version + +extern size_t (*mlx_dtype_size_)(mlx_dtype dtype); +extern int (*mlx_array_tostring_)(mlx_string* str, const mlx_array arr); +extern mlx_array (*mlx_array_new_)(void); +extern int (*mlx_array_free_)(mlx_array arr); +extern mlx_array (*mlx_array_new_bool_)(bool val); +extern mlx_array (*mlx_array_new_int_)(int val); +extern mlx_array (*mlx_array_new_float32_)(float val); +extern mlx_array (*mlx_array_new_float_)(float val); +extern mlx_array (*mlx_array_new_float64_)(double val); +extern mlx_array (*mlx_array_new_double_)(double val); +extern mlx_array (*mlx_array_new_complex_)(float real_val, float imag_val); +extern mlx_array (*mlx_array_new_data_)( + const void* data, + const int* shape, + int dim, + mlx_dtype dtype); +extern int (*mlx_array_set_)(mlx_array* arr, const mlx_array src); +extern int (*mlx_array_set_bool_)(mlx_array* arr, bool val); +extern int (*mlx_array_set_int_)(mlx_array* arr, int val); +extern int (*mlx_array_set_float32_)(mlx_array* arr, float val); +extern int (*mlx_array_set_float_)(mlx_array* arr, float val); +extern int (*mlx_array_set_float64_)(mlx_array* arr, double val); +extern int (*mlx_array_set_double_)(mlx_array* arr, double val); +extern int (*mlx_array_set_complex_)(mlx_array* arr, float real_val, float imag_val); +extern int (*mlx_array_set_data_)( + mlx_array* arr, + const void* data, + const int* shape, + int dim, + mlx_dtype dtype); +extern size_t (*mlx_array_itemsize_)(const mlx_array arr); +extern size_t (*mlx_array_size_)(const mlx_array arr); +extern size_t (*mlx_array_nbytes_)(const mlx_array arr); +extern size_t (*mlx_array_ndim_)(const mlx_array arr); +extern const int * (*mlx_array_shape_)(const mlx_array arr); +extern const size_t * (*mlx_array_strides_)(const mlx_array arr); +extern int (*mlx_array_dim_)(const mlx_array arr, int dim); +extern mlx_dtype (*mlx_array_dtype_)(const mlx_array arr); +extern int (*mlx_array_eval_)(mlx_array arr); +extern int (*mlx_array_item_bool_)(bool* res, const mlx_array arr); +extern int (*mlx_array_item_uint8_)(uint8_t* res, const mlx_array arr); +extern int (*mlx_array_item_uint16_)(uint16_t* res, const mlx_array arr); +extern int (*mlx_array_item_uint32_)(uint32_t* res, const mlx_array arr); +extern int (*mlx_array_item_uint64_)(uint64_t* res, const mlx_array arr); +extern int (*mlx_array_item_int8_)(int8_t* res, const mlx_array arr); +extern int (*mlx_array_item_int16_)(int16_t* res, const mlx_array arr); +extern int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr); +extern int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr); +extern int (*mlx_array_item_float32_)(float* res, const mlx_array arr); +extern int (*mlx_array_item_float64_)(double* res, const mlx_array arr); +extern int (*mlx_array_item_complex64_)(float _Complex* res, const mlx_array arr); +extern int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr); +extern int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr); +extern const bool * (*mlx_array_data_bool_)(const mlx_array arr); +extern const uint8_t * (*mlx_array_data_uint8_)(const mlx_array arr); +extern const uint16_t * (*mlx_array_data_uint16_)(const mlx_array arr); +extern const uint32_t * (*mlx_array_data_uint32_)(const mlx_array arr); +extern const uint64_t * (*mlx_array_data_uint64_)(const mlx_array arr); +extern const int8_t * (*mlx_array_data_int8_)(const mlx_array arr); +extern const int16_t * (*mlx_array_data_int16_)(const mlx_array arr); +extern const int32_t * (*mlx_array_data_int32_)(const mlx_array arr); +extern const int64_t * (*mlx_array_data_int64_)(const mlx_array arr); +extern const float * (*mlx_array_data_float32_)(const mlx_array arr); +extern const double * (*mlx_array_data_float64_)(const mlx_array arr); +extern const float _Complex * (*mlx_array_data_complex64_)(const mlx_array arr); +extern const float16_t * (*mlx_array_data_float16_)(const mlx_array arr); +extern const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr); +extern int (*_mlx_array_is_available_)(bool* res, const mlx_array arr); +extern int (*_mlx_array_wait_)(const mlx_array arr); +extern int (*_mlx_array_is_contiguous_)(bool* res, const mlx_array arr); +extern int (*_mlx_array_is_row_contiguous_)(bool* res, const mlx_array arr); +extern int (*_mlx_array_is_col_contiguous_)(bool* res, const mlx_array arr); +extern mlx_closure (*mlx_closure_new_)(void); +extern int (*mlx_closure_free_)(mlx_closure cls); +extern mlx_closure (*mlx_closure_new_func_)( + int (*fun)(mlx_vector_array*, const mlx_vector_array)); +extern mlx_closure (*mlx_closure_new_func_payload_)( + int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), + void* payload, + void (*dtor)(void*)); +extern int (*mlx_closure_set_)(mlx_closure* cls, const mlx_closure src); +extern int (*mlx_closure_apply_)( + mlx_vector_array* res, + mlx_closure cls, + const mlx_vector_array input); +extern mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array)); +extern mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void); +extern int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls); +extern mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_map_string_to_array)); +extern mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_map_string_to_array, + void*), + void* payload, + void (*dtor)(void*)); +extern int (*mlx_closure_kwargs_set_)( + mlx_closure_kwargs* cls, + const mlx_closure_kwargs src); +extern int (*mlx_closure_kwargs_apply_)( + mlx_vector_array* res, + mlx_closure_kwargs cls, + const mlx_vector_array input_0, + const mlx_map_string_to_array input_1); +extern mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_)(void); +extern int (*mlx_closure_value_and_grad_free_)(mlx_closure_value_and_grad cls); +extern mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_)( + int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)); +extern mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_payload_)( + int (*fun)( + mlx_vector_array*, + mlx_vector_array*, + const mlx_vector_array, + void*), + void* payload, + void (*dtor)(void*)); +extern int (*mlx_closure_value_and_grad_set_)( + mlx_closure_value_and_grad* cls, + const mlx_closure_value_and_grad src); +extern int (*mlx_closure_value_and_grad_apply_)( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + mlx_closure_value_and_grad cls, + const mlx_vector_array input); +extern mlx_closure_custom (*mlx_closure_custom_new_)(void); +extern int (*mlx_closure_custom_free_)(mlx_closure_custom cls); +extern mlx_closure_custom (*mlx_closure_custom_new_func_)(int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const mlx_vector_array)); +extern mlx_closure_custom (*mlx_closure_custom_new_func_payload_)( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const mlx_vector_array, + void*), + void* payload, + void (*dtor)(void*)); +extern int (*mlx_closure_custom_set_)( + mlx_closure_custom* cls, + const mlx_closure_custom src); +extern int (*mlx_closure_custom_apply_)( + mlx_vector_array* res, + mlx_closure_custom cls, + const mlx_vector_array input_0, + const mlx_vector_array input_1, + const mlx_vector_array input_2); +extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void); +extern int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls); +extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const int*, + size_t _num)); +extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const int*, + size_t _num, + void*), + void* payload, + void (*dtor)(void*)); +extern int (*mlx_closure_custom_jvp_set_)( + mlx_closure_custom_jvp* cls, + const mlx_closure_custom_jvp src); +extern int (*mlx_closure_custom_jvp_apply_)( + mlx_vector_array* res, + mlx_closure_custom_jvp cls, + const mlx_vector_array input_0, + const mlx_vector_array input_1, + const int* input_2, + size_t input_2_num); +extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void); +extern int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls); +extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(int (*fun)( + mlx_vector_array*, + mlx_vector_int*, + const mlx_vector_array, + const int*, + size_t _num)); +extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)( + int (*fun)( + mlx_vector_array*, + mlx_vector_int*, + const mlx_vector_array, + const int*, + size_t _num, + void*), + void* payload, + void (*dtor)(void*)); +extern int (*mlx_closure_custom_vmap_set_)( + mlx_closure_custom_vmap* cls, + const mlx_closure_custom_vmap src); +extern int (*mlx_closure_custom_vmap_apply_)( + mlx_vector_array* res_0, + mlx_vector_int* res_1, + mlx_closure_custom_vmap cls, + const mlx_vector_array input_0, + const int* input_1, + size_t input_1_num); +extern int (*mlx_compile_)(mlx_closure* res, const mlx_closure fun, bool shapeless); +extern int (*mlx_detail_compile_)( + mlx_closure* res, + const mlx_closure fun, + uintptr_t fun_id, + bool shapeless, + const uint64_t* constants, + size_t constants_num); +extern int (*mlx_detail_compile_clear_cache_)(void); +extern int (*mlx_detail_compile_erase_)(uintptr_t fun_id); +extern int (*mlx_disable_compile_)(void); +extern int (*mlx_enable_compile_)(void); +extern int (*mlx_set_compile_mode_)(mlx_compile_mode mode); +extern mlx_device (*mlx_device_new_)(void); +extern mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index); +extern int (*mlx_device_free_)(mlx_device dev); +extern int (*mlx_device_set_)(mlx_device* dev, const mlx_device src); +extern int (*mlx_device_tostring_)(mlx_string* str, mlx_device dev); +extern bool (*mlx_device_equal_)(mlx_device lhs, mlx_device rhs); +extern int (*mlx_device_get_index_)(int* index, mlx_device dev); +extern int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev); +extern int (*mlx_get_default_device_)(mlx_device* dev); +extern int (*mlx_set_default_device_)(mlx_device dev); +extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group); +extern int (*mlx_distributed_group_size_)(mlx_distributed_group group); +extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key); +extern bool (*mlx_distributed_is_available_)(void); +extern mlx_distributed_group (*mlx_distributed_init_)(bool strict); +extern int (*mlx_distributed_all_gather_)( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream S); +extern int (*mlx_distributed_all_max_)( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +extern int (*mlx_distributed_all_min_)( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +extern int (*mlx_distributed_all_sum_)( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +extern int (*mlx_distributed_recv_)( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + int src, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +extern int (*mlx_distributed_recv_like_)( + mlx_array* res, + const mlx_array x, + int src, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +extern int (*mlx_distributed_send_)( + mlx_array* res, + const mlx_array x, + int dst, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +extern int (*mlx_distributed_sum_scatter_)( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +extern void (*mlx_set_error_handler_)( + mlx_error_handler_func handler, + void* data, + void (*dtor)(void*)); +extern void (*_mlx_error_)(const char* file, const int line, const char* fmt, ...); +extern int (*mlx_export_function_)( + const char* file, + const mlx_closure fun, + const mlx_vector_array args, + bool shapeless); +extern int (*mlx_export_function_kwargs_)( + const char* file, + const mlx_closure_kwargs fun, + const mlx_vector_array args, + const mlx_map_string_to_array kwargs, + bool shapeless); +extern mlx_function_exporter (*mlx_function_exporter_new_)( + const char* file, + const mlx_closure fun, + bool shapeless); +extern int (*mlx_function_exporter_free_)(mlx_function_exporter xfunc); +extern int (*mlx_function_exporter_apply_)( + const mlx_function_exporter xfunc, + const mlx_vector_array args); +extern int (*mlx_function_exporter_apply_kwargs_)( + const mlx_function_exporter xfunc, + const mlx_vector_array args, + const mlx_map_string_to_array kwargs); +extern mlx_imported_function (*mlx_imported_function_new_)(const char* file); +extern int (*mlx_imported_function_free_)(mlx_imported_function xfunc); +extern int (*mlx_imported_function_apply_)( + mlx_vector_array* res, + const mlx_imported_function xfunc, + const mlx_vector_array args); +extern int (*mlx_imported_function_apply_kwargs_)( + mlx_vector_array* res, + const mlx_imported_function xfunc, + const mlx_vector_array args, + const mlx_map_string_to_array kwargs); +extern mlx_fast_cuda_kernel_config (*mlx_fast_cuda_kernel_config_new_)(void); +extern void (*mlx_fast_cuda_kernel_config_free_)(mlx_fast_cuda_kernel_config cls); +extern int (*mlx_fast_cuda_kernel_config_add_output_arg_)( + mlx_fast_cuda_kernel_config cls, + const int* shape, + size_t size, + mlx_dtype dtype); +extern int (*mlx_fast_cuda_kernel_config_set_grid_)( + mlx_fast_cuda_kernel_config cls, + int grid1, + int grid2, + int grid3); +extern int (*mlx_fast_cuda_kernel_config_set_thread_group_)( + mlx_fast_cuda_kernel_config cls, + int thread1, + int thread2, + int thread3); +extern int (*mlx_fast_cuda_kernel_config_set_init_value_)( + mlx_fast_cuda_kernel_config cls, + float value); +extern int (*mlx_fast_cuda_kernel_config_set_verbose_)( + mlx_fast_cuda_kernel_config cls, + bool verbose); +extern int (*mlx_fast_cuda_kernel_config_add_template_arg_dtype_)( + mlx_fast_cuda_kernel_config cls, + const char* name, + mlx_dtype dtype); +extern int (*mlx_fast_cuda_kernel_config_add_template_arg_int_)( + mlx_fast_cuda_kernel_config cls, + const char* name, + int value); +extern int (*mlx_fast_cuda_kernel_config_add_template_arg_bool_)( + mlx_fast_cuda_kernel_config cls, + const char* name, + bool value); +extern mlx_fast_cuda_kernel (*mlx_fast_cuda_kernel_new_)( + const char* name, + const mlx_vector_string input_names, + const mlx_vector_string output_names, + const char* source, + const char* header, + bool ensure_row_contiguous, + int shared_memory); +extern void (*mlx_fast_cuda_kernel_free_)(mlx_fast_cuda_kernel cls); +extern int (*mlx_fast_cuda_kernel_apply_)( + mlx_vector_array* outputs, + mlx_fast_cuda_kernel cls, + const mlx_vector_array inputs, + const mlx_fast_cuda_kernel_config config, + const mlx_stream stream); +extern int (*mlx_fast_layer_norm_)( + mlx_array* res, + const mlx_array x, + const mlx_array weight /* may be null */, + const mlx_array bias /* may be null */, + float eps, + const mlx_stream s); +extern mlx_fast_metal_kernel_config (*mlx_fast_metal_kernel_config_new_)(void); +extern void (*mlx_fast_metal_kernel_config_free_)(mlx_fast_metal_kernel_config cls); +extern int (*mlx_fast_metal_kernel_config_add_output_arg_)( + mlx_fast_metal_kernel_config cls, + const int* shape, + size_t size, + mlx_dtype dtype); +extern int (*mlx_fast_metal_kernel_config_set_grid_)( + mlx_fast_metal_kernel_config cls, + int grid1, + int grid2, + int grid3); +extern int (*mlx_fast_metal_kernel_config_set_thread_group_)( + mlx_fast_metal_kernel_config cls, + int thread1, + int thread2, + int thread3); +extern int (*mlx_fast_metal_kernel_config_set_init_value_)( + mlx_fast_metal_kernel_config cls, + float value); +extern int (*mlx_fast_metal_kernel_config_set_verbose_)( + mlx_fast_metal_kernel_config cls, + bool verbose); +extern int (*mlx_fast_metal_kernel_config_add_template_arg_dtype_)( + mlx_fast_metal_kernel_config cls, + const char* name, + mlx_dtype dtype); +extern int (*mlx_fast_metal_kernel_config_add_template_arg_int_)( + mlx_fast_metal_kernel_config cls, + const char* name, + int value); +extern int (*mlx_fast_metal_kernel_config_add_template_arg_bool_)( + mlx_fast_metal_kernel_config cls, + const char* name, + bool value); +extern mlx_fast_metal_kernel (*mlx_fast_metal_kernel_new_)( + const char* name, + const mlx_vector_string input_names, + const mlx_vector_string output_names, + const char* source, + const char* header, + bool ensure_row_contiguous, + bool atomic_outputs); +extern void (*mlx_fast_metal_kernel_free_)(mlx_fast_metal_kernel cls); +extern int (*mlx_fast_metal_kernel_apply_)( + mlx_vector_array* outputs, + mlx_fast_metal_kernel cls, + const mlx_vector_array inputs, + const mlx_fast_metal_kernel_config config, + const mlx_stream stream); +extern int (*mlx_fast_rms_norm_)( + mlx_array* res, + const mlx_array x, + const mlx_array weight /* may be null */, + float eps, + const mlx_stream s); +extern int (*mlx_fast_rope_)( + mlx_array* res, + const mlx_array x, + int dims, + bool traditional, + mlx_optional_float base, + float scale, + int offset, + const mlx_array freqs /* may be null */, + const mlx_stream s); +extern int (*mlx_fast_scaled_dot_product_attention_)( + mlx_array* res, + const mlx_array queries, + const mlx_array keys, + const mlx_array values, + float scale, + const char* mask_mode, + const mlx_array mask_arr /* may be null */, + const mlx_array sinks /* may be null */, + const mlx_stream s); +extern int (*mlx_fft_fft_)( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s); +extern int (*mlx_fft_fft2_)( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern int (*mlx_fft_fftn_)( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern int (*mlx_fft_fftshift_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern int (*mlx_fft_ifft_)( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s); +extern int (*mlx_fft_ifft2_)( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern int (*mlx_fft_ifftn_)( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern int (*mlx_fft_ifftshift_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern int (*mlx_fft_irfft_)( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s); +extern int (*mlx_fft_irfft2_)( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern int (*mlx_fft_irfftn_)( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern int (*mlx_fft_rfft_)( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s); +extern int (*mlx_fft_rfft2_)( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern int (*mlx_fft_rfftn_)( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable); +extern int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io); +extern int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io); +extern int (*mlx_io_reader_free_)(mlx_io_reader io); +extern mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable); +extern int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io); +extern int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io); +extern int (*mlx_io_writer_free_)(mlx_io_writer io); +extern int (*mlx_load_reader_)( + mlx_array* res, + mlx_io_reader in_stream, + const mlx_stream s); +extern int (*mlx_load_)(mlx_array* res, const char* file, const mlx_stream s); +extern int (*mlx_load_safetensors_reader_)( + mlx_map_string_to_array* res_0, + mlx_map_string_to_string* res_1, + mlx_io_reader in_stream, + const mlx_stream s); +extern int (*mlx_load_safetensors_)( + mlx_map_string_to_array* res_0, + mlx_map_string_to_string* res_1, + const char* file, + const mlx_stream s); +extern int (*mlx_save_writer_)(mlx_io_writer out_stream, const mlx_array a); +extern int (*mlx_save_)(const char* file, const mlx_array a); +extern int (*mlx_save_safetensors_writer_)( + mlx_io_writer in_stream, + const mlx_map_string_to_array param, + const mlx_map_string_to_string metadata); +extern int (*mlx_save_safetensors_)( + const char* file, + const mlx_map_string_to_array param, + const mlx_map_string_to_string metadata); +extern int (*mlx_linalg_cholesky_)( + mlx_array* res, + const mlx_array a, + bool upper, + const mlx_stream s); +extern int (*mlx_linalg_cholesky_inv_)( + mlx_array* res, + const mlx_array a, + bool upper, + const mlx_stream s); +extern int (*mlx_linalg_cross_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + int axis, + const mlx_stream s); +extern int (*mlx_linalg_eig_)( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const mlx_stream s); +extern int (*mlx_linalg_eigh_)( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const char* UPLO, + const mlx_stream s); +extern int (*mlx_linalg_eigvals_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_linalg_eigvalsh_)( + mlx_array* res, + const mlx_array a, + const char* UPLO, + const mlx_stream s); +extern int (*mlx_linalg_inv_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_linalg_lu_)(mlx_vector_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_linalg_lu_factor_)( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const mlx_stream s); +extern int (*mlx_linalg_norm_)( + mlx_array* res, + const mlx_array a, + double ord, + const int* axis /* may be null */, + size_t axis_num, + bool keepdims, + const mlx_stream s); +extern int (*mlx_linalg_norm_matrix_)( + mlx_array* res, + const mlx_array a, + const char* ord, + const int* axis /* may be null */, + size_t axis_num, + bool keepdims, + const mlx_stream s); +extern int (*mlx_linalg_norm_l2_)( + mlx_array* res, + const mlx_array a, + const int* axis /* may be null */, + size_t axis_num, + bool keepdims, + const mlx_stream s); +extern int (*mlx_linalg_pinv_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_linalg_qr_)( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const mlx_stream s); +extern int (*mlx_linalg_solve_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_linalg_solve_triangular_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + bool upper, + const mlx_stream s); +extern int (*mlx_linalg_svd_)( + mlx_vector_array* res, + const mlx_array a, + bool compute_uv, + const mlx_stream s); +extern int (*mlx_linalg_tri_inv_)( + mlx_array* res, + const mlx_array a, + bool upper, + const mlx_stream s); +extern mlx_map_string_to_array (*mlx_map_string_to_array_new_)(void); +extern int (*mlx_map_string_to_array_set_)( + mlx_map_string_to_array* map, + const mlx_map_string_to_array src); +extern int (*mlx_map_string_to_array_free_)(mlx_map_string_to_array map); +extern int (*mlx_map_string_to_array_insert_)( + mlx_map_string_to_array map, + const char* key, + const mlx_array value); +extern int (*mlx_map_string_to_array_get_)( + mlx_array* value, + const mlx_map_string_to_array map, + const char* key); +extern mlx_map_string_to_array_iterator (*mlx_map_string_to_array_iterator_new_)( + mlx_map_string_to_array map); +extern int (*mlx_map_string_to_array_iterator_free_)(mlx_map_string_to_array_iterator it); +extern int (*mlx_map_string_to_array_iterator_next_)( + const char** key, + mlx_array* value, + mlx_map_string_to_array_iterator it); +extern mlx_map_string_to_string (*mlx_map_string_to_string_new_)(void); +extern int (*mlx_map_string_to_string_set_)( + mlx_map_string_to_string* map, + const mlx_map_string_to_string src); +extern int (*mlx_map_string_to_string_free_)(mlx_map_string_to_string map); +extern int (*mlx_map_string_to_string_insert_)( + mlx_map_string_to_string map, + const char* key, + const char* value); +extern int (*mlx_map_string_to_string_get_)( + const char** value, + const mlx_map_string_to_string map, + const char* key); +extern mlx_map_string_to_string_iterator (*mlx_map_string_to_string_iterator_new_)( + mlx_map_string_to_string map); +extern int (*mlx_map_string_to_string_iterator_free_)( + mlx_map_string_to_string_iterator it); +extern int (*mlx_map_string_to_string_iterator_next_)( + const char** key, + const char** value, + mlx_map_string_to_string_iterator it); +extern int (*mlx_clear_cache_)(void); +extern int (*mlx_get_active_memory_)(size_t* res); +extern int (*mlx_get_cache_memory_)(size_t* res); +extern int (*mlx_get_memory_limit_)(size_t* res); +extern int (*mlx_get_peak_memory_)(size_t* res); +extern int (*mlx_reset_peak_memory_)(void); +extern int (*mlx_set_cache_limit_)(size_t* res, size_t limit); +extern int (*mlx_set_memory_limit_)(size_t* res, size_t limit); +extern int (*mlx_set_wired_limit_)(size_t* res, size_t limit); +extern mlx_metal_device_info_t (*mlx_metal_device_info_)(void); +extern int (*mlx_metal_is_available_)(bool* res); +extern int (*mlx_metal_start_capture_)(const char* path); +extern int (*mlx_metal_stop_capture_)(void); +extern int (*mlx_abs_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_add_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_addmm_)( + mlx_array* res, + const mlx_array c, + const mlx_array a, + const mlx_array b, + float alpha, + float beta, + const mlx_stream s); +extern int (*mlx_all_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +extern int (*mlx_all_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +extern int (*mlx_all_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +extern int (*mlx_allclose_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + double rtol, + double atol, + bool equal_nan, + const mlx_stream s); +extern int (*mlx_any_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +extern int (*mlx_any_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +extern int (*mlx_any_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +extern int (*mlx_arange_)( + mlx_array* res, + double start, + double stop, + double step, + mlx_dtype dtype, + const mlx_stream s); +extern int (*mlx_arccos_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_arccosh_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_arcsin_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_arcsinh_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_arctan_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_arctan2_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_arctanh_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_argmax_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +extern int (*mlx_argmax_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +extern int (*mlx_argmin_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +extern int (*mlx_argmin_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +extern int (*mlx_argpartition_axis_)( + mlx_array* res, + const mlx_array a, + int kth, + int axis, + const mlx_stream s); +extern int (*mlx_argpartition_)( + mlx_array* res, + const mlx_array a, + int kth, + const mlx_stream s); +extern int (*mlx_argsort_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s); +extern int (*mlx_argsort_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_array_equal_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + bool equal_nan, + const mlx_stream s); +extern int (*mlx_as_strided_)( + mlx_array* res, + const mlx_array a, + const int* shape, + size_t shape_num, + const int64_t* strides, + size_t strides_num, + size_t offset, + const mlx_stream s); +extern int (*mlx_astype_)( + mlx_array* res, + const mlx_array a, + mlx_dtype dtype, + const mlx_stream s); +extern int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_bitwise_and_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_bitwise_invert_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_bitwise_or_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_bitwise_xor_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_block_masked_mm_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + int block_size, + const mlx_array mask_out /* may be null */, + const mlx_array mask_lhs /* may be null */, + const mlx_array mask_rhs /* may be null */, + const mlx_stream s); +extern int (*mlx_broadcast_arrays_)( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_stream s); +extern int (*mlx_broadcast_to_)( + mlx_array* res, + const mlx_array a, + const int* shape, + size_t shape_num, + const mlx_stream s); +extern int (*mlx_ceil_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_clip_)( + mlx_array* res, + const mlx_array a, + const mlx_array a_min /* may be null */, + const mlx_array a_max /* may be null */, + const mlx_stream s); +extern int (*mlx_concatenate_axis_)( + mlx_array* res, + const mlx_vector_array arrays, + int axis, + const mlx_stream s); +extern int (*mlx_concatenate_)( + mlx_array* res, + const mlx_vector_array arrays, + const mlx_stream s); +extern int (*mlx_conjugate_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_contiguous_)( + mlx_array* res, + const mlx_array a, + bool allow_col_major, + const mlx_stream s); +extern int (*mlx_conv1d_)( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride, + int padding, + int dilation, + int groups, + const mlx_stream s); +extern int (*mlx_conv2d_)( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int padding_0, + int padding_1, + int dilation_0, + int dilation_1, + int groups, + const mlx_stream s); +extern int (*mlx_conv3d_)( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int stride_2, + int padding_0, + int padding_1, + int padding_2, + int dilation_0, + int dilation_1, + int dilation_2, + int groups, + const mlx_stream s); +extern int (*mlx_conv_general_)( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + const int* stride, + size_t stride_num, + const int* padding_lo, + size_t padding_lo_num, + const int* padding_hi, + size_t padding_hi_num, + const int* kernel_dilation, + size_t kernel_dilation_num, + const int* input_dilation, + size_t input_dilation_num, + int groups, + bool flip, + const mlx_stream s); +extern int (*mlx_conv_transpose1d_)( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride, + int padding, + int dilation, + int output_padding, + int groups, + const mlx_stream s); +extern int (*mlx_conv_transpose2d_)( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int padding_0, + int padding_1, + int dilation_0, + int dilation_1, + int output_padding_0, + int output_padding_1, + int groups, + const mlx_stream s); +extern int (*mlx_conv_transpose3d_)( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int stride_2, + int padding_0, + int padding_1, + int padding_2, + int dilation_0, + int dilation_1, + int dilation_2, + int output_padding_0, + int output_padding_1, + int output_padding_2, + int groups, + const mlx_stream s); +extern int (*mlx_copy_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_cos_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_cosh_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_cummax_)( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +extern int (*mlx_cummin_)( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +extern int (*mlx_cumprod_)( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +extern int (*mlx_cumsum_)( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +extern int (*mlx_degrees_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_depends_)( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_vector_array dependencies); +extern int (*mlx_dequantize_)( + mlx_array* res, + const mlx_array w, + const mlx_array scales, + const mlx_array biases /* may be null */, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + mlx_optional_dtype dtype, + const mlx_stream s); +extern int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s); +extern int (*mlx_diagonal_)( + mlx_array* res, + const mlx_array a, + int offset, + int axis1, + int axis2, + const mlx_stream s); +extern int (*mlx_divide_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_divmod_)( + mlx_vector_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_einsum_)( + mlx_array* res, + const char* subscripts, + const mlx_vector_array operands, + const mlx_stream s); +extern int (*mlx_equal_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_erf_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_erfinv_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_exp_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_expand_dims_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern int (*mlx_expand_dims_)( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s); +extern int (*mlx_expm1_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_eye_)( + mlx_array* res, + int n, + int m, + int k, + mlx_dtype dtype, + const mlx_stream s); +extern int (*mlx_flatten_)( + mlx_array* res, + const mlx_array a, + int start_axis, + int end_axis, + const mlx_stream s); +extern int (*mlx_floor_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_floor_divide_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_from_fp8_)( + mlx_array* res, + const mlx_array x, + mlx_dtype dtype, + const mlx_stream s); +extern int (*mlx_full_)( + mlx_array* res, + const int* shape, + size_t shape_num, + const mlx_array vals, + mlx_dtype dtype, + const mlx_stream s); +extern int (*mlx_full_like_)( + mlx_array* res, + const mlx_array a, + const mlx_array vals, + mlx_dtype dtype, + const mlx_stream s); +extern int (*mlx_gather_)( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const int* axes, + size_t axes_num, + const int* slice_sizes, + size_t slice_sizes_num, + const mlx_stream s); +extern int (*mlx_gather_mm_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_array lhs_indices /* may be null */, + const mlx_array rhs_indices /* may be null */, + bool sorted_indices, + const mlx_stream s); +extern int (*mlx_gather_qmm_)( + mlx_array* res, + const mlx_array x, + const mlx_array w, + const mlx_array scales, + const mlx_array biases /* may be null */, + const mlx_array lhs_indices /* may be null */, + const mlx_array rhs_indices /* may be null */, + bool transpose, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + bool sorted_indices, + const mlx_stream s); +extern int (*mlx_greater_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_greater_equal_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_hadamard_transform_)( + mlx_array* res, + const mlx_array a, + mlx_optional_float scale, + const mlx_stream s); +extern int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s); +extern int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_inner_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_isclose_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + double rtol, + double atol, + bool equal_nan, + const mlx_stream s); +extern int (*mlx_isfinite_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_isinf_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_isnan_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_isneginf_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_isposinf_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_kron_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_left_shift_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_less_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_less_equal_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_linspace_)( + mlx_array* res, + double start, + double stop, + int num, + mlx_dtype dtype, + const mlx_stream s); +extern int (*mlx_log_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_log10_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_log1p_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_log2_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_logaddexp_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_logcumsumexp_)( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +extern int (*mlx_logical_and_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_logical_not_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_logical_or_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_logsumexp_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +extern int (*mlx_logsumexp_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +extern int (*mlx_logsumexp_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +extern int (*mlx_masked_scatter_)( + mlx_array* res, + const mlx_array a, + const mlx_array mask, + const mlx_array src, + const mlx_stream s); +extern int (*mlx_matmul_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_max_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +extern int (*mlx_max_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +extern int (*mlx_max_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +extern int (*mlx_maximum_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_mean_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +extern int (*mlx_mean_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +extern int (*mlx_mean_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +extern int (*mlx_median_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +extern int (*mlx_meshgrid_)( + mlx_vector_array* res, + const mlx_vector_array arrays, + bool sparse, + const char* indexing, + const mlx_stream s); +extern int (*mlx_min_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +extern int (*mlx_min_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +extern int (*mlx_min_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +extern int (*mlx_minimum_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_moveaxis_)( + mlx_array* res, + const mlx_array a, + int source, + int destination, + const mlx_stream s); +extern int (*mlx_multiply_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_nan_to_num_)( + mlx_array* res, + const mlx_array a, + float nan, + mlx_optional_float posinf, + mlx_optional_float neginf, + const mlx_stream s); +extern int (*mlx_negative_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_not_equal_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_number_of_elements_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool inverted, + mlx_dtype dtype, + const mlx_stream s); +extern int (*mlx_ones_)( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_stream s); +extern int (*mlx_ones_like_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_outer_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_pad_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const int* low_pad_size, + size_t low_pad_size_num, + const int* high_pad_size, + size_t high_pad_size_num, + const mlx_array pad_value, + const char* mode, + const mlx_stream s); +extern int (*mlx_pad_symmetric_)( + mlx_array* res, + const mlx_array a, + int pad_width, + const mlx_array pad_value, + const char* mode, + const mlx_stream s); +extern int (*mlx_partition_axis_)( + mlx_array* res, + const mlx_array a, + int kth, + int axis, + const mlx_stream s); +extern int (*mlx_partition_)( + mlx_array* res, + const mlx_array a, + int kth, + const mlx_stream s); +extern int (*mlx_power_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_prod_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +extern int (*mlx_prod_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +extern int (*mlx_prod_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +extern int (*mlx_put_along_axis_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array values, + int axis, + const mlx_stream s); +extern int (*mlx_quantize_)( + mlx_vector_array* res, + const mlx_array w, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + const mlx_stream s); +extern int (*mlx_quantized_matmul_)( + mlx_array* res, + const mlx_array x, + const mlx_array w, + const mlx_array scales, + const mlx_array biases /* may be null */, + bool transpose, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + const mlx_stream s); +extern int (*mlx_radians_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_real_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_reciprocal_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_remainder_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_repeat_axis_)( + mlx_array* res, + const mlx_array arr, + int repeats, + int axis, + const mlx_stream s); +extern int (*mlx_repeat_)( + mlx_array* res, + const mlx_array arr, + int repeats, + const mlx_stream s); +extern int (*mlx_reshape_)( + mlx_array* res, + const mlx_array a, + const int* shape, + size_t shape_num, + const mlx_stream s); +extern int (*mlx_right_shift_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_roll_axis_)( + mlx_array* res, + const mlx_array a, + const int* shift, + size_t shift_num, + int axis, + const mlx_stream s); +extern int (*mlx_roll_axes_)( + mlx_array* res, + const mlx_array a, + const int* shift, + size_t shift_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern int (*mlx_roll_)( + mlx_array* res, + const mlx_array a, + const int* shift, + size_t shift_num, + const mlx_stream s); +extern int (*mlx_round_)( + mlx_array* res, + const mlx_array a, + int decimals, + const mlx_stream s); +extern int (*mlx_rsqrt_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_scatter_)( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern int (*mlx_scatter_add_)( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern int (*mlx_scatter_add_axis_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array values, + int axis, + const mlx_stream s); +extern int (*mlx_scatter_max_)( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern int (*mlx_scatter_min_)( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern int (*mlx_scatter_prod_)( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern int (*mlx_segmented_mm_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_array segments, + const mlx_stream s); +extern int (*mlx_sigmoid_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_sign_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_sin_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_sinh_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_slice_)( + mlx_array* res, + const mlx_array a, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s); +extern int (*mlx_slice_dynamic_)( + mlx_array* res, + const mlx_array a, + const mlx_array start, + const int* axes, + size_t axes_num, + const int* slice_size, + size_t slice_size_num, + const mlx_stream s); +extern int (*mlx_slice_update_)( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s); +extern int (*mlx_slice_update_dynamic_)( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const mlx_array start, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern int (*mlx_softmax_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool precise, + const mlx_stream s); +extern int (*mlx_softmax_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool precise, + const mlx_stream s); +extern int (*mlx_softmax_)( + mlx_array* res, + const mlx_array a, + bool precise, + const mlx_stream s); +extern int (*mlx_sort_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s); +extern int (*mlx_sort_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_split_)( + mlx_vector_array* res, + const mlx_array a, + int num_splits, + int axis, + const mlx_stream s); +extern int (*mlx_split_sections_)( + mlx_vector_array* res, + const mlx_array a, + const int* indices, + size_t indices_num, + int axis, + const mlx_stream s); +extern int (*mlx_sqrt_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_square_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_squeeze_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern int (*mlx_squeeze_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s); +extern int (*mlx_squeeze_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_stack_axis_)( + mlx_array* res, + const mlx_vector_array arrays, + int axis, + const mlx_stream s); +extern int (*mlx_stack_)( + mlx_array* res, + const mlx_vector_array arrays, + const mlx_stream s); +extern int (*mlx_std_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + int ddof, + const mlx_stream s); +extern int (*mlx_std_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + int ddof, + const mlx_stream s); +extern int (*mlx_std_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + int ddof, + const mlx_stream s); +extern int (*mlx_stop_gradient_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_subtract_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +extern int (*mlx_sum_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +extern int (*mlx_sum_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +extern int (*mlx_sum_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +extern int (*mlx_swapaxes_)( + mlx_array* res, + const mlx_array a, + int axis1, + int axis2, + const mlx_stream s); +extern int (*mlx_take_axis_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + int axis, + const mlx_stream s); +extern int (*mlx_take_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_stream s); +extern int (*mlx_take_along_axis_)( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + int axis, + const mlx_stream s); +extern int (*mlx_tan_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_tanh_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_tensordot_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const int* axes_a, + size_t axes_a_num, + const int* axes_b, + size_t axes_b_num, + const mlx_stream s); +extern int (*mlx_tensordot_axis_)( + mlx_array* res, + const mlx_array a, + const mlx_array b, + int axis, + const mlx_stream s); +extern int (*mlx_tile_)( + mlx_array* res, + const mlx_array arr, + const int* reps, + size_t reps_num, + const mlx_stream s); +extern int (*mlx_to_fp8_)(mlx_array* res, const mlx_array x, const mlx_stream s); +extern int (*mlx_topk_axis_)( + mlx_array* res, + const mlx_array a, + int k, + int axis, + const mlx_stream s); +extern int (*mlx_topk_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s); +extern int (*mlx_trace_)( + mlx_array* res, + const mlx_array a, + int offset, + int axis1, + int axis2, + mlx_dtype dtype, + const mlx_stream s); +extern int (*mlx_transpose_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +extern int (*mlx_transpose_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_tri_)( + mlx_array* res, + int n, + int m, + int k, + mlx_dtype type, + const mlx_stream s); +extern int (*mlx_tril_)(mlx_array* res, const mlx_array x, int k, const mlx_stream s); +extern int (*mlx_triu_)(mlx_array* res, const mlx_array x, int k, const mlx_stream s); +extern int (*mlx_unflatten_)( + mlx_array* res, + const mlx_array a, + int axis, + const int* shape, + size_t shape_num, + const mlx_stream s); +extern int (*mlx_var_axes_)( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + int ddof, + const mlx_stream s); +extern int (*mlx_var_axis_)( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + int ddof, + const mlx_stream s); +extern int (*mlx_var_)( + mlx_array* res, + const mlx_array a, + bool keepdims, + int ddof, + const mlx_stream s); +extern int (*mlx_view_)( + mlx_array* res, + const mlx_array a, + mlx_dtype dtype, + const mlx_stream s); +extern int (*mlx_where_)( + mlx_array* res, + const mlx_array condition, + const mlx_array x, + const mlx_array y, + const mlx_stream s); +extern int (*mlx_zeros_)( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_stream s); +extern int (*mlx_zeros_like_)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_random_bernoulli_)( + mlx_array* res, + const mlx_array p, + const int* shape, + size_t shape_num, + const mlx_array key /* may be null */, + const mlx_stream s); +extern int (*mlx_random_bits_)( + mlx_array* res, + const int* shape, + size_t shape_num, + int width, + const mlx_array key /* may be null */, + const mlx_stream s); +extern int (*mlx_random_categorical_shape_)( + mlx_array* res, + const mlx_array logits, + int axis, + const int* shape, + size_t shape_num, + const mlx_array key /* may be null */, + const mlx_stream s); +extern int (*mlx_random_categorical_num_samples_)( + mlx_array* res, + const mlx_array logits_, + int axis, + int num_samples, + const mlx_array key /* may be null */, + const mlx_stream s); +extern int (*mlx_random_categorical_)( + mlx_array* res, + const mlx_array logits, + int axis, + const mlx_array key /* may be null */, + const mlx_stream s); +extern int (*mlx_random_gumbel_)( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); +extern int (*mlx_random_key_)(mlx_array* res, uint64_t seed); +extern int (*mlx_random_laplace_)( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + float loc, + float scale, + const mlx_array key /* may be null */, + const mlx_stream s); +extern int (*mlx_random_multivariate_normal_)( + mlx_array* res, + const mlx_array mean, + const mlx_array cov, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); +extern int (*mlx_random_normal_broadcast_)( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array loc /* may be null */, + const mlx_array scale /* may be null */, + const mlx_array key /* may be null */, + const mlx_stream s); +extern int (*mlx_random_normal_)( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + float loc, + float scale, + const mlx_array key /* may be null */, + const mlx_stream s); +extern int (*mlx_random_permutation_)( + mlx_array* res, + const mlx_array x, + int axis, + const mlx_array key /* may be null */, + const mlx_stream s); +extern int (*mlx_random_permutation_arange_)( + mlx_array* res, + int x, + const mlx_array key /* may be null */, + const mlx_stream s); +extern int (*mlx_random_randint_)( + mlx_array* res, + const mlx_array low, + const mlx_array high, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); +extern int (*mlx_random_seed_)(uint64_t seed); +extern int (*mlx_random_split_num_)( + mlx_array* res, + const mlx_array key, + int num, + const mlx_stream s); +extern int (*mlx_random_split_)( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array key, + const mlx_stream s); +extern int (*mlx_random_truncated_normal_)( + mlx_array* res, + const mlx_array lower, + const mlx_array upper, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); +extern int (*mlx_random_uniform_)( + mlx_array* res, + const mlx_array low, + const mlx_array high, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); +extern mlx_stream (*mlx_stream_new_)(void); +extern mlx_stream (*mlx_stream_new_device_)(mlx_device dev); +extern int (*mlx_stream_set_)(mlx_stream* stream, const mlx_stream src); +extern int (*mlx_stream_free_)(mlx_stream stream); +extern int (*mlx_stream_tostring_)(mlx_string* str, mlx_stream stream); +extern bool (*mlx_stream_equal_)(mlx_stream lhs, mlx_stream rhs); +extern int (*mlx_stream_get_device_)(mlx_device* dev, mlx_stream stream); +extern int (*mlx_stream_get_index_)(int* index, mlx_stream stream); +extern int (*mlx_synchronize_)(mlx_stream stream); +extern int (*mlx_get_default_stream_)(mlx_stream* stream, mlx_device dev); +extern int (*mlx_set_default_stream_)(mlx_stream stream); +extern mlx_stream (*mlx_default_cpu_stream_new_)(void); +extern mlx_stream (*mlx_default_gpu_stream_new_)(void); +extern mlx_string (*mlx_string_new_)(void); +extern mlx_string (*mlx_string_new_data_)(const char* str); +extern int (*mlx_string_set_)(mlx_string* str, const mlx_string src); +extern const char * (*mlx_string_data_)(mlx_string str); +extern int (*mlx_string_free_)(mlx_string str); +extern int (*mlx_detail_vmap_replace_)( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_vector_array s_inputs, + const mlx_vector_array s_outputs, + const int* in_axes, + size_t in_axes_num, + const int* out_axes, + size_t out_axes_num); +extern int (*mlx_detail_vmap_trace_)( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array inputs, + const int* in_axes, + size_t in_axes_num); +extern int (*mlx_async_eval_)(const mlx_vector_array outputs); +extern int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun); +extern int (*mlx_custom_function_)( + mlx_closure* res, + const mlx_closure fun, + const mlx_closure_custom fun_vjp /* may be null */, + const mlx_closure_custom_jvp fun_jvp /* may be null */, + const mlx_closure_custom_vmap fun_vmap /* may be null */); +extern int (*mlx_custom_vjp_)( + mlx_closure* res, + const mlx_closure fun, + const mlx_closure_custom fun_vjp); +extern int (*mlx_eval_)(const mlx_vector_array outputs); +extern int (*mlx_jvp_)( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array primals, + const mlx_vector_array tangents); +extern int (*mlx_value_and_grad_)( + mlx_closure_value_and_grad* res, + const mlx_closure fun, + const int* argnums, + size_t argnums_num); +extern int (*mlx_vjp_)( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array primals, + const mlx_vector_array cotangents); +extern mlx_vector_array (*mlx_vector_array_new_)(void); +extern int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src); +extern int (*mlx_vector_array_free_)(mlx_vector_array vec); +extern mlx_vector_array (*mlx_vector_array_new_data_)(const mlx_array* data, size_t size); +extern mlx_vector_array (*mlx_vector_array_new_value_)(const mlx_array val); +extern int (*mlx_vector_array_set_data_)( + mlx_vector_array* vec, + const mlx_array* data, + size_t size); +extern int (*mlx_vector_array_set_value_)(mlx_vector_array* vec, const mlx_array val); +extern int (*mlx_vector_array_append_data_)( + mlx_vector_array vec, + const mlx_array* data, + size_t size); +extern int (*mlx_vector_array_append_value_)(mlx_vector_array vec, const mlx_array val); +extern size_t (*mlx_vector_array_size_)(mlx_vector_array vec); +extern int (*mlx_vector_array_get_)( + mlx_array* res, + const mlx_vector_array vec, + size_t idx); +extern mlx_vector_vector_array (*mlx_vector_vector_array_new_)(void); +extern int (*mlx_vector_vector_array_set_)( + mlx_vector_vector_array* vec, + const mlx_vector_vector_array src); +extern int (*mlx_vector_vector_array_free_)(mlx_vector_vector_array vec); +extern mlx_vector_vector_array (*mlx_vector_vector_array_new_data_)( + const mlx_vector_array* data, + size_t size); +extern mlx_vector_vector_array (*mlx_vector_vector_array_new_value_)( + const mlx_vector_array val); +extern int (*mlx_vector_vector_array_set_data_)( + mlx_vector_vector_array* vec, + const mlx_vector_array* data, + size_t size); +extern int (*mlx_vector_vector_array_set_value_)( + mlx_vector_vector_array* vec, + const mlx_vector_array val); +extern int (*mlx_vector_vector_array_append_data_)( + mlx_vector_vector_array vec, + const mlx_vector_array* data, + size_t size); +extern int (*mlx_vector_vector_array_append_value_)( + mlx_vector_vector_array vec, + const mlx_vector_array val); +extern size_t (*mlx_vector_vector_array_size_)(mlx_vector_vector_array vec); +extern int (*mlx_vector_vector_array_get_)( + mlx_vector_array* res, + const mlx_vector_vector_array vec, + size_t idx); +extern mlx_vector_int (*mlx_vector_int_new_)(void); +extern int (*mlx_vector_int_set_)(mlx_vector_int* vec, const mlx_vector_int src); +extern int (*mlx_vector_int_free_)(mlx_vector_int vec); +extern mlx_vector_int (*mlx_vector_int_new_data_)(int* data, size_t size); +extern mlx_vector_int (*mlx_vector_int_new_value_)(int val); +extern int (*mlx_vector_int_set_data_)(mlx_vector_int* vec, int* data, size_t size); +extern int (*mlx_vector_int_set_value_)(mlx_vector_int* vec, int val); +extern int (*mlx_vector_int_append_data_)(mlx_vector_int vec, int* data, size_t size); +extern int (*mlx_vector_int_append_value_)(mlx_vector_int vec, int val); +extern size_t (*mlx_vector_int_size_)(mlx_vector_int vec); +extern int (*mlx_vector_int_get_)(int* res, const mlx_vector_int vec, size_t idx); +extern mlx_vector_string (*mlx_vector_string_new_)(void); +extern int (*mlx_vector_string_set_)(mlx_vector_string* vec, const mlx_vector_string src); +extern int (*mlx_vector_string_free_)(mlx_vector_string vec); +extern mlx_vector_string (*mlx_vector_string_new_data_)(const char** data, size_t size); +extern mlx_vector_string (*mlx_vector_string_new_value_)(const char* val); +extern int (*mlx_vector_string_set_data_)( + mlx_vector_string* vec, + const char** data, + size_t size); +extern int (*mlx_vector_string_set_value_)(mlx_vector_string* vec, const char* val); +extern int (*mlx_vector_string_append_data_)( + mlx_vector_string vec, + const char** data, + size_t size); +extern int (*mlx_vector_string_append_value_)(mlx_vector_string vec, const char* val); +extern size_t (*mlx_vector_string_size_)(mlx_vector_string vec); +extern int (*mlx_vector_string_get_)(char** res, const mlx_vector_string vec, size_t idx); +extern int (*mlx_version_)(mlx_string* str_); + +int mlx_dynamic_load_symbols(mlx_dynamic_handle handle); + +static inline size_t mlx_dtype_size(mlx_dtype dtype) { + return mlx_dtype_size_(dtype); +} + +static inline int mlx_array_tostring(mlx_string* str, const mlx_array arr) { + return mlx_array_tostring_(str, arr); +} + +static inline mlx_array mlx_array_new(void) { + return mlx_array_new_(); +} + +static inline int mlx_array_free(mlx_array arr) { + return mlx_array_free_(arr); +} + +static inline mlx_array mlx_array_new_bool(bool val) { + return mlx_array_new_bool_(val); +} + +static inline mlx_array mlx_array_new_int(int val) { + return mlx_array_new_int_(val); +} + +static inline mlx_array mlx_array_new_float32(float val) { + return mlx_array_new_float32_(val); +} + +static inline mlx_array mlx_array_new_float(float val) { + return mlx_array_new_float_(val); +} + +static inline mlx_array mlx_array_new_float64(double val) { + return mlx_array_new_float64_(val); +} + +static inline mlx_array mlx_array_new_double(double val) { + return mlx_array_new_double_(val); +} + +static inline mlx_array mlx_array_new_complex(float real_val, float imag_val) { + return mlx_array_new_complex_(real_val, imag_val); +} + +static inline mlx_array mlx_array_new_data( + const void* data, + const int* shape, + int dim, + mlx_dtype dtype) { + return mlx_array_new_data_(data, shape, dim, dtype); +} + +static inline int mlx_array_set(mlx_array* arr, const mlx_array src) { + return mlx_array_set_(arr, src); +} + +static inline int mlx_array_set_bool(mlx_array* arr, bool val) { + return mlx_array_set_bool_(arr, val); +} + +static inline int mlx_array_set_int(mlx_array* arr, int val) { + return mlx_array_set_int_(arr, val); +} + +static inline int mlx_array_set_float32(mlx_array* arr, float val) { + return mlx_array_set_float32_(arr, val); +} + +static inline int mlx_array_set_float(mlx_array* arr, float val) { + return mlx_array_set_float_(arr, val); +} + +static inline int mlx_array_set_float64(mlx_array* arr, double val) { + return mlx_array_set_float64_(arr, val); +} + +static inline int mlx_array_set_double(mlx_array* arr, double val) { + return mlx_array_set_double_(arr, val); +} + +static inline int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val) { + return mlx_array_set_complex_(arr, real_val, imag_val); +} + +static inline int mlx_array_set_data( + mlx_array* arr, + const void* data, + const int* shape, + int dim, + mlx_dtype dtype) { + return mlx_array_set_data_(arr, data, shape, dim, dtype); +} + +static inline size_t mlx_array_itemsize(const mlx_array arr) { + return mlx_array_itemsize_(arr); +} + +static inline size_t mlx_array_size(const mlx_array arr) { + return mlx_array_size_(arr); +} + +static inline size_t mlx_array_nbytes(const mlx_array arr) { + return mlx_array_nbytes_(arr); +} + +static inline size_t mlx_array_ndim(const mlx_array arr) { + return mlx_array_ndim_(arr); +} + +static inline const int * mlx_array_shape(const mlx_array arr) { + return mlx_array_shape_(arr); +} + +static inline const size_t * mlx_array_strides(const mlx_array arr) { + return mlx_array_strides_(arr); +} + +static inline int mlx_array_dim(const mlx_array arr, int dim) { + return mlx_array_dim_(arr, dim); +} + +static inline mlx_dtype mlx_array_dtype(const mlx_array arr) { + return mlx_array_dtype_(arr); +} + +static inline int mlx_array_eval(mlx_array arr) { + return mlx_array_eval_(arr); +} + +static inline int mlx_array_item_bool(bool* res, const mlx_array arr) { + return mlx_array_item_bool_(res, arr); +} + +static inline int mlx_array_item_uint8(uint8_t* res, const mlx_array arr) { + return mlx_array_item_uint8_(res, arr); +} + +static inline int mlx_array_item_uint16(uint16_t* res, const mlx_array arr) { + return mlx_array_item_uint16_(res, arr); +} + +static inline int mlx_array_item_uint32(uint32_t* res, const mlx_array arr) { + return mlx_array_item_uint32_(res, arr); +} + +static inline int mlx_array_item_uint64(uint64_t* res, const mlx_array arr) { + return mlx_array_item_uint64_(res, arr); +} + +static inline int mlx_array_item_int8(int8_t* res, const mlx_array arr) { + return mlx_array_item_int8_(res, arr); +} + +static inline int mlx_array_item_int16(int16_t* res, const mlx_array arr) { + return mlx_array_item_int16_(res, arr); +} + +static inline int mlx_array_item_int32(int32_t* res, const mlx_array arr) { + return mlx_array_item_int32_(res, arr); +} + +static inline int mlx_array_item_int64(int64_t* res, const mlx_array arr) { + return mlx_array_item_int64_(res, arr); +} + +static inline int mlx_array_item_float32(float* res, const mlx_array arr) { + return mlx_array_item_float32_(res, arr); +} + +static inline int mlx_array_item_float64(double* res, const mlx_array arr) { + return mlx_array_item_float64_(res, arr); +} + +static inline int mlx_array_item_complex64(float _Complex* res, const mlx_array arr) { + return mlx_array_item_complex64_(res, arr); +} + +static inline int mlx_array_item_float16(float16_t* res, const mlx_array arr) { + return mlx_array_item_float16_(res, arr); +} + +static inline int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr) { + return mlx_array_item_bfloat16_(res, arr); +} + +static inline const bool * mlx_array_data_bool(const mlx_array arr) { + return mlx_array_data_bool_(arr); +} + +static inline const uint8_t * mlx_array_data_uint8(const mlx_array arr) { + return mlx_array_data_uint8_(arr); +} + +static inline const uint16_t * mlx_array_data_uint16(const mlx_array arr) { + return mlx_array_data_uint16_(arr); +} + +static inline const uint32_t * mlx_array_data_uint32(const mlx_array arr) { + return mlx_array_data_uint32_(arr); +} + +static inline const uint64_t * mlx_array_data_uint64(const mlx_array arr) { + return mlx_array_data_uint64_(arr); +} + +static inline const int8_t * mlx_array_data_int8(const mlx_array arr) { + return mlx_array_data_int8_(arr); +} + +static inline const int16_t * mlx_array_data_int16(const mlx_array arr) { + return mlx_array_data_int16_(arr); +} + +static inline const int32_t * mlx_array_data_int32(const mlx_array arr) { + return mlx_array_data_int32_(arr); +} + +static inline const int64_t * mlx_array_data_int64(const mlx_array arr) { + return mlx_array_data_int64_(arr); +} + +static inline const float * mlx_array_data_float32(const mlx_array arr) { + return mlx_array_data_float32_(arr); +} + +static inline const double * mlx_array_data_float64(const mlx_array arr) { + return mlx_array_data_float64_(arr); +} + +static inline const float _Complex * mlx_array_data_complex64(const mlx_array arr) { + return mlx_array_data_complex64_(arr); +} + +static inline const float16_t * mlx_array_data_float16(const mlx_array arr) { + return mlx_array_data_float16_(arr); +} + +static inline const bfloat16_t * mlx_array_data_bfloat16(const mlx_array arr) { + return mlx_array_data_bfloat16_(arr); +} + +static inline int _mlx_array_is_available(bool* res, const mlx_array arr) { + return _mlx_array_is_available_(res, arr); +} + +static inline int _mlx_array_wait(const mlx_array arr) { + return _mlx_array_wait_(arr); +} + +static inline int _mlx_array_is_contiguous(bool* res, const mlx_array arr) { + return _mlx_array_is_contiguous_(res, arr); +} + +static inline int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr) { + return _mlx_array_is_row_contiguous_(res, arr); +} + +static inline int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr) { + return _mlx_array_is_col_contiguous_(res, arr); +} + +static inline mlx_closure mlx_closure_new(void) { + return mlx_closure_new_(); +} + +static inline int mlx_closure_free(mlx_closure cls) { + return mlx_closure_free_(cls); +} + +static inline mlx_closure mlx_closure_new_func( + int (*fun)(mlx_vector_array*, const mlx_vector_array)) { + return mlx_closure_new_func_(fun); +} + +static inline mlx_closure mlx_closure_new_func_payload( + int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), + void* payload, + void (*dtor)(void*)) { + return mlx_closure_new_func_payload_(fun, payload, dtor); +} + +static inline int mlx_closure_set(mlx_closure* cls, const mlx_closure src) { + return mlx_closure_set_(cls, src); +} + +static inline int mlx_closure_apply( + mlx_vector_array* res, + mlx_closure cls, + const mlx_vector_array input) { + return mlx_closure_apply_(res, cls, input); +} + +static inline mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array)) { + return mlx_closure_new_unary_(fun); +} + +static inline mlx_closure_kwargs mlx_closure_kwargs_new(void) { + return mlx_closure_kwargs_new_(); +} + +static inline int mlx_closure_kwargs_free(mlx_closure_kwargs cls) { + return mlx_closure_kwargs_free_(cls); +} + +static inline mlx_closure_kwargs mlx_closure_kwargs_new_func(int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_map_string_to_array)) { + return mlx_closure_kwargs_new_func_(fun); +} + +static inline mlx_closure_kwargs mlx_closure_kwargs_new_func_payload( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_map_string_to_array, + void*), + void* payload, + void (*dtor)(void*)) { + return mlx_closure_kwargs_new_func_payload_(fun, payload, dtor); +} + +static inline int mlx_closure_kwargs_set( + mlx_closure_kwargs* cls, + const mlx_closure_kwargs src) { + return mlx_closure_kwargs_set_(cls, src); +} + +static inline int mlx_closure_kwargs_apply( + mlx_vector_array* res, + mlx_closure_kwargs cls, + const mlx_vector_array input_0, + const mlx_map_string_to_array input_1) { + return mlx_closure_kwargs_apply_(res, cls, input_0, input_1); +} + +static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void) { + return mlx_closure_value_and_grad_new_(); +} + +static inline int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls) { + return mlx_closure_value_and_grad_free_(cls); +} + +static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func( + int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)) { + return mlx_closure_value_and_grad_new_func_(fun); +} + +static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload( + int (*fun)( + mlx_vector_array*, + mlx_vector_array*, + const mlx_vector_array, + void*), + void* payload, + void (*dtor)(void*)) { + return mlx_closure_value_and_grad_new_func_payload_(fun, payload, dtor); +} + +static inline int mlx_closure_value_and_grad_set( + mlx_closure_value_and_grad* cls, + const mlx_closure_value_and_grad src) { + return mlx_closure_value_and_grad_set_(cls, src); +} + +static inline int mlx_closure_value_and_grad_apply( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + mlx_closure_value_and_grad cls, + const mlx_vector_array input) { + return mlx_closure_value_and_grad_apply_(res_0, res_1, cls, input); +} + +static inline mlx_closure_custom mlx_closure_custom_new(void) { + return mlx_closure_custom_new_(); +} + +static inline int mlx_closure_custom_free(mlx_closure_custom cls) { + return mlx_closure_custom_free_(cls); +} + +static inline mlx_closure_custom mlx_closure_custom_new_func(int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const mlx_vector_array)) { + return mlx_closure_custom_new_func_(fun); +} + +static inline mlx_closure_custom mlx_closure_custom_new_func_payload( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const mlx_vector_array, + void*), + void* payload, + void (*dtor)(void*)) { + return mlx_closure_custom_new_func_payload_(fun, payload, dtor); +} + +static inline int mlx_closure_custom_set( + mlx_closure_custom* cls, + const mlx_closure_custom src) { + return mlx_closure_custom_set_(cls, src); +} + +static inline int mlx_closure_custom_apply( + mlx_vector_array* res, + mlx_closure_custom cls, + const mlx_vector_array input_0, + const mlx_vector_array input_1, + const mlx_vector_array input_2) { + return mlx_closure_custom_apply_(res, cls, input_0, input_1, input_2); +} + +static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void) { + return mlx_closure_custom_jvp_new_(); +} + +static inline int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls) { + return mlx_closure_custom_jvp_free_(cls); +} + +static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func(int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const int*, + size_t _num)) { + return mlx_closure_custom_jvp_new_func_(fun); +} + +static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const int*, + size_t _num, + void*), + void* payload, + void (*dtor)(void*)) { + return mlx_closure_custom_jvp_new_func_payload_(fun, payload, dtor); +} + +static inline int mlx_closure_custom_jvp_set( + mlx_closure_custom_jvp* cls, + const mlx_closure_custom_jvp src) { + return mlx_closure_custom_jvp_set_(cls, src); +} + +static inline int mlx_closure_custom_jvp_apply( + mlx_vector_array* res, + mlx_closure_custom_jvp cls, + const mlx_vector_array input_0, + const mlx_vector_array input_1, + const int* input_2, + size_t input_2_num) { + return mlx_closure_custom_jvp_apply_(res, cls, input_0, input_1, input_2, input_2_num); +} + +static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void) { + return mlx_closure_custom_vmap_new_(); +} + +static inline int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls) { + return mlx_closure_custom_vmap_free_(cls); +} + +static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func(int (*fun)( + mlx_vector_array*, + mlx_vector_int*, + const mlx_vector_array, + const int*, + size_t _num)) { + return mlx_closure_custom_vmap_new_func_(fun); +} + +static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload( + int (*fun)( + mlx_vector_array*, + mlx_vector_int*, + const mlx_vector_array, + const int*, + size_t _num, + void*), + void* payload, + void (*dtor)(void*)) { + return mlx_closure_custom_vmap_new_func_payload_(fun, payload, dtor); +} + +static inline int mlx_closure_custom_vmap_set( + mlx_closure_custom_vmap* cls, + const mlx_closure_custom_vmap src) { + return mlx_closure_custom_vmap_set_(cls, src); +} + +static inline int mlx_closure_custom_vmap_apply( + mlx_vector_array* res_0, + mlx_vector_int* res_1, + mlx_closure_custom_vmap cls, + const mlx_vector_array input_0, + const int* input_1, + size_t input_1_num) { + return mlx_closure_custom_vmap_apply_(res_0, res_1, cls, input_0, input_1, input_1_num); +} + +static inline int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless) { + return mlx_compile_(res, fun, shapeless); +} + +static inline int mlx_detail_compile( + mlx_closure* res, + const mlx_closure fun, + uintptr_t fun_id, + bool shapeless, + const uint64_t* constants, + size_t constants_num) { + return mlx_detail_compile_(res, fun, fun_id, shapeless, constants, constants_num); +} + +static inline int mlx_detail_compile_clear_cache(void) { + return mlx_detail_compile_clear_cache_(); +} + +static inline int mlx_detail_compile_erase(uintptr_t fun_id) { + return mlx_detail_compile_erase_(fun_id); +} + +static inline int mlx_disable_compile(void) { + return mlx_disable_compile_(); +} + +static inline int mlx_enable_compile(void) { + return mlx_enable_compile_(); +} + +static inline int mlx_set_compile_mode(mlx_compile_mode mode) { + return mlx_set_compile_mode_(mode); +} + +static inline mlx_device mlx_device_new(void) { + return mlx_device_new_(); +} + +static inline mlx_device mlx_device_new_type(mlx_device_type type, int index) { + return mlx_device_new_type_(type, index); +} + +static inline int mlx_device_free(mlx_device dev) { + return mlx_device_free_(dev); +} + +static inline int mlx_device_set(mlx_device* dev, const mlx_device src) { + return mlx_device_set_(dev, src); +} + +static inline int mlx_device_tostring(mlx_string* str, mlx_device dev) { + return mlx_device_tostring_(str, dev); +} + +static inline bool mlx_device_equal(mlx_device lhs, mlx_device rhs) { + return mlx_device_equal_(lhs, rhs); +} + +static inline int mlx_device_get_index(int* index, mlx_device dev) { + return mlx_device_get_index_(index, dev); +} + +static inline int mlx_device_get_type(mlx_device_type* type, mlx_device dev) { + return mlx_device_get_type_(type, dev); +} + +static inline int mlx_get_default_device(mlx_device* dev) { + return mlx_get_default_device_(dev); +} + +static inline int mlx_set_default_device(mlx_device dev) { + return mlx_set_default_device_(dev); +} + +static inline int mlx_distributed_group_rank(mlx_distributed_group group) { + return mlx_distributed_group_rank_(group); +} + +static inline int mlx_distributed_group_size(mlx_distributed_group group) { + return mlx_distributed_group_size_(group); +} + +static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) { + return mlx_distributed_group_split_(group, color, key); +} + +static inline bool mlx_distributed_is_available(void) { + return mlx_distributed_is_available_(); +} + +static inline mlx_distributed_group mlx_distributed_init(bool strict) { + return mlx_distributed_init_(strict); +} + +static inline int mlx_distributed_all_gather( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream S) { + return mlx_distributed_all_gather_(res, x, group, S); +} + +static inline int mlx_distributed_all_max( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s) { + return mlx_distributed_all_max_(res, x, group, s); +} + +static inline int mlx_distributed_all_min( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s) { + return mlx_distributed_all_min_(res, x, group, s); +} + +static inline int mlx_distributed_all_sum( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s) { + return mlx_distributed_all_sum_(res, x, group, s); +} + +static inline int mlx_distributed_recv( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + int src, + const mlx_distributed_group group /* may be null */, + const mlx_stream s) { + return mlx_distributed_recv_(res, shape, shape_num, dtype, src, group, s); +} + +static inline int mlx_distributed_recv_like( + mlx_array* res, + const mlx_array x, + int src, + const mlx_distributed_group group /* may be null */, + const mlx_stream s) { + return mlx_distributed_recv_like_(res, x, src, group, s); +} + +static inline int mlx_distributed_send( + mlx_array* res, + const mlx_array x, + int dst, + const mlx_distributed_group group /* may be null */, + const mlx_stream s) { + return mlx_distributed_send_(res, x, dst, group, s); +} + +static inline int mlx_distributed_sum_scatter( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s) { + return mlx_distributed_sum_scatter_(res, x, group, s); +} + +static inline void mlx_set_error_handler( + mlx_error_handler_func handler, + void* data, + void (*dtor)(void*)) { + mlx_set_error_handler_(handler, data, dtor); +} + +#define _mlx_error(file, line, fmt, ...) _mlx_error_(file, line, fmt, __VA_ARGS__) + +static inline int mlx_export_function( + const char* file, + const mlx_closure fun, + const mlx_vector_array args, + bool shapeless) { + return mlx_export_function_(file, fun, args, shapeless); +} + +static inline int mlx_export_function_kwargs( + const char* file, + const mlx_closure_kwargs fun, + const mlx_vector_array args, + const mlx_map_string_to_array kwargs, + bool shapeless) { + return mlx_export_function_kwargs_(file, fun, args, kwargs, shapeless); +} + +static inline mlx_function_exporter mlx_function_exporter_new( + const char* file, + const mlx_closure fun, + bool shapeless) { + return mlx_function_exporter_new_(file, fun, shapeless); +} + +static inline int mlx_function_exporter_free(mlx_function_exporter xfunc) { + return mlx_function_exporter_free_(xfunc); +} + +static inline int mlx_function_exporter_apply( + const mlx_function_exporter xfunc, + const mlx_vector_array args) { + return mlx_function_exporter_apply_(xfunc, args); +} + +static inline int mlx_function_exporter_apply_kwargs( + const mlx_function_exporter xfunc, + const mlx_vector_array args, + const mlx_map_string_to_array kwargs) { + return mlx_function_exporter_apply_kwargs_(xfunc, args, kwargs); +} + +static inline mlx_imported_function mlx_imported_function_new(const char* file) { + return mlx_imported_function_new_(file); +} + +static inline int mlx_imported_function_free(mlx_imported_function xfunc) { + return mlx_imported_function_free_(xfunc); +} + +static inline int mlx_imported_function_apply( + mlx_vector_array* res, + const mlx_imported_function xfunc, + const mlx_vector_array args) { + return mlx_imported_function_apply_(res, xfunc, args); +} + +static inline int mlx_imported_function_apply_kwargs( + mlx_vector_array* res, + const mlx_imported_function xfunc, + const mlx_vector_array args, + const mlx_map_string_to_array kwargs) { + return mlx_imported_function_apply_kwargs_(res, xfunc, args, kwargs); +} + +static inline mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void) { + return mlx_fast_cuda_kernel_config_new_(); +} + +static inline void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls) { + mlx_fast_cuda_kernel_config_free_(cls); +} + +static inline int mlx_fast_cuda_kernel_config_add_output_arg( + mlx_fast_cuda_kernel_config cls, + const int* shape, + size_t size, + mlx_dtype dtype) { + return mlx_fast_cuda_kernel_config_add_output_arg_(cls, shape, size, dtype); +} + +static inline int mlx_fast_cuda_kernel_config_set_grid( + mlx_fast_cuda_kernel_config cls, + int grid1, + int grid2, + int grid3) { + return mlx_fast_cuda_kernel_config_set_grid_(cls, grid1, grid2, grid3); +} + +static inline int mlx_fast_cuda_kernel_config_set_thread_group( + mlx_fast_cuda_kernel_config cls, + int thread1, + int thread2, + int thread3) { + return mlx_fast_cuda_kernel_config_set_thread_group_(cls, thread1, thread2, thread3); +} + +static inline int mlx_fast_cuda_kernel_config_set_init_value( + mlx_fast_cuda_kernel_config cls, + float value) { + return mlx_fast_cuda_kernel_config_set_init_value_(cls, value); +} + +static inline int mlx_fast_cuda_kernel_config_set_verbose( + mlx_fast_cuda_kernel_config cls, + bool verbose) { + return mlx_fast_cuda_kernel_config_set_verbose_(cls, verbose); +} + +static inline int mlx_fast_cuda_kernel_config_add_template_arg_dtype( + mlx_fast_cuda_kernel_config cls, + const char* name, + mlx_dtype dtype) { + return mlx_fast_cuda_kernel_config_add_template_arg_dtype_(cls, name, dtype); +} + +static inline int mlx_fast_cuda_kernel_config_add_template_arg_int( + mlx_fast_cuda_kernel_config cls, + const char* name, + int value) { + return mlx_fast_cuda_kernel_config_add_template_arg_int_(cls, name, value); +} + +static inline int mlx_fast_cuda_kernel_config_add_template_arg_bool( + mlx_fast_cuda_kernel_config cls, + const char* name, + bool value) { + return mlx_fast_cuda_kernel_config_add_template_arg_bool_(cls, name, value); +} + +static inline mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new( + const char* name, + const mlx_vector_string input_names, + const mlx_vector_string output_names, + const char* source, + const char* header, + bool ensure_row_contiguous, + int shared_memory) { + return mlx_fast_cuda_kernel_new_(name, input_names, output_names, source, header, ensure_row_contiguous, shared_memory); +} + +static inline void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls) { + mlx_fast_cuda_kernel_free_(cls); +} + +static inline int mlx_fast_cuda_kernel_apply( + mlx_vector_array* outputs, + mlx_fast_cuda_kernel cls, + const mlx_vector_array inputs, + const mlx_fast_cuda_kernel_config config, + const mlx_stream stream) { + return mlx_fast_cuda_kernel_apply_(outputs, cls, inputs, config, stream); +} + +static inline int mlx_fast_layer_norm( + mlx_array* res, + const mlx_array x, + const mlx_array weight /* may be null */, + const mlx_array bias /* may be null */, + float eps, + const mlx_stream s) { + return mlx_fast_layer_norm_(res, x, weight, bias, eps, s); +} + +static inline mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void) { + return mlx_fast_metal_kernel_config_new_(); +} + +static inline void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls) { + mlx_fast_metal_kernel_config_free_(cls); +} + +static inline int mlx_fast_metal_kernel_config_add_output_arg( + mlx_fast_metal_kernel_config cls, + const int* shape, + size_t size, + mlx_dtype dtype) { + return mlx_fast_metal_kernel_config_add_output_arg_(cls, shape, size, dtype); +} + +static inline int mlx_fast_metal_kernel_config_set_grid( + mlx_fast_metal_kernel_config cls, + int grid1, + int grid2, + int grid3) { + return mlx_fast_metal_kernel_config_set_grid_(cls, grid1, grid2, grid3); +} + +static inline int mlx_fast_metal_kernel_config_set_thread_group( + mlx_fast_metal_kernel_config cls, + int thread1, + int thread2, + int thread3) { + return mlx_fast_metal_kernel_config_set_thread_group_(cls, thread1, thread2, thread3); +} + +static inline int mlx_fast_metal_kernel_config_set_init_value( + mlx_fast_metal_kernel_config cls, + float value) { + return mlx_fast_metal_kernel_config_set_init_value_(cls, value); +} + +static inline int mlx_fast_metal_kernel_config_set_verbose( + mlx_fast_metal_kernel_config cls, + bool verbose) { + return mlx_fast_metal_kernel_config_set_verbose_(cls, verbose); +} + +static inline int mlx_fast_metal_kernel_config_add_template_arg_dtype( + mlx_fast_metal_kernel_config cls, + const char* name, + mlx_dtype dtype) { + return mlx_fast_metal_kernel_config_add_template_arg_dtype_(cls, name, dtype); +} + +static inline int mlx_fast_metal_kernel_config_add_template_arg_int( + mlx_fast_metal_kernel_config cls, + const char* name, + int value) { + return mlx_fast_metal_kernel_config_add_template_arg_int_(cls, name, value); +} + +static inline int mlx_fast_metal_kernel_config_add_template_arg_bool( + mlx_fast_metal_kernel_config cls, + const char* name, + bool value) { + return mlx_fast_metal_kernel_config_add_template_arg_bool_(cls, name, value); +} + +static inline mlx_fast_metal_kernel mlx_fast_metal_kernel_new( + const char* name, + const mlx_vector_string input_names, + const mlx_vector_string output_names, + const char* source, + const char* header, + bool ensure_row_contiguous, + bool atomic_outputs) { + return mlx_fast_metal_kernel_new_(name, input_names, output_names, source, header, ensure_row_contiguous, atomic_outputs); +} + +static inline void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls) { + mlx_fast_metal_kernel_free_(cls); +} + +static inline int mlx_fast_metal_kernel_apply( + mlx_vector_array* outputs, + mlx_fast_metal_kernel cls, + const mlx_vector_array inputs, + const mlx_fast_metal_kernel_config config, + const mlx_stream stream) { + return mlx_fast_metal_kernel_apply_(outputs, cls, inputs, config, stream); +} + +static inline int mlx_fast_rms_norm( + mlx_array* res, + const mlx_array x, + const mlx_array weight /* may be null */, + float eps, + const mlx_stream s) { + return mlx_fast_rms_norm_(res, x, weight, eps, s); +} + +static inline int mlx_fast_rope( + mlx_array* res, + const mlx_array x, + int dims, + bool traditional, + mlx_optional_float base, + float scale, + int offset, + const mlx_array freqs /* may be null */, + const mlx_stream s) { + return mlx_fast_rope_(res, x, dims, traditional, base, scale, offset, freqs, s); +} + +static inline int mlx_fast_scaled_dot_product_attention( + mlx_array* res, + const mlx_array queries, + const mlx_array keys, + const mlx_array values, + float scale, + const char* mask_mode, + const mlx_array mask_arr /* may be null */, + const mlx_array sinks /* may be null */, + const mlx_stream s) { + return mlx_fast_scaled_dot_product_attention_(res, queries, keys, values, scale, mask_mode, mask_arr, sinks, s); +} + +static inline int mlx_fft_fft( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s) { + return mlx_fft_fft_(res, a, n, axis, s); +} + +static inline int mlx_fft_fft2( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_fft_fft2_(res, a, n, n_num, axes, axes_num, s); +} + +static inline int mlx_fft_fftn( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_fft_fftn_(res, a, n, n_num, axes, axes_num, s); +} + +static inline int mlx_fft_fftshift( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_fft_fftshift_(res, a, axes, axes_num, s); +} + +static inline int mlx_fft_ifft( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s) { + return mlx_fft_ifft_(res, a, n, axis, s); +} + +static inline int mlx_fft_ifft2( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_fft_ifft2_(res, a, n, n_num, axes, axes_num, s); +} + +static inline int mlx_fft_ifftn( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_fft_ifftn_(res, a, n, n_num, axes, axes_num, s); +} + +static inline int mlx_fft_ifftshift( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_fft_ifftshift_(res, a, axes, axes_num, s); +} + +static inline int mlx_fft_irfft( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s) { + return mlx_fft_irfft_(res, a, n, axis, s); +} + +static inline int mlx_fft_irfft2( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_fft_irfft2_(res, a, n, n_num, axes, axes_num, s); +} + +static inline int mlx_fft_irfftn( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_fft_irfftn_(res, a, n, n_num, axes, axes_num, s); +} + +static inline int mlx_fft_rfft( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s) { + return mlx_fft_rfft_(res, a, n, axis, s); +} + +static inline int mlx_fft_rfft2( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_fft_rfft2_(res, a, n, n_num, axes, axes_num, s); +} + +static inline int mlx_fft_rfftn( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_fft_rfftn_(res, a, n, n_num, axes, axes_num, s); +} + +static inline mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable) { + return mlx_io_reader_new_(desc, vtable); +} + +static inline int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io) { + return mlx_io_reader_descriptor_(desc_, io); +} + +static inline int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io) { + return mlx_io_reader_tostring_(str_, io); +} + +static inline int mlx_io_reader_free(mlx_io_reader io) { + return mlx_io_reader_free_(io); +} + +static inline mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable) { + return mlx_io_writer_new_(desc, vtable); +} + +static inline int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io) { + return mlx_io_writer_descriptor_(desc_, io); +} + +static inline int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io) { + return mlx_io_writer_tostring_(str_, io); +} + +static inline int mlx_io_writer_free(mlx_io_writer io) { + return mlx_io_writer_free_(io); +} + +static inline int mlx_load_reader( + mlx_array* res, + mlx_io_reader in_stream, + const mlx_stream s) { + return mlx_load_reader_(res, in_stream, s); +} + +static inline int mlx_load(mlx_array* res, const char* file, const mlx_stream s) { + return mlx_load_(res, file, s); +} + +static inline int mlx_load_safetensors_reader( + mlx_map_string_to_array* res_0, + mlx_map_string_to_string* res_1, + mlx_io_reader in_stream, + const mlx_stream s) { + return mlx_load_safetensors_reader_(res_0, res_1, in_stream, s); +} + +static inline int mlx_load_safetensors( + mlx_map_string_to_array* res_0, + mlx_map_string_to_string* res_1, + const char* file, + const mlx_stream s) { + return mlx_load_safetensors_(res_0, res_1, file, s); +} + +static inline int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a) { + return mlx_save_writer_(out_stream, a); +} + +static inline int mlx_save(const char* file, const mlx_array a) { + return mlx_save_(file, a); +} + +static inline int mlx_save_safetensors_writer( + mlx_io_writer in_stream, + const mlx_map_string_to_array param, + const mlx_map_string_to_string metadata) { + return mlx_save_safetensors_writer_(in_stream, param, metadata); +} + +static inline int mlx_save_safetensors( + const char* file, + const mlx_map_string_to_array param, + const mlx_map_string_to_string metadata) { + return mlx_save_safetensors_(file, param, metadata); +} + +static inline int mlx_linalg_cholesky( + mlx_array* res, + const mlx_array a, + bool upper, + const mlx_stream s) { + return mlx_linalg_cholesky_(res, a, upper, s); +} + +static inline int mlx_linalg_cholesky_inv( + mlx_array* res, + const mlx_array a, + bool upper, + const mlx_stream s) { + return mlx_linalg_cholesky_inv_(res, a, upper, s); +} + +static inline int mlx_linalg_cross( + mlx_array* res, + const mlx_array a, + const mlx_array b, + int axis, + const mlx_stream s) { + return mlx_linalg_cross_(res, a, b, axis, s); +} + +static inline int mlx_linalg_eig( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const mlx_stream s) { + return mlx_linalg_eig_(res_0, res_1, a, s); +} + +static inline int mlx_linalg_eigh( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const char* UPLO, + const mlx_stream s) { + return mlx_linalg_eigh_(res_0, res_1, a, UPLO, s); +} + +static inline int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_linalg_eigvals_(res, a, s); +} + +static inline int mlx_linalg_eigvalsh( + mlx_array* res, + const mlx_array a, + const char* UPLO, + const mlx_stream s) { + return mlx_linalg_eigvalsh_(res, a, UPLO, s); +} + +static inline int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_linalg_inv_(res, a, s); +} + +static inline int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s) { + return mlx_linalg_lu_(res, a, s); +} + +static inline int mlx_linalg_lu_factor( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const mlx_stream s) { + return mlx_linalg_lu_factor_(res_0, res_1, a, s); +} + +static inline int mlx_linalg_norm( + mlx_array* res, + const mlx_array a, + double ord, + const int* axis /* may be null */, + size_t axis_num, + bool keepdims, + const mlx_stream s) { + return mlx_linalg_norm_(res, a, ord, axis, axis_num, keepdims, s); +} + +static inline int mlx_linalg_norm_matrix( + mlx_array* res, + const mlx_array a, + const char* ord, + const int* axis /* may be null */, + size_t axis_num, + bool keepdims, + const mlx_stream s) { + return mlx_linalg_norm_matrix_(res, a, ord, axis, axis_num, keepdims, s); +} + +static inline int mlx_linalg_norm_l2( + mlx_array* res, + const mlx_array a, + const int* axis /* may be null */, + size_t axis_num, + bool keepdims, + const mlx_stream s) { + return mlx_linalg_norm_l2_(res, a, axis, axis_num, keepdims, s); +} + +static inline int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_linalg_pinv_(res, a, s); +} + +static inline int mlx_linalg_qr( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const mlx_stream s) { + return mlx_linalg_qr_(res_0, res_1, a, s); +} + +static inline int mlx_linalg_solve( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_linalg_solve_(res, a, b, s); +} + +static inline int mlx_linalg_solve_triangular( + mlx_array* res, + const mlx_array a, + const mlx_array b, + bool upper, + const mlx_stream s) { + return mlx_linalg_solve_triangular_(res, a, b, upper, s); +} + +static inline int mlx_linalg_svd( + mlx_vector_array* res, + const mlx_array a, + bool compute_uv, + const mlx_stream s) { + return mlx_linalg_svd_(res, a, compute_uv, s); +} + +static inline int mlx_linalg_tri_inv( + mlx_array* res, + const mlx_array a, + bool upper, + const mlx_stream s) { + return mlx_linalg_tri_inv_(res, a, upper, s); +} + +static inline mlx_map_string_to_array mlx_map_string_to_array_new(void) { + return mlx_map_string_to_array_new_(); +} + +static inline int mlx_map_string_to_array_set( + mlx_map_string_to_array* map, + const mlx_map_string_to_array src) { + return mlx_map_string_to_array_set_(map, src); +} + +static inline int mlx_map_string_to_array_free(mlx_map_string_to_array map) { + return mlx_map_string_to_array_free_(map); +} + +static inline int mlx_map_string_to_array_insert( + mlx_map_string_to_array map, + const char* key, + const mlx_array value) { + return mlx_map_string_to_array_insert_(map, key, value); +} + +static inline int mlx_map_string_to_array_get( + mlx_array* value, + const mlx_map_string_to_array map, + const char* key) { + return mlx_map_string_to_array_get_(value, map, key); +} + +static inline mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new( + mlx_map_string_to_array map) { + return mlx_map_string_to_array_iterator_new_(map); +} + +static inline int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it) { + return mlx_map_string_to_array_iterator_free_(it); +} + +static inline int mlx_map_string_to_array_iterator_next( + const char** key, + mlx_array* value, + mlx_map_string_to_array_iterator it) { + return mlx_map_string_to_array_iterator_next_(key, value, it); +} + +static inline mlx_map_string_to_string mlx_map_string_to_string_new(void) { + return mlx_map_string_to_string_new_(); +} + +static inline int mlx_map_string_to_string_set( + mlx_map_string_to_string* map, + const mlx_map_string_to_string src) { + return mlx_map_string_to_string_set_(map, src); +} + +static inline int mlx_map_string_to_string_free(mlx_map_string_to_string map) { + return mlx_map_string_to_string_free_(map); +} + +static inline int mlx_map_string_to_string_insert( + mlx_map_string_to_string map, + const char* key, + const char* value) { + return mlx_map_string_to_string_insert_(map, key, value); +} + +static inline int mlx_map_string_to_string_get( + const char** value, + const mlx_map_string_to_string map, + const char* key) { + return mlx_map_string_to_string_get_(value, map, key); +} + +static inline mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new( + mlx_map_string_to_string map) { + return mlx_map_string_to_string_iterator_new_(map); +} + +static inline int mlx_map_string_to_string_iterator_free( + mlx_map_string_to_string_iterator it) { + return mlx_map_string_to_string_iterator_free_(it); +} + +static inline int mlx_map_string_to_string_iterator_next( + const char** key, + const char** value, + mlx_map_string_to_string_iterator it) { + return mlx_map_string_to_string_iterator_next_(key, value, it); +} + +static inline int mlx_clear_cache(void) { + return mlx_clear_cache_(); +} + +static inline int mlx_get_active_memory(size_t* res) { + return mlx_get_active_memory_(res); +} + +static inline int mlx_get_cache_memory(size_t* res) { + return mlx_get_cache_memory_(res); +} + +static inline int mlx_get_memory_limit(size_t* res) { + return mlx_get_memory_limit_(res); +} + +static inline int mlx_get_peak_memory(size_t* res) { + return mlx_get_peak_memory_(res); +} + +static inline int mlx_reset_peak_memory(void) { + return mlx_reset_peak_memory_(); +} + +static inline int mlx_set_cache_limit(size_t* res, size_t limit) { + return mlx_set_cache_limit_(res, limit); +} + +static inline int mlx_set_memory_limit(size_t* res, size_t limit) { + return mlx_set_memory_limit_(res, limit); +} + +static inline int mlx_set_wired_limit(size_t* res, size_t limit) { + return mlx_set_wired_limit_(res, limit); +} + +static inline mlx_metal_device_info_t mlx_metal_device_info(void) { + return mlx_metal_device_info_(); +} + +static inline int mlx_metal_is_available(bool* res) { + return mlx_metal_is_available_(res); +} + +static inline int mlx_metal_start_capture(const char* path) { + return mlx_metal_start_capture_(path); +} + +static inline int mlx_metal_stop_capture(void) { + return mlx_metal_stop_capture_(); +} + +static inline int mlx_abs(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_abs_(res, a, s); +} + +static inline int mlx_add( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_add_(res, a, b, s); +} + +static inline int mlx_addmm( + mlx_array* res, + const mlx_array c, + const mlx_array a, + const mlx_array b, + float alpha, + float beta, + const mlx_stream s) { + return mlx_addmm_(res, c, a, b, alpha, beta, s); +} + +static inline int mlx_all_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s) { + return mlx_all_axes_(res, a, axes, axes_num, keepdims, s); +} + +static inline int mlx_all_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) { + return mlx_all_axis_(res, a, axis, keepdims, s); +} + +static inline int mlx_all( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) { + return mlx_all_(res, a, keepdims, s); +} + +static inline int mlx_allclose( + mlx_array* res, + const mlx_array a, + const mlx_array b, + double rtol, + double atol, + bool equal_nan, + const mlx_stream s) { + return mlx_allclose_(res, a, b, rtol, atol, equal_nan, s); +} + +static inline int mlx_any_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s) { + return mlx_any_axes_(res, a, axes, axes_num, keepdims, s); +} + +static inline int mlx_any_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) { + return mlx_any_axis_(res, a, axis, keepdims, s); +} + +static inline int mlx_any( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) { + return mlx_any_(res, a, keepdims, s); +} + +static inline int mlx_arange( + mlx_array* res, + double start, + double stop, + double step, + mlx_dtype dtype, + const mlx_stream s) { + return mlx_arange_(res, start, stop, step, dtype, s); +} + +static inline int mlx_arccos(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_arccos_(res, a, s); +} + +static inline int mlx_arccosh(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_arccosh_(res, a, s); +} + +static inline int mlx_arcsin(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_arcsin_(res, a, s); +} + +static inline int mlx_arcsinh(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_arcsinh_(res, a, s); +} + +static inline int mlx_arctan(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_arctan_(res, a, s); +} + +static inline int mlx_arctan2( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_arctan2_(res, a, b, s); +} + +static inline int mlx_arctanh(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_arctanh_(res, a, s); +} + +static inline int mlx_argmax_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) { + return mlx_argmax_axis_(res, a, axis, keepdims, s); +} + +static inline int mlx_argmax( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) { + return mlx_argmax_(res, a, keepdims, s); +} + +static inline int mlx_argmin_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) { + return mlx_argmin_axis_(res, a, axis, keepdims, s); +} + +static inline int mlx_argmin( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) { + return mlx_argmin_(res, a, keepdims, s); +} + +static inline int mlx_argpartition_axis( + mlx_array* res, + const mlx_array a, + int kth, + int axis, + const mlx_stream s) { + return mlx_argpartition_axis_(res, a, kth, axis, s); +} + +static inline int mlx_argpartition( + mlx_array* res, + const mlx_array a, + int kth, + const mlx_stream s) { + return mlx_argpartition_(res, a, kth, s); +} + +static inline int mlx_argsort_axis( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s) { + return mlx_argsort_axis_(res, a, axis, s); +} + +static inline int mlx_argsort(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_argsort_(res, a, s); +} + +static inline int mlx_array_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + bool equal_nan, + const mlx_stream s) { + return mlx_array_equal_(res, a, b, equal_nan, s); +} + +static inline int mlx_as_strided( + mlx_array* res, + const mlx_array a, + const int* shape, + size_t shape_num, + const int64_t* strides, + size_t strides_num, + size_t offset, + const mlx_stream s) { + return mlx_as_strided_(res, a, shape, shape_num, strides, strides_num, offset, s); +} + +static inline int mlx_astype( + mlx_array* res, + const mlx_array a, + mlx_dtype dtype, + const mlx_stream s) { + return mlx_astype_(res, a, dtype, s); +} + +static inline int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_atleast_1d_(res, a, s); +} + +static inline int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_atleast_2d_(res, a, s); +} + +static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_atleast_3d_(res, a, s); +} + +static inline int mlx_bitwise_and( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_bitwise_and_(res, a, b, s); +} + +static inline int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_bitwise_invert_(res, a, s); +} + +static inline int mlx_bitwise_or( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_bitwise_or_(res, a, b, s); +} + +static inline int mlx_bitwise_xor( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_bitwise_xor_(res, a, b, s); +} + +static inline int mlx_block_masked_mm( + mlx_array* res, + const mlx_array a, + const mlx_array b, + int block_size, + const mlx_array mask_out /* may be null */, + const mlx_array mask_lhs /* may be null */, + const mlx_array mask_rhs /* may be null */, + const mlx_stream s) { + return mlx_block_masked_mm_(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s); +} + +static inline int mlx_broadcast_arrays( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_stream s) { + return mlx_broadcast_arrays_(res, inputs, s); +} + +static inline int mlx_broadcast_to( + mlx_array* res, + const mlx_array a, + const int* shape, + size_t shape_num, + const mlx_stream s) { + return mlx_broadcast_to_(res, a, shape, shape_num, s); +} + +static inline int mlx_ceil(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_ceil_(res, a, s); +} + +static inline int mlx_clip( + mlx_array* res, + const mlx_array a, + const mlx_array a_min /* may be null */, + const mlx_array a_max /* may be null */, + const mlx_stream s) { + return mlx_clip_(res, a, a_min, a_max, s); +} + +static inline int mlx_concatenate_axis( + mlx_array* res, + const mlx_vector_array arrays, + int axis, + const mlx_stream s) { + return mlx_concatenate_axis_(res, arrays, axis, s); +} + +static inline int mlx_concatenate( + mlx_array* res, + const mlx_vector_array arrays, + const mlx_stream s) { + return mlx_concatenate_(res, arrays, s); +} + +static inline int mlx_conjugate(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_conjugate_(res, a, s); +} + +static inline int mlx_contiguous( + mlx_array* res, + const mlx_array a, + bool allow_col_major, + const mlx_stream s) { + return mlx_contiguous_(res, a, allow_col_major, s); +} + +static inline int mlx_conv1d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride, + int padding, + int dilation, + int groups, + const mlx_stream s) { + return mlx_conv1d_(res, input, weight, stride, padding, dilation, groups, s); +} + +static inline int mlx_conv2d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int padding_0, + int padding_1, + int dilation_0, + int dilation_1, + int groups, + const mlx_stream s) { + return mlx_conv2d_(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, groups, s); +} + +static inline int mlx_conv3d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int stride_2, + int padding_0, + int padding_1, + int padding_2, + int dilation_0, + int dilation_1, + int dilation_2, + int groups, + const mlx_stream s) { + return mlx_conv3d_(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, groups, s); +} + +static inline int mlx_conv_general( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + const int* stride, + size_t stride_num, + const int* padding_lo, + size_t padding_lo_num, + const int* padding_hi, + size_t padding_hi_num, + const int* kernel_dilation, + size_t kernel_dilation_num, + const int* input_dilation, + size_t input_dilation_num, + int groups, + bool flip, + const mlx_stream s) { + return mlx_conv_general_(res, input, weight, stride, stride_num, padding_lo, padding_lo_num, padding_hi, padding_hi_num, kernel_dilation, kernel_dilation_num, input_dilation, input_dilation_num, groups, flip, s); +} + +static inline int mlx_conv_transpose1d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride, + int padding, + int dilation, + int output_padding, + int groups, + const mlx_stream s) { + return mlx_conv_transpose1d_(res, input, weight, stride, padding, dilation, output_padding, groups, s); +} + +static inline int mlx_conv_transpose2d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int padding_0, + int padding_1, + int dilation_0, + int dilation_1, + int output_padding_0, + int output_padding_1, + int groups, + const mlx_stream s) { + return mlx_conv_transpose2d_(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, output_padding_0, output_padding_1, groups, s); +} + +static inline int mlx_conv_transpose3d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int stride_2, + int padding_0, + int padding_1, + int padding_2, + int dilation_0, + int dilation_1, + int dilation_2, + int output_padding_0, + int output_padding_1, + int output_padding_2, + int groups, + const mlx_stream s) { + return mlx_conv_transpose3d_(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, output_padding_0, output_padding_1, output_padding_2, groups, s); +} + +static inline int mlx_copy(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_copy_(res, a, s); +} + +static inline int mlx_cos(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_cos_(res, a, s); +} + +static inline int mlx_cosh(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_cosh_(res, a, s); +} + +static inline int mlx_cummax( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s) { + return mlx_cummax_(res, a, axis, reverse, inclusive, s); +} + +static inline int mlx_cummin( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s) { + return mlx_cummin_(res, a, axis, reverse, inclusive, s); +} + +static inline int mlx_cumprod( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s) { + return mlx_cumprod_(res, a, axis, reverse, inclusive, s); +} + +static inline int mlx_cumsum( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s) { + return mlx_cumsum_(res, a, axis, reverse, inclusive, s); +} + +static inline int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_degrees_(res, a, s); +} + +static inline int mlx_depends( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_vector_array dependencies) { + return mlx_depends_(res, inputs, dependencies); +} + +static inline int mlx_dequantize( + mlx_array* res, + const mlx_array w, + const mlx_array scales, + const mlx_array biases /* may be null */, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + mlx_optional_dtype dtype, + const mlx_stream s) { + return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, dtype, s); +} + +static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) { + return mlx_diag_(res, a, k, s); +} + +static inline int mlx_diagonal( + mlx_array* res, + const mlx_array a, + int offset, + int axis1, + int axis2, + const mlx_stream s) { + return mlx_diagonal_(res, a, offset, axis1, axis2, s); +} + +static inline int mlx_divide( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_divide_(res, a, b, s); +} + +static inline int mlx_divmod( + mlx_vector_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_divmod_(res, a, b, s); +} + +static inline int mlx_einsum( + mlx_array* res, + const char* subscripts, + const mlx_vector_array operands, + const mlx_stream s) { + return mlx_einsum_(res, subscripts, operands, s); +} + +static inline int mlx_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_equal_(res, a, b, s); +} + +static inline int mlx_erf(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_erf_(res, a, s); +} + +static inline int mlx_erfinv(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_erfinv_(res, a, s); +} + +static inline int mlx_exp(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_exp_(res, a, s); +} + +static inline int mlx_expand_dims_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_expand_dims_axes_(res, a, axes, axes_num, s); +} + +static inline int mlx_expand_dims( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s) { + return mlx_expand_dims_(res, a, axis, s); +} + +static inline int mlx_expm1(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_expm1_(res, a, s); +} + +static inline int mlx_eye( + mlx_array* res, + int n, + int m, + int k, + mlx_dtype dtype, + const mlx_stream s) { + return mlx_eye_(res, n, m, k, dtype, s); +} + +static inline int mlx_flatten( + mlx_array* res, + const mlx_array a, + int start_axis, + int end_axis, + const mlx_stream s) { + return mlx_flatten_(res, a, start_axis, end_axis, s); +} + +static inline int mlx_floor(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_floor_(res, a, s); +} + +static inline int mlx_floor_divide( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_floor_divide_(res, a, b, s); +} + +static inline int mlx_from_fp8( + mlx_array* res, + const mlx_array x, + mlx_dtype dtype, + const mlx_stream s) { + return mlx_from_fp8_(res, x, dtype, s); +} + +static inline int mlx_full( + mlx_array* res, + const int* shape, + size_t shape_num, + const mlx_array vals, + mlx_dtype dtype, + const mlx_stream s) { + return mlx_full_(res, shape, shape_num, vals, dtype, s); +} + +static inline int mlx_full_like( + mlx_array* res, + const mlx_array a, + const mlx_array vals, + mlx_dtype dtype, + const mlx_stream s) { + return mlx_full_like_(res, a, vals, dtype, s); +} + +static inline int mlx_gather( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const int* axes, + size_t axes_num, + const int* slice_sizes, + size_t slice_sizes_num, + const mlx_stream s) { + return mlx_gather_(res, a, indices, axes, axes_num, slice_sizes, slice_sizes_num, s); +} + +static inline int mlx_gather_mm( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_array lhs_indices /* may be null */, + const mlx_array rhs_indices /* may be null */, + bool sorted_indices, + const mlx_stream s) { + return mlx_gather_mm_(res, a, b, lhs_indices, rhs_indices, sorted_indices, s); +} + +static inline int mlx_gather_qmm( + mlx_array* res, + const mlx_array x, + const mlx_array w, + const mlx_array scales, + const mlx_array biases /* may be null */, + const mlx_array lhs_indices /* may be null */, + const mlx_array rhs_indices /* may be null */, + bool transpose, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + bool sorted_indices, + const mlx_stream s) { + return mlx_gather_qmm_(res, x, w, scales, biases, lhs_indices, rhs_indices, transpose, group_size, bits, mode, sorted_indices, s); +} + +static inline int mlx_greater( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_greater_(res, a, b, s); +} + +static inline int mlx_greater_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_greater_equal_(res, a, b, s); +} + +static inline int mlx_hadamard_transform( + mlx_array* res, + const mlx_array a, + mlx_optional_float scale, + const mlx_stream s) { + return mlx_hadamard_transform_(res, a, scale, s); +} + +static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) { + return mlx_identity_(res, n, dtype, s); +} + +static inline int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_imag_(res, a, s); +} + +static inline int mlx_inner( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_inner_(res, a, b, s); +} + +static inline int mlx_isclose( + mlx_array* res, + const mlx_array a, + const mlx_array b, + double rtol, + double atol, + bool equal_nan, + const mlx_stream s) { + return mlx_isclose_(res, a, b, rtol, atol, equal_nan, s); +} + +static inline int mlx_isfinite(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_isfinite_(res, a, s); +} + +static inline int mlx_isinf(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_isinf_(res, a, s); +} + +static inline int mlx_isnan(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_isnan_(res, a, s); +} + +static inline int mlx_isneginf(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_isneginf_(res, a, s); +} + +static inline int mlx_isposinf(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_isposinf_(res, a, s); +} + +static inline int mlx_kron( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_kron_(res, a, b, s); +} + +static inline int mlx_left_shift( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_left_shift_(res, a, b, s); +} + +static inline int mlx_less( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_less_(res, a, b, s); +} + +static inline int mlx_less_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_less_equal_(res, a, b, s); +} + +static inline int mlx_linspace( + mlx_array* res, + double start, + double stop, + int num, + mlx_dtype dtype, + const mlx_stream s) { + return mlx_linspace_(res, start, stop, num, dtype, s); +} + +static inline int mlx_log(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_log_(res, a, s); +} + +static inline int mlx_log10(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_log10_(res, a, s); +} + +static inline int mlx_log1p(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_log1p_(res, a, s); +} + +static inline int mlx_log2(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_log2_(res, a, s); +} + +static inline int mlx_logaddexp( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_logaddexp_(res, a, b, s); +} + +static inline int mlx_logcumsumexp( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s) { + return mlx_logcumsumexp_(res, a, axis, reverse, inclusive, s); +} + +static inline int mlx_logical_and( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_logical_and_(res, a, b, s); +} + +static inline int mlx_logical_not(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_logical_not_(res, a, s); +} + +static inline int mlx_logical_or( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_logical_or_(res, a, b, s); +} + +static inline int mlx_logsumexp_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s) { + return mlx_logsumexp_axes_(res, a, axes, axes_num, keepdims, s); +} + +static inline int mlx_logsumexp_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) { + return mlx_logsumexp_axis_(res, a, axis, keepdims, s); +} + +static inline int mlx_logsumexp( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) { + return mlx_logsumexp_(res, a, keepdims, s); +} + +static inline int mlx_masked_scatter( + mlx_array* res, + const mlx_array a, + const mlx_array mask, + const mlx_array src, + const mlx_stream s) { + return mlx_masked_scatter_(res, a, mask, src, s); +} + +static inline int mlx_matmul( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_matmul_(res, a, b, s); +} + +static inline int mlx_max_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s) { + return mlx_max_axes_(res, a, axes, axes_num, keepdims, s); +} + +static inline int mlx_max_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) { + return mlx_max_axis_(res, a, axis, keepdims, s); +} + +static inline int mlx_max( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) { + return mlx_max_(res, a, keepdims, s); +} + +static inline int mlx_maximum( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_maximum_(res, a, b, s); +} + +static inline int mlx_mean_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s) { + return mlx_mean_axes_(res, a, axes, axes_num, keepdims, s); +} + +static inline int mlx_mean_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) { + return mlx_mean_axis_(res, a, axis, keepdims, s); +} + +static inline int mlx_mean( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) { + return mlx_mean_(res, a, keepdims, s); +} + +static inline int mlx_median( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s) { + return mlx_median_(res, a, axes, axes_num, keepdims, s); +} + +static inline int mlx_meshgrid( + mlx_vector_array* res, + const mlx_vector_array arrays, + bool sparse, + const char* indexing, + const mlx_stream s) { + return mlx_meshgrid_(res, arrays, sparse, indexing, s); +} + +static inline int mlx_min_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s) { + return mlx_min_axes_(res, a, axes, axes_num, keepdims, s); +} + +static inline int mlx_min_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) { + return mlx_min_axis_(res, a, axis, keepdims, s); +} + +static inline int mlx_min( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) { + return mlx_min_(res, a, keepdims, s); +} + +static inline int mlx_minimum( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_minimum_(res, a, b, s); +} + +static inline int mlx_moveaxis( + mlx_array* res, + const mlx_array a, + int source, + int destination, + const mlx_stream s) { + return mlx_moveaxis_(res, a, source, destination, s); +} + +static inline int mlx_multiply( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_multiply_(res, a, b, s); +} + +static inline int mlx_nan_to_num( + mlx_array* res, + const mlx_array a, + float nan, + mlx_optional_float posinf, + mlx_optional_float neginf, + const mlx_stream s) { + return mlx_nan_to_num_(res, a, nan, posinf, neginf, s); +} + +static inline int mlx_negative(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_negative_(res, a, s); +} + +static inline int mlx_not_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_not_equal_(res, a, b, s); +} + +static inline int mlx_number_of_elements( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool inverted, + mlx_dtype dtype, + const mlx_stream s) { + return mlx_number_of_elements_(res, a, axes, axes_num, inverted, dtype, s); +} + +static inline int mlx_ones( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_stream s) { + return mlx_ones_(res, shape, shape_num, dtype, s); +} + +static inline int mlx_ones_like(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_ones_like_(res, a, s); +} + +static inline int mlx_outer( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_outer_(res, a, b, s); +} + +static inline int mlx_pad( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const int* low_pad_size, + size_t low_pad_size_num, + const int* high_pad_size, + size_t high_pad_size_num, + const mlx_array pad_value, + const char* mode, + const mlx_stream s) { + return mlx_pad_(res, a, axes, axes_num, low_pad_size, low_pad_size_num, high_pad_size, high_pad_size_num, pad_value, mode, s); +} + +static inline int mlx_pad_symmetric( + mlx_array* res, + const mlx_array a, + int pad_width, + const mlx_array pad_value, + const char* mode, + const mlx_stream s) { + return mlx_pad_symmetric_(res, a, pad_width, pad_value, mode, s); +} + +static inline int mlx_partition_axis( + mlx_array* res, + const mlx_array a, + int kth, + int axis, + const mlx_stream s) { + return mlx_partition_axis_(res, a, kth, axis, s); +} + +static inline int mlx_partition( + mlx_array* res, + const mlx_array a, + int kth, + const mlx_stream s) { + return mlx_partition_(res, a, kth, s); +} + +static inline int mlx_power( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_power_(res, a, b, s); +} + +static inline int mlx_prod_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s) { + return mlx_prod_axes_(res, a, axes, axes_num, keepdims, s); +} + +static inline int mlx_prod_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) { + return mlx_prod_axis_(res, a, axis, keepdims, s); +} + +static inline int mlx_prod( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) { + return mlx_prod_(res, a, keepdims, s); +} + +static inline int mlx_put_along_axis( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array values, + int axis, + const mlx_stream s) { + return mlx_put_along_axis_(res, a, indices, values, axis, s); +} + +static inline int mlx_quantize( + mlx_vector_array* res, + const mlx_array w, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + const mlx_stream s) { + return mlx_quantize_(res, w, group_size, bits, mode, s); +} + +static inline int mlx_quantized_matmul( + mlx_array* res, + const mlx_array x, + const mlx_array w, + const mlx_array scales, + const mlx_array biases /* may be null */, + bool transpose, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + const mlx_stream s) { + return mlx_quantized_matmul_(res, x, w, scales, biases, transpose, group_size, bits, mode, s); +} + +static inline int mlx_radians(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_radians_(res, a, s); +} + +static inline int mlx_real(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_real_(res, a, s); +} + +static inline int mlx_reciprocal(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_reciprocal_(res, a, s); +} + +static inline int mlx_remainder( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_remainder_(res, a, b, s); +} + +static inline int mlx_repeat_axis( + mlx_array* res, + const mlx_array arr, + int repeats, + int axis, + const mlx_stream s) { + return mlx_repeat_axis_(res, arr, repeats, axis, s); +} + +static inline int mlx_repeat( + mlx_array* res, + const mlx_array arr, + int repeats, + const mlx_stream s) { + return mlx_repeat_(res, arr, repeats, s); +} + +static inline int mlx_reshape( + mlx_array* res, + const mlx_array a, + const int* shape, + size_t shape_num, + const mlx_stream s) { + return mlx_reshape_(res, a, shape, shape_num, s); +} + +static inline int mlx_right_shift( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_right_shift_(res, a, b, s); +} + +static inline int mlx_roll_axis( + mlx_array* res, + const mlx_array a, + const int* shift, + size_t shift_num, + int axis, + const mlx_stream s) { + return mlx_roll_axis_(res, a, shift, shift_num, axis, s); +} + +static inline int mlx_roll_axes( + mlx_array* res, + const mlx_array a, + const int* shift, + size_t shift_num, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_roll_axes_(res, a, shift, shift_num, axes, axes_num, s); +} + +static inline int mlx_roll( + mlx_array* res, + const mlx_array a, + const int* shift, + size_t shift_num, + const mlx_stream s) { + return mlx_roll_(res, a, shift, shift_num, s); +} + +static inline int mlx_round( + mlx_array* res, + const mlx_array a, + int decimals, + const mlx_stream s) { + return mlx_round_(res, a, decimals, s); +} + +static inline int mlx_rsqrt(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_rsqrt_(res, a, s); +} + +static inline int mlx_scatter( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_scatter_(res, a, indices, updates, axes, axes_num, s); +} + +static inline int mlx_scatter_add( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_scatter_add_(res, a, indices, updates, axes, axes_num, s); +} + +static inline int mlx_scatter_add_axis( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array values, + int axis, + const mlx_stream s) { + return mlx_scatter_add_axis_(res, a, indices, values, axis, s); +} + +static inline int mlx_scatter_max( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_scatter_max_(res, a, indices, updates, axes, axes_num, s); +} + +static inline int mlx_scatter_min( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_scatter_min_(res, a, indices, updates, axes, axes_num, s); +} + +static inline int mlx_scatter_prod( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_scatter_prod_(res, a, indices, updates, axes, axes_num, s); +} + +static inline int mlx_segmented_mm( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_array segments, + const mlx_stream s) { + return mlx_segmented_mm_(res, a, b, segments, s); +} + +static inline int mlx_sigmoid(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_sigmoid_(res, a, s); +} + +static inline int mlx_sign(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_sign_(res, a, s); +} + +static inline int mlx_sin(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_sin_(res, a, s); +} + +static inline int mlx_sinh(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_sinh_(res, a, s); +} + +static inline int mlx_slice( + mlx_array* res, + const mlx_array a, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s) { + return mlx_slice_(res, a, start, start_num, stop, stop_num, strides, strides_num, s); +} + +static inline int mlx_slice_dynamic( + mlx_array* res, + const mlx_array a, + const mlx_array start, + const int* axes, + size_t axes_num, + const int* slice_size, + size_t slice_size_num, + const mlx_stream s) { + return mlx_slice_dynamic_(res, a, start, axes, axes_num, slice_size, slice_size_num, s); +} + +static inline int mlx_slice_update( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s) { + return mlx_slice_update_(res, src, update, start, start_num, stop, stop_num, strides, strides_num, s); +} + +static inline int mlx_slice_update_dynamic( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const mlx_array start, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_slice_update_dynamic_(res, src, update, start, axes, axes_num, s); +} + +static inline int mlx_softmax_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool precise, + const mlx_stream s) { + return mlx_softmax_axes_(res, a, axes, axes_num, precise, s); +} + +static inline int mlx_softmax_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool precise, + const mlx_stream s) { + return mlx_softmax_axis_(res, a, axis, precise, s); +} + +static inline int mlx_softmax( + mlx_array* res, + const mlx_array a, + bool precise, + const mlx_stream s) { + return mlx_softmax_(res, a, precise, s); +} + +static inline int mlx_sort_axis( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s) { + return mlx_sort_axis_(res, a, axis, s); +} + +static inline int mlx_sort(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_sort_(res, a, s); +} + +static inline int mlx_split( + mlx_vector_array* res, + const mlx_array a, + int num_splits, + int axis, + const mlx_stream s) { + return mlx_split_(res, a, num_splits, axis, s); +} + +static inline int mlx_split_sections( + mlx_vector_array* res, + const mlx_array a, + const int* indices, + size_t indices_num, + int axis, + const mlx_stream s) { + return mlx_split_sections_(res, a, indices, indices_num, axis, s); +} + +static inline int mlx_sqrt(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_sqrt_(res, a, s); +} + +static inline int mlx_square(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_square_(res, a, s); +} + +static inline int mlx_squeeze_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_squeeze_axes_(res, a, axes, axes_num, s); +} + +static inline int mlx_squeeze_axis( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s) { + return mlx_squeeze_axis_(res, a, axis, s); +} + +static inline int mlx_squeeze(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_squeeze_(res, a, s); +} + +static inline int mlx_stack_axis( + mlx_array* res, + const mlx_vector_array arrays, + int axis, + const mlx_stream s) { + return mlx_stack_axis_(res, arrays, axis, s); +} + +static inline int mlx_stack( + mlx_array* res, + const mlx_vector_array arrays, + const mlx_stream s) { + return mlx_stack_(res, arrays, s); +} + +static inline int mlx_std_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + int ddof, + const mlx_stream s) { + return mlx_std_axes_(res, a, axes, axes_num, keepdims, ddof, s); +} + +static inline int mlx_std_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + int ddof, + const mlx_stream s) { + return mlx_std_axis_(res, a, axis, keepdims, ddof, s); +} + +static inline int mlx_std( + mlx_array* res, + const mlx_array a, + bool keepdims, + int ddof, + const mlx_stream s) { + return mlx_std_(res, a, keepdims, ddof, s); +} + +static inline int mlx_stop_gradient(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_stop_gradient_(res, a, s); +} + +static inline int mlx_subtract( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s) { + return mlx_subtract_(res, a, b, s); +} + +static inline int mlx_sum_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s) { + return mlx_sum_axes_(res, a, axes, axes_num, keepdims, s); +} + +static inline int mlx_sum_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s) { + return mlx_sum_axis_(res, a, axis, keepdims, s); +} + +static inline int mlx_sum( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s) { + return mlx_sum_(res, a, keepdims, s); +} + +static inline int mlx_swapaxes( + mlx_array* res, + const mlx_array a, + int axis1, + int axis2, + const mlx_stream s) { + return mlx_swapaxes_(res, a, axis1, axis2, s); +} + +static inline int mlx_take_axis( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + int axis, + const mlx_stream s) { + return mlx_take_axis_(res, a, indices, axis, s); +} + +static inline int mlx_take( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_stream s) { + return mlx_take_(res, a, indices, s); +} + +static inline int mlx_take_along_axis( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + int axis, + const mlx_stream s) { + return mlx_take_along_axis_(res, a, indices, axis, s); +} + +static inline int mlx_tan(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_tan_(res, a, s); +} + +static inline int mlx_tanh(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_tanh_(res, a, s); +} + +static inline int mlx_tensordot( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const int* axes_a, + size_t axes_a_num, + const int* axes_b, + size_t axes_b_num, + const mlx_stream s) { + return mlx_tensordot_(res, a, b, axes_a, axes_a_num, axes_b, axes_b_num, s); +} + +static inline int mlx_tensordot_axis( + mlx_array* res, + const mlx_array a, + const mlx_array b, + int axis, + const mlx_stream s) { + return mlx_tensordot_axis_(res, a, b, axis, s); +} + +static inline int mlx_tile( + mlx_array* res, + const mlx_array arr, + const int* reps, + size_t reps_num, + const mlx_stream s) { + return mlx_tile_(res, arr, reps, reps_num, s); +} + +static inline int mlx_to_fp8(mlx_array* res, const mlx_array x, const mlx_stream s) { + return mlx_to_fp8_(res, x, s); +} + +static inline int mlx_topk_axis( + mlx_array* res, + const mlx_array a, + int k, + int axis, + const mlx_stream s) { + return mlx_topk_axis_(res, a, k, axis, s); +} + +static inline int mlx_topk(mlx_array* res, const mlx_array a, int k, const mlx_stream s) { + return mlx_topk_(res, a, k, s); +} + +static inline int mlx_trace( + mlx_array* res, + const mlx_array a, + int offset, + int axis1, + int axis2, + mlx_dtype dtype, + const mlx_stream s) { + return mlx_trace_(res, a, offset, axis1, axis2, dtype, s); +} + +static inline int mlx_transpose_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s) { + return mlx_transpose_axes_(res, a, axes, axes_num, s); +} + +static inline int mlx_transpose(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_transpose_(res, a, s); +} + +static inline int mlx_tri( + mlx_array* res, + int n, + int m, + int k, + mlx_dtype type, + const mlx_stream s) { + return mlx_tri_(res, n, m, k, type, s); +} + +static inline int mlx_tril(mlx_array* res, const mlx_array x, int k, const mlx_stream s) { + return mlx_tril_(res, x, k, s); +} + +static inline int mlx_triu(mlx_array* res, const mlx_array x, int k, const mlx_stream s) { + return mlx_triu_(res, x, k, s); +} + +static inline int mlx_unflatten( + mlx_array* res, + const mlx_array a, + int axis, + const int* shape, + size_t shape_num, + const mlx_stream s) { + return mlx_unflatten_(res, a, axis, shape, shape_num, s); +} + +static inline int mlx_var_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + int ddof, + const mlx_stream s) { + return mlx_var_axes_(res, a, axes, axes_num, keepdims, ddof, s); +} + +static inline int mlx_var_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + int ddof, + const mlx_stream s) { + return mlx_var_axis_(res, a, axis, keepdims, ddof, s); +} + +static inline int mlx_var( + mlx_array* res, + const mlx_array a, + bool keepdims, + int ddof, + const mlx_stream s) { + return mlx_var_(res, a, keepdims, ddof, s); +} + +static inline int mlx_view( + mlx_array* res, + const mlx_array a, + mlx_dtype dtype, + const mlx_stream s) { + return mlx_view_(res, a, dtype, s); +} + +static inline int mlx_where( + mlx_array* res, + const mlx_array condition, + const mlx_array x, + const mlx_array y, + const mlx_stream s) { + return mlx_where_(res, condition, x, y, s); +} + +static inline int mlx_zeros( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_stream s) { + return mlx_zeros_(res, shape, shape_num, dtype, s); +} + +static inline int mlx_zeros_like(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_zeros_like_(res, a, s); +} + +static inline int mlx_random_bernoulli( + mlx_array* res, + const mlx_array p, + const int* shape, + size_t shape_num, + const mlx_array key /* may be null */, + const mlx_stream s) { + return mlx_random_bernoulli_(res, p, shape, shape_num, key, s); +} + +static inline int mlx_random_bits( + mlx_array* res, + const int* shape, + size_t shape_num, + int width, + const mlx_array key /* may be null */, + const mlx_stream s) { + return mlx_random_bits_(res, shape, shape_num, width, key, s); +} + +static inline int mlx_random_categorical_shape( + mlx_array* res, + const mlx_array logits, + int axis, + const int* shape, + size_t shape_num, + const mlx_array key /* may be null */, + const mlx_stream s) { + return mlx_random_categorical_shape_(res, logits, axis, shape, shape_num, key, s); +} + +static inline int mlx_random_categorical_num_samples( + mlx_array* res, + const mlx_array logits_, + int axis, + int num_samples, + const mlx_array key /* may be null */, + const mlx_stream s) { + return mlx_random_categorical_num_samples_(res, logits_, axis, num_samples, key, s); +} + +static inline int mlx_random_categorical( + mlx_array* res, + const mlx_array logits, + int axis, + const mlx_array key /* may be null */, + const mlx_stream s) { + return mlx_random_categorical_(res, logits, axis, key, s); +} + +static inline int mlx_random_gumbel( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s) { + return mlx_random_gumbel_(res, shape, shape_num, dtype, key, s); +} + +static inline int mlx_random_key(mlx_array* res, uint64_t seed) { + return mlx_random_key_(res, seed); +} + +static inline int mlx_random_laplace( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + float loc, + float scale, + const mlx_array key /* may be null */, + const mlx_stream s) { + return mlx_random_laplace_(res, shape, shape_num, dtype, loc, scale, key, s); +} + +static inline int mlx_random_multivariate_normal( + mlx_array* res, + const mlx_array mean, + const mlx_array cov, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s) { + return mlx_random_multivariate_normal_(res, mean, cov, shape, shape_num, dtype, key, s); +} + +static inline int mlx_random_normal_broadcast( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array loc /* may be null */, + const mlx_array scale /* may be null */, + const mlx_array key /* may be null */, + const mlx_stream s) { + return mlx_random_normal_broadcast_(res, shape, shape_num, dtype, loc, scale, key, s); +} + +static inline int mlx_random_normal( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + float loc, + float scale, + const mlx_array key /* may be null */, + const mlx_stream s) { + return mlx_random_normal_(res, shape, shape_num, dtype, loc, scale, key, s); +} + +static inline int mlx_random_permutation( + mlx_array* res, + const mlx_array x, + int axis, + const mlx_array key /* may be null */, + const mlx_stream s) { + return mlx_random_permutation_(res, x, axis, key, s); +} + +static inline int mlx_random_permutation_arange( + mlx_array* res, + int x, + const mlx_array key /* may be null */, + const mlx_stream s) { + return mlx_random_permutation_arange_(res, x, key, s); +} + +static inline int mlx_random_randint( + mlx_array* res, + const mlx_array low, + const mlx_array high, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s) { + return mlx_random_randint_(res, low, high, shape, shape_num, dtype, key, s); +} + +static inline int mlx_random_seed(uint64_t seed) { + return mlx_random_seed_(seed); +} + +static inline int mlx_random_split_num( + mlx_array* res, + const mlx_array key, + int num, + const mlx_stream s) { + return mlx_random_split_num_(res, key, num, s); +} + +static inline int mlx_random_split( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array key, + const mlx_stream s) { + return mlx_random_split_(res_0, res_1, key, s); +} + +static inline int mlx_random_truncated_normal( + mlx_array* res, + const mlx_array lower, + const mlx_array upper, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s) { + return mlx_random_truncated_normal_(res, lower, upper, shape, shape_num, dtype, key, s); +} + +static inline int mlx_random_uniform( + mlx_array* res, + const mlx_array low, + const mlx_array high, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s) { + return mlx_random_uniform_(res, low, high, shape, shape_num, dtype, key, s); +} + +static inline mlx_stream mlx_stream_new(void) { + return mlx_stream_new_(); +} + +static inline mlx_stream mlx_stream_new_device(mlx_device dev) { + return mlx_stream_new_device_(dev); +} + +static inline int mlx_stream_set(mlx_stream* stream, const mlx_stream src) { + return mlx_stream_set_(stream, src); +} + +static inline int mlx_stream_free(mlx_stream stream) { + return mlx_stream_free_(stream); +} + +static inline int mlx_stream_tostring(mlx_string* str, mlx_stream stream) { + return mlx_stream_tostring_(str, stream); +} + +static inline bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs) { + return mlx_stream_equal_(lhs, rhs); +} + +static inline int mlx_stream_get_device(mlx_device* dev, mlx_stream stream) { + return mlx_stream_get_device_(dev, stream); +} + +static inline int mlx_stream_get_index(int* index, mlx_stream stream) { + return mlx_stream_get_index_(index, stream); +} + +static inline int mlx_synchronize(mlx_stream stream) { + return mlx_synchronize_(stream); +} + +static inline int mlx_get_default_stream(mlx_stream* stream, mlx_device dev) { + return mlx_get_default_stream_(stream, dev); +} + +static inline int mlx_set_default_stream(mlx_stream stream) { + return mlx_set_default_stream_(stream); +} + +static inline mlx_stream mlx_default_cpu_stream_new(void) { + return mlx_default_cpu_stream_new_(); +} + +static inline mlx_stream mlx_default_gpu_stream_new(void) { + return mlx_default_gpu_stream_new_(); +} + +static inline mlx_string mlx_string_new(void) { + return mlx_string_new_(); +} + +static inline mlx_string mlx_string_new_data(const char* str) { + return mlx_string_new_data_(str); +} + +static inline int mlx_string_set(mlx_string* str, const mlx_string src) { + return mlx_string_set_(str, src); +} + +static inline const char * mlx_string_data(mlx_string str) { + return mlx_string_data_(str); +} + +static inline int mlx_string_free(mlx_string str) { + return mlx_string_free_(str); +} + +static inline int mlx_detail_vmap_replace( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_vector_array s_inputs, + const mlx_vector_array s_outputs, + const int* in_axes, + size_t in_axes_num, + const int* out_axes, + size_t out_axes_num) { + return mlx_detail_vmap_replace_(res, inputs, s_inputs, s_outputs, in_axes, in_axes_num, out_axes, out_axes_num); +} + +static inline int mlx_detail_vmap_trace( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array inputs, + const int* in_axes, + size_t in_axes_num) { + return mlx_detail_vmap_trace_(res_0, res_1, fun, inputs, in_axes, in_axes_num); +} + +static inline int mlx_async_eval(const mlx_vector_array outputs) { + return mlx_async_eval_(outputs); +} + +static inline int mlx_checkpoint(mlx_closure* res, const mlx_closure fun) { + return mlx_checkpoint_(res, fun); +} + +static inline int mlx_custom_function( + mlx_closure* res, + const mlx_closure fun, + const mlx_closure_custom fun_vjp /* may be null */, + const mlx_closure_custom_jvp fun_jvp /* may be null */, + const mlx_closure_custom_vmap fun_vmap /* may be null */) { + return mlx_custom_function_(res, fun, fun_vjp, fun_jvp, fun_vmap); +} + +static inline int mlx_custom_vjp( + mlx_closure* res, + const mlx_closure fun, + const mlx_closure_custom fun_vjp) { + return mlx_custom_vjp_(res, fun, fun_vjp); +} + +static inline int mlx_eval(const mlx_vector_array outputs) { + return mlx_eval_(outputs); +} + +static inline int mlx_jvp( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array primals, + const mlx_vector_array tangents) { + return mlx_jvp_(res_0, res_1, fun, primals, tangents); +} + +static inline int mlx_value_and_grad( + mlx_closure_value_and_grad* res, + const mlx_closure fun, + const int* argnums, + size_t argnums_num) { + return mlx_value_and_grad_(res, fun, argnums, argnums_num); +} + +static inline int mlx_vjp( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array primals, + const mlx_vector_array cotangents) { + return mlx_vjp_(res_0, res_1, fun, primals, cotangents); +} + +static inline mlx_vector_array mlx_vector_array_new(void) { + return mlx_vector_array_new_(); +} + +static inline int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src) { + return mlx_vector_array_set_(vec, src); +} + +static inline int mlx_vector_array_free(mlx_vector_array vec) { + return mlx_vector_array_free_(vec); +} + +static inline mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size) { + return mlx_vector_array_new_data_(data, size); +} + +static inline mlx_vector_array mlx_vector_array_new_value(const mlx_array val) { + return mlx_vector_array_new_value_(val); +} + +static inline int mlx_vector_array_set_data( + mlx_vector_array* vec, + const mlx_array* data, + size_t size) { + return mlx_vector_array_set_data_(vec, data, size); +} + +static inline int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val) { + return mlx_vector_array_set_value_(vec, val); +} + +static inline int mlx_vector_array_append_data( + mlx_vector_array vec, + const mlx_array* data, + size_t size) { + return mlx_vector_array_append_data_(vec, data, size); +} + +static inline int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val) { + return mlx_vector_array_append_value_(vec, val); +} + +static inline size_t mlx_vector_array_size(mlx_vector_array vec) { + return mlx_vector_array_size_(vec); +} + +static inline int mlx_vector_array_get( + mlx_array* res, + const mlx_vector_array vec, + size_t idx) { + return mlx_vector_array_get_(res, vec, idx); +} + +static inline mlx_vector_vector_array mlx_vector_vector_array_new(void) { + return mlx_vector_vector_array_new_(); +} + +static inline int mlx_vector_vector_array_set( + mlx_vector_vector_array* vec, + const mlx_vector_vector_array src) { + return mlx_vector_vector_array_set_(vec, src); +} + +static inline int mlx_vector_vector_array_free(mlx_vector_vector_array vec) { + return mlx_vector_vector_array_free_(vec); +} + +static inline mlx_vector_vector_array mlx_vector_vector_array_new_data( + const mlx_vector_array* data, + size_t size) { + return mlx_vector_vector_array_new_data_(data, size); +} + +static inline mlx_vector_vector_array mlx_vector_vector_array_new_value( + const mlx_vector_array val) { + return mlx_vector_vector_array_new_value_(val); +} + +static inline int mlx_vector_vector_array_set_data( + mlx_vector_vector_array* vec, + const mlx_vector_array* data, + size_t size) { + return mlx_vector_vector_array_set_data_(vec, data, size); +} + +static inline int mlx_vector_vector_array_set_value( + mlx_vector_vector_array* vec, + const mlx_vector_array val) { + return mlx_vector_vector_array_set_value_(vec, val); +} + +static inline int mlx_vector_vector_array_append_data( + mlx_vector_vector_array vec, + const mlx_vector_array* data, + size_t size) { + return mlx_vector_vector_array_append_data_(vec, data, size); +} + +static inline int mlx_vector_vector_array_append_value( + mlx_vector_vector_array vec, + const mlx_vector_array val) { + return mlx_vector_vector_array_append_value_(vec, val); +} + +static inline size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec) { + return mlx_vector_vector_array_size_(vec); +} + +static inline int mlx_vector_vector_array_get( + mlx_vector_array* res, + const mlx_vector_vector_array vec, + size_t idx) { + return mlx_vector_vector_array_get_(res, vec, idx); +} + +static inline mlx_vector_int mlx_vector_int_new(void) { + return mlx_vector_int_new_(); +} + +static inline int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src) { + return mlx_vector_int_set_(vec, src); +} + +static inline int mlx_vector_int_free(mlx_vector_int vec) { + return mlx_vector_int_free_(vec); +} + +static inline mlx_vector_int mlx_vector_int_new_data(int* data, size_t size) { + return mlx_vector_int_new_data_(data, size); +} + +static inline mlx_vector_int mlx_vector_int_new_value(int val) { + return mlx_vector_int_new_value_(val); +} + +static inline int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size) { + return mlx_vector_int_set_data_(vec, data, size); +} + +static inline int mlx_vector_int_set_value(mlx_vector_int* vec, int val) { + return mlx_vector_int_set_value_(vec, val); +} + +static inline int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size) { + return mlx_vector_int_append_data_(vec, data, size); +} + +static inline int mlx_vector_int_append_value(mlx_vector_int vec, int val) { + return mlx_vector_int_append_value_(vec, val); +} + +static inline size_t mlx_vector_int_size(mlx_vector_int vec) { + return mlx_vector_int_size_(vec); +} + +static inline int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx) { + return mlx_vector_int_get_(res, vec, idx); +} + +static inline mlx_vector_string mlx_vector_string_new(void) { + return mlx_vector_string_new_(); +} + +static inline int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src) { + return mlx_vector_string_set_(vec, src); +} + +static inline int mlx_vector_string_free(mlx_vector_string vec) { + return mlx_vector_string_free_(vec); +} + +static inline mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size) { + return mlx_vector_string_new_data_(data, size); +} + +static inline mlx_vector_string mlx_vector_string_new_value(const char* val) { + return mlx_vector_string_new_value_(val); +} + +static inline int mlx_vector_string_set_data( + mlx_vector_string* vec, + const char** data, + size_t size) { + return mlx_vector_string_set_data_(vec, data, size); +} + +static inline int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val) { + return mlx_vector_string_set_value_(vec, val); +} + +static inline int mlx_vector_string_append_data( + mlx_vector_string vec, + const char** data, + size_t size) { + return mlx_vector_string_append_data_(vec, data, size); +} + +static inline int mlx_vector_string_append_value(mlx_vector_string vec, const char* val) { + return mlx_vector_string_append_value_(vec, val); +} + +static inline size_t mlx_vector_string_size(mlx_vector_string vec) { + return mlx_vector_string_size_(vec); +} + +static inline int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx) { + return mlx_vector_string_get_(res, vec, idx); +} + +static inline int mlx_version(mlx_string* str_) { + return mlx_version_(str_); +} + +#endif // MLX_GENERATED_H \ No newline at end of file diff --git a/x/mlxrunner/mlx/generator/generated.c.gotmpl b/x/mlxrunner/mlx/generator/generated.c.gotmpl new file mode 100644 index 000000000..c31b34a76 --- /dev/null +++ b/x/mlxrunner/mlx/generator/generated.c.gotmpl @@ -0,0 +1,17 @@ +// This code is auto-generated; DO NOT EDIT. + +#include "generated.h" + +#include +#include +#include +{{ range .Functions }} +{{ .Type }} (*{{ .Name }}_){{ .Parameters }} = NULL; +{{- end }} + +int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) { +{{- range .Functions }} + CHECK_LOAD(handle, {{ .Name }}); +{{- end }} + return 0; +} diff --git a/x/mlxrunner/mlx/generator/generated.h.gotmpl b/x/mlxrunner/mlx/generator/generated.h.gotmpl new file mode 100644 index 000000000..8f043573b --- /dev/null +++ b/x/mlxrunner/mlx/generator/generated.h.gotmpl @@ -0,0 +1,22 @@ +// This code is auto-generated; DO NOT EDIT. + +#ifndef MLX_GENERATED_H +#define MLX_GENERATED_H + +#include "dynamic.h" +#include "mlx/c/mlx.h" +{{ range .Functions }} +#undef {{ .Name }} +{{- end }} +{{ range .Functions }} +extern {{ .Type }} (*{{ .Name }}_){{ .Parameters }}; +{{- end }} + +int mlx_dynamic_load_symbols(mlx_dynamic_handle handle); +{{ range .Functions }} +static inline {{ .Type }} {{ .Name }}{{ .Parameters }} {{ "{" }} + return {{ .Name }}_({{ .Args }}); +{{ "}" }} +{{- end }} + +#endif // MLX_GENERATED_H diff --git a/x/mlxrunner/mlx/generator/main.go b/x/mlxrunner/mlx/generator/main.go new file mode 100644 index 000000000..a98046a2f --- /dev/null +++ b/x/mlxrunner/mlx/generator/main.go @@ -0,0 +1,135 @@ +package main + +import ( + "embed" + "flag" + "fmt" + "os" + "path/filepath" + "slices" + "strings" + "text/template" + + tree_sitter "github.com/tree-sitter/go-tree-sitter" + tree_sitter_cpp "github.com/tree-sitter/tree-sitter-cpp/bindings/go" +) + +//go:embed *.gotmpl +var fsys embed.FS + +type Function struct { + Type, + Name, + Parameters, + Args string +} + +func ParseFunction(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) Function { + var fn Function + fn.Name = node.ChildByFieldName("declarator").Utf8Text(source) + if params := node.ChildByFieldName("parameters"); params != nil { + fn.Parameters = params.Utf8Text(source) + fn.Args = ParseParameters(params, tc, source) + } + + var types []string + for node.Parent() != nil && node.Parent().Kind() != "declaration" { + if node.Parent().Kind() == "pointer_declarator" { + types = append(types, "*") + } + node = node.Parent() + } + + for sibling := node.PrevSibling(); sibling != nil; sibling = sibling.PrevSibling() { + types = append(types, sibling.Utf8Text(source)) + } + + slices.Reverse(types) + fn.Type = strings.Join(types, " ") + return fn +} + +func ParseParameters(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) string { + var s []string + for _, child := range node.Children(tc) { + if child.IsNamed() { + child := child.ChildByFieldName("declarator") + for child != nil && child.Kind() != "identifier" { + if child.Kind() == "parenthesized_declarator" { + child = child.Child(1) + } else { + child = child.ChildByFieldName("declarator") + } + } + + if child != nil { + s = append(s, child.Utf8Text(source)) + } + } + } + return strings.Join(s, ", ") +} + +func main() { + var output string + flag.StringVar(&output, "output", ".", "Output directory for generated files") + flag.Parse() + + parser := tree_sitter.NewParser() + defer parser.Close() + + language := tree_sitter.NewLanguage(tree_sitter_cpp.Language()) + parser.SetLanguage(language) + + query, _ := tree_sitter.NewQuery(language, `(function_declarator declarator: (identifier)) @func`) + defer query.Close() + + qc := tree_sitter.NewQueryCursor() + defer qc.Close() + + var funs []Function + for _, arg := range flag.Args() { + bts, err := os.ReadFile(arg) + if err != nil { + fmt.Fprintf(os.Stderr, "Error reading file %s: %v\n", arg, err) + continue + } + + tree := parser.Parse(bts, nil) + defer tree.Close() + + tc := tree.Walk() + defer tc.Close() + + matches := qc.Matches(query, tree.RootNode(), bts) + for match := matches.Next(); match != nil; match = matches.Next() { + for _, capture := range match.Captures { + funs = append(funs, ParseFunction(&capture.Node, tc, bts)) + } + } + } + + tmpl, err := template.New("").ParseFS(fsys, "*.gotmpl") + if err != nil { + fmt.Fprintf(os.Stderr, "Error parsing template: %v\n", err) + return + } + + for _, tmpl := range tmpl.Templates() { + name := filepath.Join(output, strings.TrimSuffix(tmpl.Name(), ".gotmpl")) + + fmt.Println("Generating", name) + f, err := os.Create(name) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating file %s: %v\n", name, err) + continue + } + defer f.Close() + + if err := tmpl.Execute(f, map[string]any{ + "Functions": funs, + }); err != nil { + fmt.Fprintf(os.Stderr, "Error executing template %s: %v\n", tmpl.Name(), err) + } + } +} diff --git a/x/mlxrunner/mlx/io.go b/x/mlxrunner/mlx/io.go new file mode 100644 index 000000000..304cfcd2c --- /dev/null +++ b/x/mlxrunner/mlx/io.go @@ -0,0 +1,45 @@ +//go:build mlx + +package mlx + +// #include "generated.h" +import "C" + +import ( + "iter" + "unsafe" +) + +func Load(path string) iter.Seq2[string, *Array] { + return func(yield func(string, *Array) bool) { + string2array := C.mlx_map_string_to_array_new() + defer C.mlx_map_string_to_array_free(string2array) + + string2string := C.mlx_map_string_to_string_new() + defer C.mlx_map_string_to_string_free(string2string) + + cPath := C.CString(path) + defer C.free(unsafe.Pointer(cPath)) + + cpu := C.mlx_default_cpu_stream_new() + defer C.mlx_stream_free(cpu) + + C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu) + + it := C.mlx_map_string_to_array_iterator_new(string2array) + defer C.mlx_map_string_to_array_iterator_free(it) + + for { + var key *C.char + value := C.mlx_array_new() + if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 { + break + } + + name := C.GoString(key) + if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) { + break + } + } + } +} diff --git a/x/mlxrunner/mlx/memory.go b/x/mlxrunner/mlx/memory.go new file mode 100644 index 000000000..e9a174b1e --- /dev/null +++ b/x/mlxrunner/mlx/memory.go @@ -0,0 +1,87 @@ +//go:build mlx + +package mlx + +// #include "generated.h" +import "C" + +import ( + "fmt" + "log/slog" + "strconv" +) + +func (b Byte) String() string { + return strconv.FormatInt(int64(b), 10) + " B" +} + +func (b KibiByte) String() string { + return strconv.FormatFloat(float64(b)/(1<<10), 'f', 2, 64) + " KiB" +} + +func (b MebiByte) String() string { + return strconv.FormatFloat(float64(b)/(1<<(2*10)), 'f', 2, 64) + " MiB" +} + +func (b GibiByte) String() string { + return strconv.FormatFloat(float64(b)/(1<<(3*10)), 'f', 2, 64) + " GiB" +} + +func (b TebiByte) String() string { + return strconv.FormatFloat(float64(b)/(1<<(4*10)), 'f', 2, 64) + " TiB" +} + +func PrettyBytes(n int) fmt.Stringer { + switch { + case n < 1<<10: + return Byte(n) + case n < 1<<(2*10): + return KibiByte(n) + case n < 1<<(3*10): + return MebiByte(n) + case n < 1<<(4*10): + return GibiByte(n) + default: + return TebiByte(n) + } +} + +func ActiveMemory() int { + var active C.size_t + C.mlx_get_active_memory(&active) + return int(active) +} + +func CacheMemory() int { + var cache C.size_t + C.mlx_get_cache_memory(&cache) + return int(cache) +} + +func PeakMemory() int { + var peak C.size_t + C.mlx_get_peak_memory(&peak) + return int(peak) +} + +type Memory struct{} + +func (Memory) LogValue() slog.Value { + return slog.GroupValue( + slog.Any("active", PrettyBytes(ActiveMemory())), + slog.Any("cache", PrettyBytes(CacheMemory())), + slog.Any("peak", PrettyBytes(PeakMemory())), + ) +} + +type ( + Byte int + KibiByte int + MebiByte int + GibiByte int + TebiByte int +) + +func ClearCache() { + C.mlx_clear_cache() +} diff --git a/x/mlxrunner/mlx/mlx.go b/x/mlxrunner/mlx/mlx.go new file mode 100644 index 000000000..0bf43830c --- /dev/null +++ b/x/mlxrunner/mlx/mlx.go @@ -0,0 +1,40 @@ +//go:build mlx + +package mlx + +//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release +//go:generate cmake --build build --parallel +//go:generate cmake --install build +//go:generate sh -c "go run generator/main.go -output=. ./dist/include/mlx/c/*.h" + +// #cgo CXXFLAGS: -std=c++17 +// #cgo CPPFLAGS: -I${SRCDIR}/dist/include +// #cgo LDFLAGS: -L${SRCDIR}/dist/lib -lstdc++ +// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate +// #include "generated.h" +import "C" + +func doEval(outputs []*Array, async bool) { + vector := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(vector) + + for _, output := range outputs { + if output.Valid() { + C.mlx_vector_array_append_value(vector, output.ctx) + } + } + + if async { + C.mlx_async_eval(vector) + } else { + C.mlx_eval(vector) + } +} + +func AsyncEval(outputs ...*Array) { + doEval(outputs, true) +} + +func Eval(outputs ...*Array) { + doEval(outputs, false) +} diff --git a/x/mlxrunner/mlx/nn.go b/x/mlxrunner/mlx/nn.go new file mode 100644 index 000000000..3d5691368 --- /dev/null +++ b/x/mlxrunner/mlx/nn.go @@ -0,0 +1,38 @@ +//go:build mlx + +package mlx + +type Linear struct { + Weight Array `weight:"weight"` + Bias Array `weight:"bias"` +} + +// Forward computes the linear transformation: x @ Weight.T + Bias +func (m Linear) Forward(x *Array) *Array { + w := m.Weight.Transpose(1, 0) + if m.Bias.Valid() { + return m.Bias.Addmm(x, w, 1.0, 1.0) + } + + return x.Matmul(w) +} + +func (m Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array { + w := m.Weight.Transpose(0, 2, 1) + // TODO: bias + return x.GatherMM(w, lhs, rhs, sorted) +} + +type Embedding struct { + Weight Array `weight:"weight"` +} + +func (e *Embedding) Forward(indices *Array) *Array { + return e.Weight.TakeAxis(indices, 0) +} + +func (e *Embedding) AsLinear() Linear { + return Linear{ + Weight: e.Weight, + } +} diff --git a/x/mlxrunner/mlx/ops.go b/x/mlxrunner/mlx/ops.go new file mode 100644 index 000000000..01a7f4835 --- /dev/null +++ b/x/mlxrunner/mlx/ops.go @@ -0,0 +1,256 @@ +//go:build mlx + +package mlx + +// #include "generated.h" +import "C" + +import ( + "unsafe" +) + +func (t *Array) Abs() *Array { + out := New("ABS", t) + C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx) + return out +} + +func (t *Array) Add(other *Array) *Array { + out := New("ADD", t, other) + C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) + return out +} + +func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array { + out := New("ADDMM", t, a, b) + C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx) + return out +} + +func (t *Array) Argmax(axis int, keepDims bool) *Array { + out := New("ARGMAX", t) + C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx) + return out +} + +func (t *Array) ArgpartitionAxis(kth int, axis int) *Array { + out := New("ARGPARTITION", t) + C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx) + return out +} + +func (t *Array) ArgsortAxis(axis int) *Array { + out := New("ARGSORT_AXIS", t) + C.mlx_argsort_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx) + return out +} + +func (t *Array) AsType(dtype DType) *Array { + out := New("AS_TYPE", t) + C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx) + return out +} + +func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array { + cShape := make([]C.int, len(shape)) + for i, s := range shape { + cShape[i] = C.int(s) + } + + cStrides := make([]C.int64_t, len(strides)) + for i, s := range strides { + cStrides[i] = C.int64_t(s) + } + + out := New("AS_STRIDED", t) + C.mlx_as_strided( + &out.ctx, t.ctx, + unsafe.SliceData(cShape), C.size_t(len(shape)), + unsafe.SliceData(cStrides), C.size_t(len(strides)), + C.size_t(offset), + DefaultStream().ctx, + ) + return out +} + +func (t *Array) Concatenate(axis int, others ...*Array) *Array { + vector := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(vector) + + s := append([]*Array{t}, others...) + for _, other := range s { + C.mlx_vector_array_append_value(vector, other.ctx) + } + + out := New("CONCATENATE", s...) + C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx) + return out +} + +func (t *Array) Divide(other *Array) *Array { + out := New("DIVIDE", t, other) + C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) + return out +} + +func (t *Array) ExpandDims(axis int) *Array { + out := New("EXPAND_DIMS", t) + C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx) + return out +} + +func (t *Array) Flatten(startAxis, endAxis int) *Array { + out := New("FLATTEN", t) + C.mlx_flatten(&out.ctx, t.ctx, C.int(startAxis), C.int(endAxis), DefaultStream().ctx) + return out +} + +func (t *Array) FloorDivide(other *Array) *Array { + out := New("FLOOR_DIVIDE", t, other) + C.mlx_floor_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) + return out +} + +func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array { + if lhs == nil { + lhs = New("") + } + if rhs == nil { + rhs = New("") + } + out := New("GATHER_MM", t, other, lhs, rhs) + C.mlx_gather_mm(&out.ctx, t.ctx, other.ctx, lhs.ctx, rhs.ctx, C.bool(sorted), DefaultStream().ctx) + return out +} + +func (t *Array) Logsumexp(keepDims bool) *Array { + out := New("LOGSUMEXP", t) + C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx) + return out +} + +func (t *Array) Matmul(other *Array) *Array { + out := New("MATMUL", t, other) + C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) + return out +} + +func (t *Array) Multiply(other *Array) *Array { + out := New("MULTIPLY", t, other) + C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) + return out +} + +func (t *Array) Negative() *Array { + out := New("NEGATIVE", t) + C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx) + return out +} + +func (t *Array) Power(exponent *Array) *Array { + out := New("POWER", t, exponent) + C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx) + return out +} + +func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array { + out := New("PUT_ALONG_AXIS", t, indices, values) + C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx) + return out +} + +func (t *Array) Reshape(axes ...int) *Array { + cAxes := make([]C.int, len(axes)) + for i := range axes { + cAxes[i] = C.int(axes[i]) + } + + out := New("RESHAPE", t) + C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx) + return out +} + +func (t *Array) Sigmoid() *Array { + out := New("SIGMOID", t) + C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx) + return out +} + +func (t *Array) Sqrt() *Array { + out := New("SQRT", t) + C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx) + return out +} + +func (t *Array) Squeeze(axis int) *Array { + out := New("SQUEEZE", t) + C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx) + return out +} + +func (t *Array) StackAxis(axis int, others ...*Array) *Array { + vectorData := make([]C.mlx_array, len(others)+1) + vectorData[0] = t.ctx + for i := range others { + vectorData[i+1] = others[i].ctx + } + + vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData))) + defer C.mlx_vector_array_free(vector) + + out := New("STACK_AXIS", append(others, t)...) + C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx) + return out +} + +func (t *Array) Subtract(other *Array) *Array { + out := New("SUBTRACT", t, other) + C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) + return out +} + +func (t *Array) SumAxis(axis int, keepDims bool) *Array { + out := New("SUM_AXIS", t) + C.mlx_sum_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx) + return out +} + +func (t *Array) TakeAxis(indices *Array, axis int) *Array { + out := New("TAKE_AXIS", t, indices) + C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx) + return out +} + +func (t *Array) TakeAlongAxis(indices *Array, axis int) *Array { + out := New("TAKE_ALONG_AXIS", t, indices) + C.mlx_take_along_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx) + return out +} + +func (t *Array) Tanh() *Array { + out := New("TANH", t) + C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx) + return out +} + +func (t *Array) Transpose(axes ...int) *Array { + cAxes := make([]C.int, len(axes)) + for i, axis := range axes { + cAxes[i] = C.int(axis) + } + + out := New("TRANSPOSE", t) + C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx) + return out +} + +func Zeros(dtype DType, shape ...int) *Array { + cAxes := make([]C.int, len(shape)) + for i := range shape { + cAxes[i] = C.int(shape[i]) + } + + t := New("ZEROS") + C.mlx_zeros(&t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), C.mlx_dtype(dtype), DefaultStream().ctx) + return t +} diff --git a/x/mlxrunner/mlx/ops_extra.go b/x/mlxrunner/mlx/ops_extra.go new file mode 100644 index 000000000..e5444a4f8 --- /dev/null +++ b/x/mlxrunner/mlx/ops_extra.go @@ -0,0 +1,427 @@ +//go:build mlx + +package mlx + +// #include "generated.h" +import "C" + +import ( + "reflect" + "unsafe" +) + +// Quantization operations + +func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) { + cMode := C.CString(mode) + defer C.free(unsafe.Pointer(cMode)) + optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true} + optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} + res := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(res) + C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, DefaultStream().ctx) + + vecSize := int(C.mlx_vector_array_size(res)) + w0 := New("QUANTIZE_W") + C.mlx_vector_array_get(&w0.ctx, res, 0) + w1 := New("QUANTIZE_S") + C.mlx_vector_array_get(&w1.ctx, res, 1) + if vecSize >= 3 { + w2 := New("QUANTIZE_B") + C.mlx_vector_array_get(&w2.ctx, res, 2) + return w0, w1, w2 + } + return w0, w1, nil +} + +func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array { + cMode := C.CString(mode) + defer C.free(unsafe.Pointer(cMode)) + optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true} + optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} + optDtype := C.mlx_optional_dtype{has_value: false} + + inputs := []*Array{w, scales} + var b C.mlx_array + if biases != nil { + b = biases.ctx + inputs = append(inputs, biases) + } + + out := New("DEQUANTIZE", inputs...) + C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx) + return out +} + +func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array { + cMode := C.CString(mode) + defer C.free(unsafe.Pointer(cMode)) + optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true} + optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} + + inputs := []*Array{x, w, scales} + var b C.mlx_array + if biases != nil { + b = biases.ctx + inputs = append(inputs, biases) + } + + out := New("QUANTIZED_MATMUL", inputs...) + C.mlx_quantized_matmul(&out.ctx, x.ctx, w.ctx, scales.ctx, b, C.bool(transpose), optGroupSize, optBits, cMode, DefaultStream().ctx) + return out +} + +func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, transpose bool, groupSize, bits int, mode string, sortedIndices bool) *Array { + cMode := C.CString(mode) + defer C.free(unsafe.Pointer(cMode)) + optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true} + optBits := C.mlx_optional_int{value: C.int(bits), has_value: true} + + inputs := []*Array{x, w, scales} + var b, lhs, rhs C.mlx_array + if biases != nil { + b = biases.ctx + inputs = append(inputs, biases) + } + if lhsIndices != nil { + lhs = lhsIndices.ctx + inputs = append(inputs, lhsIndices) + } + if rhsIndices != nil { + rhs = rhsIndices.ctx + inputs = append(inputs, rhsIndices) + } + + out := New("GATHER_QMM", inputs...) + C.mlx_gather_qmm(&out.ctx, x.ctx, w.ctx, scales.ctx, b, lhs, rhs, C.bool(transpose), optGroupSize, optBits, cMode, C.bool(sortedIndices), DefaultStream().ctx) + return out +} + +// Missing tensor ops + +func Tile(a *Array, reps []int32) *Array { + cReps := make([]C.int, len(reps)) + for i, r := range reps { + cReps[i] = C.int(r) + } + out := New("TILE", a) + C.mlx_tile(&out.ctx, a.ctx, unsafe.SliceData(cReps), C.size_t(len(reps)), DefaultStream().ctx) + return out +} + +func Tri(n, m int32, k int) *Array { + out := New("TRI") + C.mlx_tri(&out.ctx, C.int(n), C.int(m), C.int(k), C.mlx_dtype(DTypeFloat32), DefaultStream().ctx) + return out +} + +func Where(condition, a, b *Array) *Array { + out := New("WHERE", condition, a, b) + C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx) + return out +} + +// Convenience wrappers (function-style for the model code) + +func Stack(arrays []*Array, axis int) *Array { + vectorData := make([]C.mlx_array, len(arrays)) + for i := range arrays { + vectorData[i] = arrays[i].ctx + } + vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData))) + defer C.mlx_vector_array_free(vector) + + out := New("STACK", arrays...) + C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx) + return out +} + +func Neg(a *Array) *Array { + return a.Negative() +} + +func Sum(a *Array, axis int, keepDims bool) *Array { + return a.SumAxis(axis, keepDims) +} + +func Argsort(a *Array, axis int) *Array { + return a.ArgsortAxis(axis) +} + +func Take(a *Array, indices *Array, axis int) *Array { + return a.TakeAxis(indices, axis) +} + +func RSqrt(a *Array) *Array { + out := New("RSQRT", a) + C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + +func Mean(a *Array, axis int, keepDims bool) *Array { + out := New("MEAN_AXIS", a) + C.mlx_mean_axis(&out.ctx, a.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx) + return out +} + +func Argpartition(a *Array, kth int, axis int) *Array { + return a.ArgpartitionAxis(kth, axis) +} + +func TakeAlongAxis(a, indices *Array, axis int) *Array { + return a.TakeAlongAxis(indices, axis) +} + +// Function-style wrappers matching imagegen API + +func Add(a, b *Array) *Array { + return a.Add(b) +} + +func Sub(a, b *Array) *Array { + return a.Subtract(b) +} + +func Mul(a, b *Array) *Array { + return a.Multiply(b) +} + +func Div(a, b *Array) *Array { + return a.Divide(b) +} + +func Matmul(a, b *Array) *Array { + return a.Matmul(b) +} + +func Reshape(a *Array, shape ...int32) *Array { + axes := make([]int, len(shape)) + for i, s := range shape { + axes[i] = int(s) + } + return a.Reshape(axes...) +} + +func Transpose(a *Array, axes ...int) *Array { + return a.Transpose(axes...) +} + +func ExpandDims(a *Array, axis int) *Array { + return a.ExpandDims(axis) +} + +func Squeeze(a *Array, axis int) *Array { + return a.Squeeze(axis) +} + +func Flatten(a *Array) *Array { + return a.Flatten(0, -1) +} + +func Concatenate(arrays []*Array, axis int) *Array { + if len(arrays) == 0 { + return nil + } + return arrays[0].Concatenate(axis, arrays[1:]...) +} + +func SliceStartStop(a *Array, start, stop []int32) *Array { + n := len(start) + cStart := make([]C.int, n) + cStop := make([]C.int, n) + cStrides := make([]C.int, n) + for i := 0; i < n; i++ { + cStart[i] = C.int(start[i]) + cStop[i] = C.int(stop[i]) + cStrides[i] = 1 + } + out := New("SLICE", a) + C.mlx_slice(&out.ctx, a.ctx, unsafe.SliceData(cStart), C.size_t(n), unsafe.SliceData(cStop), C.size_t(n), unsafe.SliceData(cStrides), C.size_t(n), DefaultStream().ctx) + return out +} + +func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *Array { + if lhsIndices == nil { + lhsIndices = New("") + } + if rhsIndices == nil { + rhsIndices = New("") + } + return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices) +} + +func SiLU(a *Array) *Array { + sig := a.Sigmoid() + return a.Multiply(sig) +} + +func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array { + freqs := New("") + out := New("FAST_ROPE", x, freqs) + C.mlx_fast_rope( + &out.ctx, + x.ctx, + C.int(dims), + C.bool(traditional), + C.mlx_optional_float{ + value: C.float(base), + has_value: C.bool(func() bool { return base != 0 }()), + }, + C.float(scale), + C.int(offset), + freqs.ctx, + DefaultStream().ctx, + ) + return out +} + +func Sigmoid(a *Array) *Array { + return a.Sigmoid() +} + +func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array { + mask := New("") + sinks := New("") + mode := "" + if causalMask { + mode = "causal" + } + cMode := C.CString(mode) + defer C.free(unsafe.Pointer(cMode)) + + out := New("FAST_SDPA", q, k, v, mask, sinks) + C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx) + return out +} + +func RMSNormFn(x, weight *Array, eps float32) *Array { + out := New("FAST_RMSNORM", x) + C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx) + return out +} + +func AddMM(c, a, b *Array, alpha, beta float32) *Array { + return c.Addmm(a, b, alpha, beta) +} + +// Scalar helpers + +func AddScalar(a *Array, s float32) *Array { + scalar := FromValue(s) + return a.Add(scalar) +} + +func MulScalar(a *Array, s float32) *Array { + scalar := FromValue(s) + return a.Multiply(scalar) +} + +func DivScalar(a *Array, s float32) *Array { + scalar := FromValue(s) + return a.Divide(scalar) +} + +func FloorDivideScalar(a *Array, s int32) *Array { + scalar := FromValue(int(s)) + return a.FloorDivide(scalar) +} + +// Array constructors + +func NewArrayInt32(data []int32, shape []int32) *Array { + cShape := make([]C.int, len(shape)) + for i, s := range shape { + cShape[i] = C.int(s) + } + out := New("NEW_ARRAY_INT32") + out.ctx = C.mlx_array_new_data(unsafe.Pointer(&data[0]), unsafe.SliceData(cShape), C.int(len(shape)), C.mlx_dtype(DTypeInt32)) + return out +} + +func NewScalarArray(value float32) *Array { + out := New("SCALAR") + out.ctx = C.mlx_array_new_float32(C.float(value)) + return out +} + +func ZerosF32(shape []int32) *Array { + return Zeros(DTypeFloat32, func() []int { + ints := make([]int, len(shape)) + for i, s := range shape { + ints[i] = int(s) + } + return ints + }()...) +} + +// Utility + +func Collect(v any) []*Array { + var arrays []*Array + seen := make(map[uintptr]bool) + collect(reflect.ValueOf(v), &arrays, seen) + return arrays +} + +func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) { + if !v.IsValid() { + return + } + + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return + } + ptr := v.Pointer() + if seen[ptr] { + return + } + seen[ptr] = true + + if arr, ok := v.Interface().(*Array); ok { + if arr != nil && arr.Valid() { + *arrays = append(*arrays, arr) + } + return + } + collect(v.Elem(), arrays, seen) + return + } + + switch v.Kind() { + case reflect.Struct: + // Check if this struct IS an Array (not a pointer to one) + if arr, ok := v.Addr().Interface().(*Array); ok { + if arr != nil && arr.Valid() { + *arrays = append(*arrays, arr) + } + return + } + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + if field.CanInterface() { + collect(field, arrays, seen) + } + } + case reflect.Slice: + for i := 0; i < v.Len(); i++ { + collect(v.Index(i), arrays, seen) + } + case reflect.Map: + for _, key := range v.MapKeys() { + collect(v.MapIndex(key), arrays, seen) + } + case reflect.Interface: + if !v.IsNil() { + collect(v.Elem(), arrays, seen) + } + } +} + +func EnableCompile() { + C.mlx_enable_compile() +} + +func DisableCompile() { + C.mlx_disable_compile() +} diff --git a/x/mlxrunner/mlx/random.go b/x/mlxrunner/mlx/random.go new file mode 100644 index 000000000..805308b4a --- /dev/null +++ b/x/mlxrunner/mlx/random.go @@ -0,0 +1,13 @@ +//go:build mlx + +package mlx + +// #include "generated.h" +import "C" + +func (t *Array) Categorical(axis int) *Array { + key := New("") + out := New("", t, key) + C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx) + return out +} diff --git a/x/mlxrunner/mlx/slice.go b/x/mlxrunner/mlx/slice.go new file mode 100644 index 000000000..7ab7e2031 --- /dev/null +++ b/x/mlxrunner/mlx/slice.go @@ -0,0 +1,86 @@ +//go:build mlx + +package mlx + +// #include "generated.h" +import "C" + +import ( + "cmp" + "unsafe" +) + +type slice struct { + args []int +} + +func Slice(args ...int) slice { + return slice{args: args} +} + +func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) { + if len(slices) != len(dims) { + panic("number of slice arguments must match number of tensor dimensions") + } + + args := [3][]C.int{ + make([]C.int, len(slices)), + make([]C.int, len(slices)), + make([]C.int, len(slices)), + } + + for i, s := range slices { + switch len(s.args) { + case 0: + // slice[:] + args[0][i] = C.int(0) + args[1][i] = C.int(dims[i]) + args[2][i] = C.int(1) + case 1: + // slice[i] + args[0][i] = C.int(s.args[0]) + args[1][i] = C.int(s.args[0] + 1) + args[2][i] = C.int(1) + case 2: + // slice[i:j] + args[0][i] = C.int(s.args[0]) + args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i])) + args[2][i] = C.int(1) + case 3: + // slice[i:j:k] + args[0][i] = C.int(s.args[0]) + args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i])) + args[2][i] = C.int(s.args[2]) + default: + panic("invalid slice arguments") + } + } + + return args[0], args[1], args[2] +} + +func (t *Array) Slice(slices ...slice) *Array { + starts, stops, strides := makeSlices(t.Dims(), slices...) + out := New("SLICE", t) + C.mlx_slice( + &out.ctx, t.ctx, + unsafe.SliceData(starts), C.size_t(len(starts)), + unsafe.SliceData(stops), C.size_t(len(stops)), + unsafe.SliceData(strides), C.size_t(len(strides)), + DefaultStream().ctx, + ) + return out +} + +func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array { + starts, stops, strides := makeSlices(t.Dims(), slices...) + out := New("SLICE_UPDATE", t, other) + C.mlx_slice_update( + &out.ctx, t.ctx, other.ctx, + unsafe.SliceData(starts), C.size_t(len(starts)), + unsafe.SliceData(stops), C.size_t(len(stops)), + unsafe.SliceData(strides), C.size_t(len(strides)), + DefaultStream().ctx, + ) + return out +} diff --git a/x/mlxrunner/mlx/stream.go b/x/mlxrunner/mlx/stream.go new file mode 100644 index 000000000..83a3eeffd --- /dev/null +++ b/x/mlxrunner/mlx/stream.go @@ -0,0 +1,45 @@ +//go:build mlx + +package mlx + +// #include "generated.h" +import "C" + +import ( + "log/slog" + "sync" +) + +type Device struct { + ctx C.mlx_device +} + +func (d Device) LogValue() slog.Value { + str := C.mlx_string_new() + defer C.mlx_string_free(str) + C.mlx_device_tostring(&str, d.ctx) + return slog.StringValue(C.GoString(C.mlx_string_data(str))) +} + +var DefaultDevice = sync.OnceValue(func() Device { + d := C.mlx_device_new() + C.mlx_get_default_device(&d) + return Device{d} +}) + +type Stream struct { + ctx C.mlx_stream +} + +func (s Stream) LogValue() slog.Value { + str := C.mlx_string_new() + defer C.mlx_string_free(str) + C.mlx_stream_tostring(&str, s.ctx) + return slog.StringValue(C.GoString(C.mlx_string_data(str))) +} + +var DefaultStream = sync.OnceValue(func() Stream { + s := C.mlx_stream_new() + C.mlx_get_default_stream(&s, DefaultDevice().ctx) + return Stream{s} +}) diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go new file mode 100644 index 000000000..c094e5a3b --- /dev/null +++ b/x/mlxrunner/pipeline.go @@ -0,0 +1,123 @@ +//go:build mlx + +package mlxrunner + +import ( + "bytes" + "errors" + "log/slog" + "time" + "unicode/utf8" + + "github.com/ollama/ollama/x/mlxrunner/cache" + "github.com/ollama/ollama/x/mlxrunner/mlx" +) + +func (r *Runner) TextGenerationPipeline(request Request) error { + if r.Model == nil { + return errors.New("model not loaded") + } + + inputs := r.Tokenizer.Encode(request.Prompt, true) + + caches, tokens := r.FindNearestCache(inputs) + if len(caches) == 0 { + caches = make([]cache.Cache, r.Model.NumLayers()) + for i := range caches { + caches[i] = cache.NewKVCache() + } + } + + total, processed := len(tokens), 0 + slog.Info("Prompt processing progress", "processed", processed, "total", total) + for total-processed > 1 { + n := min(2<<10, total-processed-1) + temp := r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches) + defer mlx.Free(temp) + mlx.Eval(func() []*mlx.Array { + s := make([]*mlx.Array, 2*len(caches)) + for i, c := range caches { + s[2*i], s[2*i+1] = c.State() + } + return s + }()...) + processed += n + slog.Info("Prompt processing progress", "processed", processed, "total", total) + mlx.ClearCache() + } + + step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) { + logits := r.Model.Unembed(r.Model.Forward(token.ExpandDims(0), caches)) + logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1) + + logprobs := logits.Subtract(logits.Logsumexp(true)) + return request.Sample(logprobs), logprobs + } + + sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed)) + mlx.AsyncEval(sample, logprobs) + + var b bytes.Buffer + + now := time.Now() + final := Response{PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1} + outputs := make([]int32, 0, request.Options.MaxTokens) + for i := range request.Options.MaxTokens { + nextSample, nextLogprobs := step(sample) + mlx.AsyncEval(nextSample, nextLogprobs) + + if i == 0 { + slog.Info("Prompt processing progress", "processed", total, "total", total) + mlx.Eval(sample) + final.PromptTokensDuration = time.Since(now) + now = time.Now() + } + + output := int32(sample.Int()) + outputs = append(outputs, output) + + if r.Tokenizer.IsEOS(output) { + final.Token = int(output) + final.DoneReason = 0 + final.CompletionTokens = i + break + } + + request.Responses <- Response{ + Text: r.Decode(output, &b), + Token: int(output), + } + + mlx.Free(sample, logprobs) + if i%256 == 0 { + mlx.ClearCache() + } + + sample, logprobs = nextSample, nextLogprobs + } + + mlx.Free(sample, logprobs) + final.CompletionTokensDuration = time.Since(now) + request.Responses <- final + r.InsertCache(append(inputs, outputs...), caches) + return nil +} + +func (r Runner) Decode(sample int32, b *bytes.Buffer) string { + token := r.Tokenizer.Decode([]int32{sample}) + + if _, err := b.WriteString(token); err != nil { + slog.Error("Failed to write token to buffer", "error", err) + return "" + } + + if text := b.String(); utf8.ValidString(text) { + b.Reset() + return text + } else if b.Len() >= utf8.UTFMax { + b.Reset() + return text + } + + return "" +} diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go new file mode 100644 index 000000000..0b84b5a44 --- /dev/null +++ b/x/mlxrunner/runner.go @@ -0,0 +1,139 @@ +//go:build mlx + +package mlxrunner + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net" + "net/http" + "time" + + "golang.org/x/sync/errgroup" + + "github.com/ollama/ollama/x/imagegen/manifest" + "github.com/ollama/ollama/x/imagegen/tokenizer" + "github.com/ollama/ollama/x/mlxrunner/cache" + "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/mlxrunner/sample" + "github.com/ollama/ollama/x/models/glm4_moe_lite" +) + +// TextModel is the interface that model implementations must satisfy. +type TextModel interface { + Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array + Unembed(x *mlx.Array) *mlx.Array + NumLayers() int +} + +type Request struct { + TextCompletionsRequest + Responses chan Response + Pipeline func(Request) error + + sample.Sampler + caches []cache.Cache +} + +type TextCompletionsRequest struct { + Prompt string `json:"prompt"` + Options struct { + Temperature float32 `json:"temperature"` + TopP float32 `json:"top_p"` + MinP float32 `json:"min_p"` + TopK int `json:"top_k"` + MaxTokens int `json:"max_tokens"` + + // Deprecated: use MaxTokens instead + NumPredict int `json:"num_predict"` + } `json:"options"` +} + +type Response struct { + Text string `json:"content,omitempty"` + Token int `json:"token,omitempty"` + Logprobs []float32 `json:"logprobs,omitempty"` + Done bool `json:"done,omitempty"` + DoneReason int `json:"done_reason,omitempty"` + + PromptTokens int `json:"prompt_eval_count,omitempty"` + PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"` + CompletionTokens int `json:"eval_count,omitempty"` + CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"` + TotalTokens int `json:"total_tokens,omitempty"` +} + +type Runner struct { + Model TextModel + Tokenizer *tokenizer.Tokenizer + Requests chan Request + CacheEntries map[int32]*CacheEntry +} + +func (r *Runner) Load(modelName string) error { + modelManifest, err := manifest.LoadManifest(modelName) + if err != nil { + return err + } + + // Read config to detect architecture + configData, err := modelManifest.ReadConfig("config.json") + if err != nil { + return fmt.Errorf("failed to read config.json: %w", err) + } + + var archConfig struct { + Architectures []string `json:"architectures"` + } + if err := json.Unmarshal(configData, &archConfig); err != nil { + return fmt.Errorf("failed to parse config.json: %w", err) + } + + if len(archConfig.Architectures) == 0 { + return fmt.Errorf("no architectures found in config.json") + } + + slog.Info("Model architecture", "arch", archConfig.Architectures[0]) + + switch archConfig.Architectures[0] { + case "Glm4MoeLiteForCausalLM", "GLM4MoeLite": + model, err := glm4_moe_lite.LoadFromManifest(modelManifest) + if err != nil { + return fmt.Errorf("failed to load GLM4-MoE-Lite model: %w", err) + } + r.Model = model + r.Tokenizer = model.Tokenizer() + default: + return fmt.Errorf("unsupported architecture: %s", archConfig.Architectures[0]) + } + + return nil +} + +func (r *Runner) Run(host, port string, mux http.Handler) error { + g, ctx := errgroup.WithContext(context.Background()) + + g.Go(func() error { + for { + select { + case <-ctx.Done(): + return nil + case request := <-r.Requests: + if err := request.Pipeline(request); err != nil { + break + } + + close(request.Responses) + } + } + }) + + g.Go(func() error { + slog.Info("Starting HTTP server", "host", host, "port", port) + return http.ListenAndServe(net.JoinHostPort(host, port), mux) + }) + + return g.Wait() +} diff --git a/x/mlxrunner/sample/sample.go b/x/mlxrunner/sample/sample.go new file mode 100644 index 000000000..3a2e7577d --- /dev/null +++ b/x/mlxrunner/sample/sample.go @@ -0,0 +1,77 @@ +//go:build mlx + +package sample + +import ( + "math" + + "github.com/ollama/ollama/x/mlxrunner/mlx" +) + +type Sampler interface { + Sample(*mlx.Array) *mlx.Array +} + +func New(temp, top_p, min_p float32, top_k int) Sampler { + if temp == 0 { + return greedy{} + } + + var samplers []Sampler + if top_p > 0 && top_p < 1 { + samplers = append(samplers, TopP(top_p)) + } + + if min_p != 0 { + samplers = append(samplers, MinP(min_p)) + } + + if top_k > 0 { + samplers = append(samplers, TopK(top_k)) + } + + samplers = append(samplers, Temperature(temp)) + return chain(samplers) +} + +type greedy struct{} + +func (greedy) Sample(logits *mlx.Array) *mlx.Array { + return logits.Argmax(-1, false) +} + +type chain []Sampler + +func (c chain) Sample(logits *mlx.Array) *mlx.Array { + for _, sampler := range c { + logits = sampler.Sample(logits) + } + return logits +} + +type Temperature float32 + +func (t Temperature) Sample(logits *mlx.Array) *mlx.Array { + return logits.Multiply(mlx.FromValue(1 / float32(t))).Categorical(-1) +} + +type TopP float32 + +func (p TopP) Sample(logprobs *mlx.Array) *mlx.Array { + // TODO: implement + return logprobs +} + +type MinP float32 + +func (p MinP) Sample(logprobs *mlx.Array) *mlx.Array { + // TODO: implement + return logprobs +} + +type TopK int + +func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array { + mask := logprobs.Negative().ArgpartitionAxis(int(k)-1, -1).Slice(mlx.Slice(), mlx.Slice(int(k), 0)) + return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1) +} diff --git a/x/mlxrunner/server.go b/x/mlxrunner/server.go new file mode 100644 index 000000000..d8460b8d7 --- /dev/null +++ b/x/mlxrunner/server.go @@ -0,0 +1,176 @@ +//go:build mlx + +package mlxrunner + +import ( + "bytes" + "cmp" + "encoding/json" + "flag" + "io" + "log/slog" + "net/http" + "os" + "strconv" + "time" + + "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/x/mlxrunner/sample" +) + +func Execute(args []string) error { + slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel())) + + var ( + modelName string + port int + ) + + flagSet := flag.NewFlagSet("mlxrunner", flag.ExitOnError) + flagSet.StringVar(&modelName, "model", "", "Model name") + flagSet.IntVar(&port, "port", 0, "Port to listen on") + _ = flagSet.Bool("verbose", false, "Enable debug logging") + flagSet.Parse(args) + + runner := Runner{ + Requests: make(chan Request), + CacheEntries: make(map[int32]*CacheEntry), + } + + if err := runner.Load(modelName); err != nil { + return err + } + + mux := http.NewServeMux() + mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) { + if err := json.NewEncoder(w).Encode(map[string]any{ + "status": 0, + "progress": 100, + }); err != nil { + slog.Error("Failed to encode response", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + }) + + mux.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case "POST": + fallthrough + case "GET": + if err := json.NewEncoder(w).Encode(map[string]any{ + "Success": true, + }); err != nil { + slog.Error("Failed to encode response", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + case "DELETE": + // TODO: cleanup model and cache + } + }) + + mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) { + request := Request{Responses: make(chan Response)} + + if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil { + slog.Error("Failed to decode request", "error", err) + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + + request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict) + if request.Options.MaxTokens < 1 { + request.Options.MaxTokens = 16 << 10 + } + + request.Pipeline = runner.TextGenerationPipeline + request.Sampler = sample.New( + request.Options.Temperature, + request.Options.TopP, + request.Options.MinP, + request.Options.TopK, + ) + + runner.Requests <- request + + w.Header().Set("Content-Type", "application/jsonl") + w.WriteHeader(http.StatusOK) + enc := json.NewEncoder(w) + for response := range request.Responses { + if err := enc.Encode(response); err != nil { + slog.Error("Failed to encode response", "error", err) + return + } + + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } + }) + + mux.HandleFunc("POST /v1/tokenize", func(w http.ResponseWriter, r *http.Request) { + var b bytes.Buffer + if _, err := io.Copy(&b, r.Body); err != nil { + slog.Error("Failed to read request body", "error", err) + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + + tokens := runner.Tokenizer.Encode(b.String(), true) + + if err := json.NewEncoder(w).Encode(tokens); err != nil { + slog.Error("Failed to encode response", "error", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + }) + + for source, target := range map[string]string{ + "GET /health": "/v1/status", + "POST /load": "/v1/models", + "POST /completion": "/v1/completions", + } { + mux.Handle(source, http.RedirectHandler(target, http.StatusPermanentRedirect)) + } + + return runner.Run("127.0.0.1", strconv.Itoa(port), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + recorder := &statusRecorder{ResponseWriter: w, code: http.StatusOK} + t := time.Now() + mux.ServeHTTP(recorder, r) + + var level slog.Level + switch { + case recorder.code >= 500: + level = slog.LevelError + case recorder.code >= 400: + level = slog.LevelWarn + case recorder.code >= 300: + return + } + + slog.Log(r.Context(), level, "ServeHTTP", "method", r.Method, "path", r.URL.Path, "took", time.Since(t), "status", recorder.Status()) + })) +} + +type statusRecorder struct { + http.ResponseWriter + code int +} + +func (w *statusRecorder) WriteHeader(code int) { + w.code = code + w.ResponseWriter.WriteHeader(code) +} + +func (w *statusRecorder) Status() string { + return strconv.Itoa(w.code) + " " + http.StatusText(w.code) +} + +func (w *statusRecorder) Flush() { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} diff --git a/x/mlxrunner/server_stub.go b/x/mlxrunner/server_stub.go new file mode 100644 index 000000000..3b0f35500 --- /dev/null +++ b/x/mlxrunner/server_stub.go @@ -0,0 +1,10 @@ +//go:build !mlx + +package mlxrunner + +import "errors" + +// Execute returns an error when not built with MLX support. +func Execute(args []string) error { + return errors.New("MLX runner not available: build with mlx tag") +} diff --git a/x/models/glm4_moe_lite/glm4_moe_lite.go b/x/models/glm4_moe_lite/glm4_moe_lite.go new file mode 100644 index 000000000..091e95839 --- /dev/null +++ b/x/models/glm4_moe_lite/glm4_moe_lite.go @@ -0,0 +1,860 @@ +//go:build mlx + +// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX. +// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE). +package glm4_moe_lite + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "io" + "math" + "os" + "strings" + + "github.com/ollama/ollama/x/imagegen/manifest" + "github.com/ollama/ollama/x/imagegen/tokenizer" + "github.com/ollama/ollama/x/mlxrunner/cache" + "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/models/nn" +) + +// RopeScaling holds RoPE scaling configuration +type RopeScaling struct { + Factor float32 `json:"factor"` + MscaleAllDim float32 `json:"mscale_all_dim"` +} + +// Config holds GLM4-MoE-Lite model configuration +type Config struct { + HiddenSize int32 `json:"hidden_size"` + NumHiddenLayers int32 `json:"num_hidden_layers"` + IntermediateSize int32 `json:"intermediate_size"` + MoEIntermediateSize int32 `json:"moe_intermediate_size"` + NumAttentionHeads int32 `json:"num_attention_heads"` + NumKeyValueHeads int32 `json:"num_key_value_heads"` + VocabSize int32 `json:"vocab_size"` + RMSNormEps float32 `json:"rms_norm_eps"` + RopeTheta float32 `json:"rope_theta"` + MaxPositionEmbeddings int32 `json:"max_position_embeddings"` + AttentionBias bool `json:"attention_bias"` + + // MLA (Multi-head Latent Attention) parameters + QLoraRank int32 `json:"q_lora_rank"` + KVLoraRank int32 `json:"kv_lora_rank"` + QKRopeHeadDim int32 `json:"qk_rope_head_dim"` + QKNopeHeadDim int32 `json:"qk_nope_head_dim"` + VHeadDim int32 `json:"v_head_dim"` + + // MoE parameters + NRoutedExperts int32 `json:"n_routed_experts"` + NSharedExperts int32 `json:"n_shared_experts"` + NumExpertsPerTok int32 `json:"num_experts_per_tok"` + RoutedScalingFactor float32 `json:"routed_scaling_factor"` + NormTopKProb bool `json:"norm_topk_prob"` + FirstKDenseReplace int32 `json:"first_k_dense_replace"` + NGroup int32 `json:"n_group"` + TopKGroup int32 `json:"topk_group"` + + // RoPE scaling + RopeScaling *RopeScaling `json:"rope_scaling"` + + // Quantization parameters (set during load based on model quantization) + QuantGroupSize int `json:"-"` // Group size for quantization (default 64) + QuantBits int `json:"-"` // Bits per weight (4 or 8) + QuantMode string `json:"-"` // Quantization mode ("affine", etc.) + + // Computed fields + QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim + Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment +} + +// MLAAttention implements Multi-head Latent Attention with absorption. +type MLAAttention struct { + QAProj nn.LinearLayer + QALayerNorm *nn.RMSNorm + QBProj nn.LinearLayer + + KVAProjWithMQA nn.LinearLayer + KVALayerNorm *nn.RMSNorm + + EmbedQ *nn.MultiLinear + UnembedOut *nn.MultiLinear + + OProj nn.LinearLayer +} + +// Forward computes absorbed MLA attention output. +func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { + q := a.QAProj.Forward(x) + q = a.QALayerNorm.Forward(q, cfg.RMSNormEps) + q = a.QBProj.Forward(q) + + q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim) + q = mlx.Transpose(q, 0, 2, 1, 3) + + qNope := mlx.SliceStartStop(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim}) + qPE := mlx.SliceStartStop(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim}) + + compressedKV := a.KVAProjWithMQA.Forward(x) + + kvCompressed := mlx.SliceStartStop(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank}) + kPE := mlx.SliceStartStop(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim}) + + kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim) + kPE = mlx.Transpose(kPE, 0, 2, 1, 3) + + kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps) + kvLatent = mlx.ExpandDims(kvLatent, 1) + + offset := 0 + if c != nil { + offset = c.Offset() + } + qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset) + kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset) + + qLatent := a.EmbedQ.Forward(qNope) + + keys := mlx.Concatenate([]*mlx.Array{kvLatent, kPE}, 3) + + cachedL := L + if c != nil { + placeholderValues := mlx.ZerosF32([]int32{B, 1, L, 0}) + keys, _ = c.Update(keys, placeholderValues) + cachedL = int32(keys.Dim(2)) + } + + values := mlx.SliceStartStop(keys, []int32{0, 0, 0, 0}, []int32{B, 1, cachedL, cfg.KVLoraRank}) + + queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3) + + out := mlx.ScaledDotProductAttentionCausal(queries, keys, values, cfg.Scale, L > 1) + + out = a.UnembedOut.Forward(out) + + out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim) + + return a.OProj.Forward(out) +} + +// DenseMLP implements the standard SwiGLU MLP for dense layers +type DenseMLP struct { + GateProj nn.LinearLayer + UpProj nn.LinearLayer + DownProj nn.LinearLayer +} + +// Forward applies the SwiGLU MLP +func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array { + gate := mlx.SiLU(m.GateProj.Forward(x)) + up := m.UpProj.Forward(x) + return m.DownProj.Forward(mlx.Mul(gate, up)) +} + +// MoEGate implements the expert gating mechanism +type MoEGate struct { + Gate nn.LinearLayer + EScoreCorrectionBias *mlx.Array +} + +// Forward computes expert selection indices and scores +func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) { + gates := g.Gate.Forward(x) + + scores := mlx.Sigmoid(gates) + origScores := scores + + if g.EScoreCorrectionBias != nil { + scores = mlx.Add(scores, g.EScoreCorrectionBias) + } + + topK := cfg.NumExpertsPerTok + negScores := mlx.Neg(scores) + inds := mlx.Argpartition(negScores, int(topK)-1, -1) + + dims := inds.Dims() + inds = mlx.SliceStartStop(inds, []int32{0, 0, 0}, []int32{int32(dims[0]), int32(dims[1]), topK}) + + scores = mlx.TakeAlongAxis(origScores, inds, -1) + + if topK > 1 && cfg.NormTopKProb { + sumScores := mlx.Sum(scores, -1, true) + scores = mlx.Div(scores, sumScores) + } + + scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor) + + return inds, scores +} + +// SwitchMLP implements the MoE expert computation using stacked weights +type SwitchMLP struct { + GateWeight *mlx.Array + UpWeight *mlx.Array + DownWeight *mlx.Array + + GateWeightQ, GateScales, GateBiases *mlx.Array + UpWeightQ, UpScales, UpBiases *mlx.Array + DownWeightQ, DownScales, DownBiases *mlx.Array + + GateBits int + UpBits int + DownBits int + + GateGroupSize int + UpGroupSize int + DownGroupSize int + + UseQuantized bool +} + +// Forward applies the switched expert MLP +func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array { + dims := x.Dims() + B, L := int32(dims[0]), int32(dims[1]) + topK := cfg.NumExpertsPerTok + + xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2) + + xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize) + + idxFlat := mlx.Reshape(indices, B*L, topK) + + doSort := B*L >= 64 + var invOrder *mlx.Array + n := B * L * topK + + if doSort { + idxAll := mlx.Flatten(idxFlat) + order := mlx.Argsort(idxAll, 0) + invOrder = mlx.Argsort(order, 0) + xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1) + idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1) + } + + var gate, up, hidden, down *mlx.Array + + if s.UseQuantized { + gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases, + nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort) + up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases, + nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort) + + hidden = mlx.Mul(mlx.SiLU(gate), up) + + down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases, + nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort) + } else { + gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort) + up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort) + + hidden = mlx.Mul(mlx.SiLU(gate), up) + + down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort) + } + + if doSort { + down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize) + } else { + down = mlx.Squeeze(down, 2) + } + + return mlx.Reshape(down, B, L, topK, cfg.HiddenSize) +} + +// SharedExperts implements the shared expert MLP +type SharedExperts struct { + GateProj nn.LinearLayer + UpProj nn.LinearLayer + DownProj nn.LinearLayer +} + +// Forward applies the shared expert MLP +func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array { + gate := mlx.SiLU(s.GateProj.Forward(x)) + up := s.UpProj.Forward(x) + return s.DownProj.Forward(mlx.Mul(gate, up)) +} + +// MoE implements the full Mixture of Experts layer +type MoE struct { + Gate *MoEGate + SwitchMLP *SwitchMLP + SharedExperts *SharedExperts +} + +// Forward applies the MoE layer +func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array { + dims := x.Dims() + B, L := int32(dims[0]), int32(dims[1]) + + inds, scores := m.Gate.Forward(x, cfg) + + expertOut := m.SwitchMLP.Forward(x, inds, cfg) + + scoresExpanded := mlx.ExpandDims(scores, -1) + y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false) + + if m.SharedExperts != nil { + y = mlx.Add(y, m.SharedExperts.Forward(x)) + } + + return mlx.Reshape(y, B, L, cfg.HiddenSize) +} + +// DenseBlock represents a dense transformer block (for first_k_dense_replace layers) +type DenseBlock struct { + Attention *MLAAttention + MLP *DenseMLP + InputLayerNorm *nn.RMSNorm + PostAttentionLayerNorm *nn.RMSNorm +} + +// Forward applies the dense block +func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { + r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg) + h := mlx.Add(x, r) + + r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps)) + return mlx.Add(h, r) +} + +// MoEBlock represents a MoE transformer block +type MoEBlock struct { + Attention *MLAAttention + MoE *MoE + InputLayerNorm *nn.RMSNorm + PostAttentionLayerNorm *nn.RMSNorm +} + +// Forward applies the MoE block +func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { + r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg) + h := mlx.Add(x, r) + + r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg) + return mlx.Add(h, r) +} + +// Block interface for both dense and MoE blocks +type Block interface { + Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array +} + +// Model represents the complete GLM4-MoE-Lite model +type Model struct { + EmbedTokens *nn.Embedding + Layers []Block + Norm *nn.RMSNorm + LMHead nn.LinearLayer + + tok *tokenizer.Tokenizer + *Config +} + +// computeScale computes the attention scale. +func computeScale(cfg *Config) float32 { + keyLength := cfg.QKNopeHeadDim + cfg.QKRopeHeadDim + scale := float32(1.0 / math.Sqrt(float64(keyLength))) + if cfg.RopeScaling != nil && cfg.RopeScaling.MscaleAllDim > 0 && cfg.RopeScaling.Factor > 1 { + s := 0.1*cfg.RopeScaling.MscaleAllDim*float32(math.Log(float64(cfg.RopeScaling.Factor))) + 1.0 + scale *= s * s + } + return scale +} + +// supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support. +func supportsGatherQMM(mode string, bits int) bool { + return mode == "affine" && (bits == 4 || bits == 8) +} + +// quantizationParams returns groupSize, bits, mode for a quantization type string. +func quantizationParams(quantization string) (groupSize, bits int, mode string) { + switch strings.ToUpper(quantization) { + case "NVFP4": + return 16, 4, "nvfp4" + case "FP4", "Q4", "INT4": + return 32, 4, "affine" + case "MXFP8": + return 32, 8, "mxfp8" + case "FP8", "Q8", "INT8", "": + return 64, 8, "affine" + default: + return 32, 8, "affine" + } +} + +// readBlobMetadata reads the __metadata__ from a safetensors blob header. +func readBlobMetadata(path string) (map[string]string, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + var headerSize uint64 + if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil { + return nil, err + } + if headerSize > 1024*1024 { + return nil, fmt.Errorf("header too large: %d", headerSize) + } + + data := make([]byte, headerSize) + if _, err := io.ReadFull(f, data); err != nil { + return nil, err + } + + var header map[string]json.RawMessage + if err := json.Unmarshal(data, &header); err != nil { + return nil, err + } + + metaRaw, ok := header["__metadata__"] + if !ok { + return nil, nil + } + + var meta map[string]string + if err := json.Unmarshal(metaRaw, &meta); err != nil { + return nil, err + } + return meta, nil +} + +// ExpertWeight holds a single expert's weight with optional quantization components. +type ExpertWeight struct { + Weight *mlx.Array + Scales *mlx.Array + Biases *mlx.Array + Bits int + GroupSize int +} + +// loadExpertWeight loads an expert weight from the tensor map. +func loadExpertWeight(tensors map[string]*mlx.Array, path string, useQuantized bool, cfg *Config) *ExpertWeight { + w := tensors[path+".weight"] + if w == nil { + return nil + } + + scales := tensors[path+".weight_scale"] + if scales != nil { + qbiases := tensors[path+".weight_qbias"] + + groupSize, bits, mode := cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode + + if useQuantized && supportsGatherQMM(mode, bits) { + return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize} + } + + return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)} + } + + return &ExpertWeight{Weight: w} +} + +// StackedExpertWeights holds stacked weights for all experts. +type StackedExpertWeights struct { + Weight *mlx.Array + Scales *mlx.Array + Biases *mlx.Array + Bits int + GroupSize int +} + +// collectAndStackExpertWeights loads and stacks expert weights for one projection type. +func collectAndStackExpertWeights( + tensors map[string]*mlx.Array, + prefix string, + projName string, + numExperts int32, + useQuantized bool, + cfg *Config, +) *StackedExpertWeights { + var w, s, b []*mlx.Array + var bits, groupSize int + + for e := int32(0); e < numExperts; e++ { + path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName) + ew := loadExpertWeight(tensors, path, useQuantized, cfg) + if ew == nil { + continue + } + w = append(w, ew.Weight) + if ew.Scales != nil { + s = append(s, ew.Scales) + } + if ew.Biases != nil { + b = append(b, ew.Biases) + } + if e == 0 { + bits = ew.Bits + groupSize = ew.GroupSize + } + } + + result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize} + if len(w) > 0 { + result.Weight = mlx.Stack(w, 0) + if len(s) > 0 { + result.Scales = mlx.Stack(s, 0) + } + if len(b) > 0 { + result.Biases = mlx.Stack(b, 0) + } + } + return result +} + +// sanitizeExpertWeights stacks individual expert weights into tensors. +func sanitizeExpertWeights(tensors map[string]*mlx.Array, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) { + gate = collectAndStackExpertWeights(tensors, prefix, "gate_proj", numExperts, useQuantized, cfg) + up = collectAndStackExpertWeights(tensors, prefix, "up_proj", numExperts, useQuantized, cfg) + down = collectAndStackExpertWeights(tensors, prefix, "down_proj", numExperts, useQuantized, cfg) + return gate, up, down +} + +// sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format. +func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Config) (*mlx.Array, *mlx.Array) { + path := prefix + ".self_attn.kv_b_proj" + w := tensors[path+".weight"] + if w == nil { + return nil, nil + } + + // Check if quantized and dequantize + if scales := tensors[path+".weight_scale"]; scales != nil { + qbiases := tensors[path+".weight_qbias"] + w = mlx.Dequantize(w, scales, qbiases, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode) + } + + headDim := cfg.QKNopeHeadDim + cfg.VHeadDim + w = mlx.Reshape(w, cfg.NumAttentionHeads, headDim, cfg.KVLoraRank) + + wk := mlx.SliceStartStop(w, []int32{0, 0, 0}, []int32{cfg.NumAttentionHeads, cfg.QKNopeHeadDim, cfg.KVLoraRank}) + wv := mlx.SliceStartStop(w, []int32{0, cfg.QKNopeHeadDim, 0}, []int32{cfg.NumAttentionHeads, headDim, cfg.KVLoraRank}) + + embedQ := mlx.Transpose(wk, 0, 2, 1) + unembedOut := wv + + return embedQ, unembedOut +} + +// makeLinear creates a Linear or QuantizedLinear layer from the tensor map. +func makeLinear(tensors map[string]*mlx.Array, path string, cfg *Config) nn.LinearLayer { + w := tensors[path+".weight"] + if w == nil { + return nil + } + + scales := tensors[path+".weight_scale"] + if scales != nil { + qbiases := tensors[path+".weight_qbias"] + bias := tensors[path+".bias"] + return &nn.QuantizedLinear{ + Weight: w, + Scales: scales, + QBiases: qbiases, + Bias: bias, + GroupSize: cfg.QuantGroupSize, + Bits: cfg.QuantBits, + Mode: cfg.QuantMode, + } + } + + bias := tensors[path+".bias"] + return nn.NewLinear(w, bias) +} + +// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage). +func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) { + configData, err := modelManifest.ReadConfig("config.json") + if err != nil { + return nil, fmt.Errorf("load config: %w", err) + } + + var cfg Config + if err := json.Unmarshal(configData, &cfg); err != nil { + return nil, fmt.Errorf("parse config: %w", err) + } + + cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim + cfg.Scale = computeScale(&cfg) + + // Load all tensors from manifest blobs into a flat map + allTensors := make(map[string]*mlx.Array) + seen := make(map[string]bool) // dedupe by digest + var quantType string + var quantGroupSize int + + for _, layer := range modelManifest.GetTensorLayers("") { + if seen[layer.Digest] { + continue + } + seen[layer.Digest] = true + blobPath := modelManifest.BlobPath(layer.Digest) + + // Read quantization metadata from first blob + if quantType == "" { + if meta, err := readBlobMetadata(blobPath); err == nil && meta != nil { + if qt := meta["quant_type"]; qt != "" { + quantType = strings.ToUpper(qt) + } + if gs := meta["group_size"]; gs != "" { + fmt.Sscanf(gs, "%d", &quantGroupSize) + } + } + } + + for name, arr := range mlx.Load(blobPath) { + // Map safetensors key naming to our naming convention + // Combined blobs use ".scale" and ".bias" suffixes + if strings.HasSuffix(name, ".scale") { + baseName := strings.TrimSuffix(name, ".scale") + allTensors[baseName+"_scale"] = arr + } else if strings.HasSuffix(name, ".bias") && !strings.HasSuffix(name, ".weight_qbias") { + // Check if this is a quantization bias or a regular bias + // by checking if there's a corresponding weight + baseName := strings.TrimSuffix(name, ".bias") + if _, hasScale := allTensors[baseName+"_scale"]; hasScale { + allTensors[baseName+"_qbias"] = arr + } else { + allTensors[name] = arr + } + } else { + allTensors[name] = arr + } + } + } + + // Set up quantization parameters + useQuantized := false + if quantType != "" { + _, cfg.QuantBits, cfg.QuantMode = quantizationParams(quantType) + if quantGroupSize > 0 { + cfg.QuantGroupSize = quantGroupSize + } else { + cfg.QuantGroupSize, _, _ = quantizationParams(quantType) + } + useQuantized = supportsGatherQMM(cfg.QuantMode, cfg.QuantBits) + } + + // Load tokenizer + tokData, err := modelManifest.ReadConfig("tokenizer.json") + if err != nil { + return nil, fmt.Errorf("load tokenizer config: %w", err) + } + + tokConfig := &tokenizer.TokenizerConfig{ + ConfigJSON: configData, + } + + if genConfigData, err := modelManifest.ReadConfig("generation_config.json"); err == nil { + tokConfig.GenerationConfigJSON = genConfigData + } + + if tokConfigData, err := modelManifest.ReadConfig("tokenizer_config.json"); err == nil { + tokConfig.TokenizerConfigJSON = tokConfigData + } + + tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig) + if err != nil { + return nil, fmt.Errorf("parse tokenizer: %w", err) + } + + m := &Model{ + Layers: make([]Block, cfg.NumHiddenLayers), + Config: &cfg, + tok: tok, + } + + // Load embedding + if w := allTensors["model.embed_tokens.weight"]; w != nil { + m.EmbedTokens = nn.NewEmbedding(w) + } + + // Load final norm + if w := allTensors["model.norm.weight"]; w != nil { + m.Norm = nn.NewRMSNorm(w, cfg.RMSNormEps) + } + + // Load LM head + m.LMHead = makeLinear(allTensors, "lm_head", &cfg) + + // Load layers + for i := int32(0); i < cfg.NumHiddenLayers; i++ { + prefix := fmt.Sprintf("model.layers.%d", i) + + // Load attention (same for both block types) + attn := &MLAAttention{} + attn.QAProj = makeLinear(allTensors, prefix+".self_attn.q_a_proj", &cfg) + if w := allTensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil { + attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps) + } + attn.QBProj = makeLinear(allTensors, prefix+".self_attn.q_b_proj", &cfg) + attn.KVAProjWithMQA = makeLinear(allTensors, prefix+".self_attn.kv_a_proj_with_mqa", &cfg) + if w := allTensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil { + attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps) + } + attn.OProj = makeLinear(allTensors, prefix+".self_attn.o_proj", &cfg) + + // Sanitize MLA weights for absorbed attention + embedQ, unembedOut := sanitizeMLAWeights(allTensors, prefix, &cfg) + attn.EmbedQ = nn.NewMultiLinear(embedQ) + attn.UnembedOut = nn.NewMultiLinear(unembedOut) + + inputLN := allTensors[prefix+".input_layernorm.weight"] + postAttnLN := allTensors[prefix+".post_attention_layernorm.weight"] + + if i < cfg.FirstKDenseReplace { + // Dense block + block := &DenseBlock{Attention: attn} + if inputLN != nil { + block.InputLayerNorm = nn.NewRMSNorm(inputLN, cfg.RMSNormEps) + } + if postAttnLN != nil { + block.PostAttentionLayerNorm = nn.NewRMSNorm(postAttnLN, cfg.RMSNormEps) + } + + block.MLP = &DenseMLP{ + GateProj: makeLinear(allTensors, prefix+".mlp.gate_proj", &cfg), + UpProj: makeLinear(allTensors, prefix+".mlp.up_proj", &cfg), + DownProj: makeLinear(allTensors, prefix+".mlp.down_proj", &cfg), + } + + m.Layers[i] = block + } else { + // MoE block + block := &MoEBlock{Attention: attn} + if inputLN != nil { + block.InputLayerNorm = nn.NewRMSNorm(inputLN, cfg.RMSNormEps) + } + if postAttnLN != nil { + block.PostAttentionLayerNorm = nn.NewRMSNorm(postAttnLN, cfg.RMSNormEps) + } + + // Stack expert weights + gate, up, down := sanitizeExpertWeights(allTensors, prefix, cfg.NRoutedExperts, useQuantized, &cfg) + + switchMLP := &SwitchMLP{UseQuantized: useQuantized} + if useQuantized { + switchMLP.GateWeightQ = gate.Weight + switchMLP.GateScales = gate.Scales + switchMLP.GateBiases = gate.Biases + switchMLP.GateBits = gate.Bits + switchMLP.GateGroupSize = gate.GroupSize + switchMLP.UpWeightQ = up.Weight + switchMLP.UpScales = up.Scales + switchMLP.UpBiases = up.Biases + switchMLP.UpBits = up.Bits + switchMLP.UpGroupSize = up.GroupSize + switchMLP.DownWeightQ = down.Weight + switchMLP.DownScales = down.Scales + switchMLP.DownBiases = down.Biases + switchMLP.DownBits = down.Bits + switchMLP.DownGroupSize = down.GroupSize + } else { + switchMLP.GateWeight = gate.Weight + switchMLP.UpWeight = up.Weight + switchMLP.DownWeight = down.Weight + } + + moeGate := &MoEGate{} + moeGate.Gate = makeLinear(allTensors, prefix+".mlp.gate", &cfg) + if bias := allTensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil { + moeGate.EScoreCorrectionBias = bias + } + + block.MoE = &MoE{ + Gate: moeGate, + SwitchMLP: switchMLP, + } + + // Load shared experts if present + if cfg.NSharedExperts > 0 { + block.MoE.SharedExperts = &SharedExperts{ + GateProj: makeLinear(allTensors, prefix+".mlp.shared_experts.gate_proj", &cfg), + UpProj: makeLinear(allTensors, prefix+".mlp.shared_experts.up_proj", &cfg), + DownProj: makeLinear(allTensors, prefix+".mlp.shared_experts.down_proj", &cfg), + } + } + + m.Layers[i] = block + } + } + + mlx.Eval(mlx.Collect(m)...) + + return m, nil +} + +// Forward computes the forward pass of the model +func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { + dims := tokens.Dims() + B, L := int32(dims[0]), int32(dims[1]) + + h := m.EmbedTokens.Forward(tokens) + + for i, layer := range m.Layers { + var c cache.Cache + if caches != nil { + c = caches[i] + } + h = layer.Forward(h, c, B, L, m.Config) + } + + h = m.Norm.Forward(h, m.RMSNormEps) + return h +} + +// Unembed applies the LM head to get logits. +func (m *Model) Unembed(x *mlx.Array) *mlx.Array { + return m.LMHead.Forward(x) +} + +// NumLayers returns the number of transformer layers +func (m *Model) NumLayers() int { return len(m.Layers) } + +// MaxContextLength returns the maximum context length +func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings } + +// VocabSize returns the vocabulary size +func (m *Model) VocabSize() int32 { return m.Config.VocabSize } + +// Tokenizer returns the model's tokenizer +func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok } + +// NewCache creates a new KV cache for the model +func (m *Model) NewCache(maxSeqLen int32) []cache.Cache { + caches := make([]cache.Cache, len(m.Layers)) + for i := range caches { + caches[i] = cache.NewKVCache() + } + return caches +} + +// FormatPrompt applies the GLM-4 chat template with thinking enabled by default. +func (m *Model) FormatPrompt(prompt string) string { + return "[gMASK]<|user|>" + prompt + "<|assistant|>" +} + +// FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control. +func (m *Model) FormatPromptWithThinking(prompt string, think bool) string { + if think { + return "[gMASK]<|user|>" + prompt + "<|assistant|>" + } + return "[gMASK]<|user|>" + prompt + "<|assistant|>" +} + +// NewRenderer returns a new Renderer for formatting multi-turn conversations. +func (m *Model) NewRenderer() *Renderer { + return &Renderer{} +} + +// NewParser returns a new Parser for extracting thinking and tool calls from output. +func (m *Model) NewParser() *Parser { + return &Parser{} +} diff --git a/x/models/glm4_moe_lite/parser.go b/x/models/glm4_moe_lite/parser.go new file mode 100644 index 000000000..c81ec5a40 --- /dev/null +++ b/x/models/glm4_moe_lite/parser.go @@ -0,0 +1,479 @@ +//go:build mlx + +package glm4_moe_lite + +import ( + "context" + "encoding/json" + "encoding/xml" + "fmt" + "log/slog" + "strings" + "unicode" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/logutil" +) + +type parserState int + +const ( + parserState_LookingForThinkingOpen parserState = iota + parserState_ThinkingStartedEatingWhitespace + parserState_CollectingThinking + parserState_ThinkingDoneEatingWhitespace + parserState_CollectingContent + parserState_ToolStartedEatingWhitespace + parserState_CollectingToolContent +) + +const ( + thinkingOpenTag = "" + thinkingCloseTag = "" + toolOpenTag = "" + toolCloseTag = "" +) + +// Parser parses GLM4-MoE-Lite model output to extract thinking and tool calls. +// GLM-4's prompt ends with when thinking is enabled, so the parser +// must start in CollectingThinking state (the model outputs thinking content directly). +type Parser struct { + state parserState + buffer strings.Builder + tools []api.Tool +} + +// HasToolSupport returns true as GLM4 supports tool calling. +func (p *Parser) HasToolSupport() bool { + return true +} + +// HasThinkingSupport returns true as GLM4 supports thinking mode. +func (p *Parser) HasThinkingSupport() bool { + return true +} + +// Init initializes the parser with tools and thinking configuration. +func (p *Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool { + p.tools = tools + // When thinking is enabled (nil or true), the prompt ends with , + // so model output starts directly with thinking content (no opening tag). + if thinkValue == nil || thinkValue.Bool() { + p.state = parserState_CollectingThinking + } + return tools +} + +type parserEvent interface { + isParserEvent() +} + +type eventContent struct { + content string +} + +func (eventContent) isParserEvent() {} + +type eventRawToolCall struct { + raw string +} + +func (eventRawToolCall) isParserEvent() {} + +type eventThinkingContent struct { + content string +} + +func (eventThinkingContent) isParserEvent() {} + +// Add processes new output text and returns parsed content, thinking, and tool calls. +func (p *Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + p.buffer.WriteString(s) + events := p.parseEvents() + + var toolCalls []api.ToolCall + var contentSb strings.Builder + var thinkingSb strings.Builder + + for _, event := range events { + switch event := event.(type) { + case eventRawToolCall: + toolCall, err := parseToolCall(event, p.tools) + if err != nil { + slog.Warn("glm-4 tool call parsing failed", "error", err) + return "", "", nil, err + } + toolCalls = append(toolCalls, toolCall) + case eventThinkingContent: + thinkingSb.WriteString(event.content) + case eventContent: + contentSb.WriteString(event.content) + } + } + + return contentSb.String(), thinkingSb.String(), toolCalls, nil +} + +func (p *Parser) parseEvents() []parserEvent { + var all []parserEvent + + keepLooping := true + for keepLooping { + var events []parserEvent + events, keepLooping = p.eat() + if len(events) > 0 { + all = append(all, events...) + } + } + + if len(all) > 0 { + slog.Log(context.TODO(), logutil.LevelTrace, "glm-4 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String()) + } + + return all +} + +// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer +// and transitions to the next state. Returns (nil, false) if only whitespace remains +// in the buffer (needs more input), or (nil, true) if we successfully transitioned. +func (p *Parser) eatLeadingWhitespaceAndTransitionTo(nextState parserState) ([]parserEvent, bool) { + trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace) + p.buffer.Reset() + if trimmed == "" { + return nil, false // Still only whitespace, keep waiting for more input + } + p.state = nextState + p.buffer.WriteString(trimmed) + return nil, true // Successfully transitioned +} + +// splitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace), +// the content after (optionally trimmed of leading whitespace), and updates the buffer +func (p *Parser) splitAtTag(tag string, trimAfter bool) (string, string) { + split := strings.SplitN(p.buffer.String(), tag, 2) + before := split[0] + before = strings.TrimRightFunc(before, unicode.IsSpace) + after := split[1] + if trimAfter { + after = strings.TrimLeftFunc(after, unicode.IsSpace) + } + p.buffer.Reset() + p.buffer.WriteString(after) + return before, after +} + +func (p *Parser) eat() ([]parserEvent, bool) { + var events []parserEvent + + switch p.state { + case parserState_LookingForThinkingOpen: + trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace) + if strings.HasPrefix(trimmed, thinkingOpenTag) { + // Found opening tag + after := strings.TrimPrefix(trimmed, thinkingOpenTag) + after = strings.TrimLeftFunc(after, unicode.IsSpace) + p.buffer.Reset() + p.buffer.WriteString(after) + if after == "" { + p.state = parserState_ThinkingStartedEatingWhitespace + } else { + p.state = parserState_CollectingThinking + } + return events, true + } else if strings.HasPrefix(thinkingOpenTag, trimmed) { + // Partial opening tag seen, keep accumulating + return events, false + } else if trimmed == "" { + // Only whitespace, keep accumulating + return events, false + } else { + // No thinking tag found, skip to content collection + p.state = parserState_CollectingContent + // Don't trim - we want to keep the original content + return events, true + } + + case parserState_ThinkingStartedEatingWhitespace: + return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingThinking) + + case parserState_CollectingThinking: + acc := p.buffer.String() + if strings.Contains(acc, thinkingCloseTag) { + thinking, remaining := p.splitAtTag(thinkingCloseTag, true) + if len(thinking) > 0 { + events = append(events, eventThinkingContent{content: thinking}) + } + if remaining == "" { + p.state = parserState_ThinkingDoneEatingWhitespace + } else { + p.state = parserState_CollectingContent + } + return events, true + } else if overlapLen := overlap(acc, thinkingCloseTag); overlapLen > 0 { + // Partial closing tag - withhold it along with any trailing whitespace before it + beforePartialTag := acc[:len(acc)-overlapLen] + trailingWsLen := trailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingWsLen + + unambiguous := acc[:ambiguousStart] + ambiguous := acc[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, eventThinkingContent{content: unambiguous}) + } + return events, false + } else { + // Pure thinking content - withhold trailing whitespace (might precede closing tag) + whitespaceLen := trailingWhitespaceLen(acc) + ambiguousStart := len(acc) - whitespaceLen + + unambiguous := acc[:ambiguousStart] + ambiguous := acc[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, eventThinkingContent{content: unambiguous}) + } + return events, false + } + + case parserState_ThinkingDoneEatingWhitespace: + return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingContent) + + case parserState_CollectingContent: + if strings.Contains(p.buffer.String(), toolOpenTag) { + before, after := p.splitAtTag(toolOpenTag, true) + if len(before) > 0 { + events = append(events, eventContent{content: before}) + } + if after == "" { + p.state = parserState_ToolStartedEatingWhitespace + } else { + p.state = parserState_CollectingToolContent + } + return events, true + } else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 { + beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen] + trailingWsLen := trailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingWsLen + + unambiguous := p.buffer.String()[:ambiguousStart] + ambiguous := p.buffer.String()[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, eventContent{content: unambiguous}) + } + return events, false + } else { + whitespaceLen := trailingWhitespaceLen(p.buffer.String()) + ambiguousStart := len(p.buffer.String()) - whitespaceLen + + unambiguous := p.buffer.String()[:ambiguousStart] + ambiguous := p.buffer.String()[ambiguousStart:] + p.buffer.Reset() + p.buffer.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, eventContent{content: unambiguous}) + } + return events, false + } + + case parserState_ToolStartedEatingWhitespace: + return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingToolContent) + + case parserState_CollectingToolContent: + acc := p.buffer.String() + if strings.Contains(acc, toolCloseTag) { + toolContent, _ := p.splitAtTag(toolCloseTag, true) + if len(toolContent) == 0 { + slog.Warn("glm4 tool call closing tag found but no content before it") + } + events = append(events, eventRawToolCall{raw: toolContent}) + p.state = parserState_CollectingContent + return events, true + } else { + // Keep accumulating - tool calls are not streamed + // We just wait for the closing tag + return events, false + } + + default: + panic("unreachable") + } +} + +// overlap returns the length of the overlap between the end of s and the start of tag. +func overlap(s, tag string) int { + for i := 1; i <= len(tag) && i <= len(s); i++ { + if strings.HasSuffix(s, tag[:i]) { + return i + } + } + return 0 +} + +// trailingWhitespaceLen returns the length of trailing whitespace in s. +func trailingWhitespaceLen(s string) int { + trimmed := strings.TrimRightFunc(s, unicode.IsSpace) + return len(s) - len(trimmed) +} + +// ToolCallXML represents the structure of a GLM-4 tool call for XML parsing +type ToolCallXML struct { + XMLName xml.Name `xml:"tool_call"` + Content string `xml:",chardata"` // Function name (text nodes between tags) + Keys []string `xml:"arg_key"` // All arg_key elements in document order + Values []string `xml:"arg_value"` // All arg_value elements in document order +} + +// escapeContent escapes XML entities in text content while preserving arg_key/arg_value tags +func escapeContent(s string) string { + var result strings.Builder + inTag := false + + for i := range len(s) { + ch := s[i] + + if ch == '<' { + // Check if this is a known tag + if strings.HasPrefix(s[i:], "") || + strings.HasPrefix(s[i:], "") || + strings.HasPrefix(s[i:], "") || + strings.HasPrefix(s[i:], "") { + inTag = true + } + } + + if inTag { + result.WriteByte(ch) + if ch == '>' { + inTag = false + } + } else { + // Escape special characters in text content + switch ch { + case '&': + result.WriteString("&") + case '<': + result.WriteString("<") + case '>': + result.WriteString(">") + default: + result.WriteByte(ch) + } + } + } + + return result.String() +} + +func parseToolCall(raw eventRawToolCall, tools []api.Tool) (api.ToolCall, error) { + // Escape any unescaped entities in text content + escaped := escapeContent(raw.raw) + + // Wrap the content in a root element to make it valid XML + xmlString := "" + escaped + "" + + // Parse XML into struct + var parsed ToolCallXML + if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil { + return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err) + } + + // Extract and trim function name + functionName := strings.TrimSpace(parsed.Content) + if functionName == "" { + return api.ToolCall{}, fmt.Errorf("empty function name") + } + + // Verify keys and values are paired correctly + if len(parsed.Keys) != len(parsed.Values) { + return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values)) + } + + // Find the matching tool to get parameter types + var matchedTool *api.Tool + for i := range tools { + if tools[i].Function.Name == functionName { + matchedTool = &tools[i] + break + } + } + + // Build arguments map by pairing keys and values + toolCall := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: functionName, + Arguments: api.NewToolCallFunctionArguments(), + }, + } + + for i := range parsed.Keys { + key := strings.TrimSpace(parsed.Keys[i]) + value := parsed.Values[i] // Don't trim here - parseValue handles it + + // Look up parameter type + var paramType api.PropertyType + if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil { + if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok { + // Handle anyOf by collecting all types from the union + if len(prop.AnyOf) > 0 { + for _, anyOfProp := range prop.AnyOf { + paramType = append(paramType, anyOfProp.Type...) + } + } else { + paramType = prop.Type + } + } + } + + // Parse value with type coercion + toolCall.Function.Arguments.Set(key, parseValue(value, paramType)) + } + + return toolCall, nil +} + +// parseValue parses a string value and coerces it to the appropriate type based on paramType. +func parseValue(value string, paramType api.PropertyType) any { + value = strings.TrimSpace(value) + + // If no type specified, return as string + if len(paramType) == 0 { + return value + } + + // Try to parse based on specified types + for _, t := range paramType { + switch t { + case "boolean": + if value == "true" { + return true + } + if value == "false" { + return false + } + case "integer": + var i int64 + if _, err := fmt.Sscanf(value, "%d", &i); err == nil { + return i + } + case "number": + var f float64 + if _, err := fmt.Sscanf(value, "%f", &f); err == nil { + return f + } + case "array", "object": + // Try to parse as JSON + var result any + if err := json.Unmarshal([]byte(value), &result); err == nil { + return result + } + } + } + + // Default to string + return value +} diff --git a/x/models/glm4_moe_lite/parser_test.go b/x/models/glm4_moe_lite/parser_test.go new file mode 100644 index 000000000..0ce382709 --- /dev/null +++ b/x/models/glm4_moe_lite/parser_test.go @@ -0,0 +1,192 @@ +//go:build mlx + +package glm4_moe_lite + +import ( + "testing" + + "github.com/ollama/ollama/api" +) + +func TestParserThinking(t *testing.T) { + tests := []struct { + name string + input string + thinkEnabled bool + wantContent string + wantThinking string + wantToolCalls int + }{ + { + name: "thinking enabled - simple thinking then content", + input: "Let me think about this...Here is my answer.", + thinkEnabled: true, + wantThinking: "Let me think about this...", + wantContent: "Here is my answer.", + }, + { + name: "thinking enabled - only thinking", + input: "I need to consider multiple factors...", + thinkEnabled: true, + wantThinking: "I need to consider multiple factors...", + wantContent: "", + }, + { + name: "thinking disabled - direct content", + input: "Here is my direct answer.", + thinkEnabled: false, + wantThinking: "", + wantContent: "Here is my direct answer.", + }, + { + name: "thinking with tool call", + input: "Let me search for that...I'll use a tool.searchquerytest", + thinkEnabled: true, + wantThinking: "Let me search for that...", + wantContent: "I'll use a tool.", + wantToolCalls: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Parser{} + + var thinkValue *api.ThinkValue + if tt.thinkEnabled { + thinkValue = &api.ThinkValue{Value: true} + } else { + thinkValue = &api.ThinkValue{Value: false} + } + + // Define tools for tool call tests + props := api.NewToolPropertiesMap() + props.Set("query", api.ToolProperty{Type: api.PropertyType{"string"}}) + tools := []api.Tool{ + { + Function: api.ToolFunction{ + Name: "search", + Parameters: api.ToolFunctionParameters{ + Properties: props, + }, + }, + }, + } + + p.Init(tools, nil, thinkValue) + + content, thinking, calls, err := p.Add(tt.input, true) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if thinking != tt.wantThinking { + t.Errorf("thinking = %q, want %q", thinking, tt.wantThinking) + } + if content != tt.wantContent { + t.Errorf("content = %q, want %q", content, tt.wantContent) + } + if len(calls) != tt.wantToolCalls { + t.Errorf("len(calls) = %d, want %d", len(calls), tt.wantToolCalls) + } + }) + } +} + +func TestParserToolCall(t *testing.T) { + p := &Parser{} + + props := api.NewToolPropertiesMap() + props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}}) + props.Set("unit", api.ToolProperty{Type: api.PropertyType{"string"}}) + tools := []api.Tool{ + { + Function: api.ToolFunction{ + Name: "get_weather", + Parameters: api.ToolFunctionParameters{ + Properties: props, + }, + }, + }, + } + + // Initialize with thinking disabled + tv := &api.ThinkValue{Value: false} + p.Init(tools, nil, tv) + + input := "get_weatherlocationSan Franciscounitcelsius" + + _, _, calls, err := p.Add(input, true) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(calls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(calls)) + } + + call := calls[0] + if call.Function.Name != "get_weather" { + t.Errorf("function name = %q, want %q", call.Function.Name, "get_weather") + } + + location, ok := call.Function.Arguments.Get("location") + if !ok || location != "San Francisco" { + t.Errorf("location = %v, want %q", location, "San Francisco") + } + + unit, ok := call.Function.Arguments.Get("unit") + if !ok || unit != "celsius" { + t.Errorf("unit = %v, want %q", unit, "celsius") + } +} + +func TestOverlap(t *testing.T) { + tests := []struct { + s string + tag string + want int + }{ + {"hello<", "", 1}, + {"hello", 2}, + {"hello", 3}, + {"hello", 4}, + {"hello", 5}, + {"hello", 6}, + {"hello", 7}, + {"hello", "", 8}, // Complete tag at end returns full length + {"hello", "", 0}, + {"", "", 0}, + } + + for _, tt := range tests { + t.Run(tt.s+"_"+tt.tag, func(t *testing.T) { + got := overlap(tt.s, tt.tag) + if got != tt.want { + t.Errorf("overlap(%q, %q) = %d, want %d", tt.s, tt.tag, got, tt.want) + } + }) + } +} + +func TestTrailingWhitespaceLen(t *testing.T) { + tests := []struct { + s string + want int + }{ + {"hello ", 3}, + {"hello\n\t ", 3}, + {"hello", 0}, + {"", 0}, + {" ", 3}, + } + + for _, tt := range tests { + t.Run(tt.s, func(t *testing.T) { + got := trailingWhitespaceLen(tt.s) + if got != tt.want { + t.Errorf("trailingWhitespaceLen(%q) = %d, want %d", tt.s, got, tt.want) + } + }) + } +} diff --git a/x/models/glm4_moe_lite/render.go b/x/models/glm4_moe_lite/render.go new file mode 100644 index 000000000..4998604bf --- /dev/null +++ b/x/models/glm4_moe_lite/render.go @@ -0,0 +1,175 @@ +//go:build mlx + +package glm4_moe_lite + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/ollama/ollama/api" +) + +// Renderer renders messages for GLM4-MoE-Lite models. +// +// GLM-4 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode): +// +// 1. INTERLEAVED THINKING +// The model thinks between tool calls and after receiving tool results. +// This enables complex step-by-step reasoning: interpreting each tool output +// before deciding what to do next. Thinking blocks are preserved and returned +// with tool results to maintain reasoning continuity. +// +// 2. PRESERVED THINKING +// The model retains reasoning content from previous assistant turns in context. +// This preserves reasoning continuity across multi-turn conversations. The +// upstream API has a "clear_thinking" parameter to control this: +// - clear_thinking=true: clears reasoning from previous turns (outputs ) +// - clear_thinking=false: preserves ... blocks from previous turns +// +// 3. TURN-LEVEL THINKING +// Controls whether the model should reason on each turn. The upstream API +// uses "enable_thinking" parameter: +// - enable_thinking=true: outputs to start reasoning +// - enable_thinking=false: outputs to skip reasoning +// +// OLLAMA DEFAULTS: +// - Thinking is ENABLED by default (thinkValue=nil or true outputs ) +// - Thinking is PRESERVED by default (reasoning content from previous turns is always +// included in ... blocks, equivalent to clear_thinking=false) +// - Users can disable thinking per-turn via thinkValue=false +type Renderer struct{} + +// Render renders messages into the GLM4 chat format. +func (r *Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) { + var sb strings.Builder + + sb.WriteString("[gMASK]") + + if len(tools) > 0 { + sb.WriteString("<|system|>\n") + sb.WriteString("# Tools\n\n") + sb.WriteString("You may call one or more functions to assist with the user query.\n\n") + sb.WriteString("You are provided with function signatures within XML tags:\n") + sb.WriteString("\n") + for _, tool := range tools { + d, _ := json.Marshal(tool) + sb.WriteString(formatToolJSON(d)) + sb.WriteString("\n") + } + sb.WriteString("\n\n") + sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n") + sb.WriteString("{function-name}{arg-key-1}{arg-value-1}{arg-key-2}{arg-value-2}...") + } + + think := true + if thinkValue != nil && !thinkValue.Bool() { + think = false + } + + for i, message := range messages { + switch message.Role { + case "user": + sb.WriteString("<|user|>") + sb.WriteString(message.Content) + case "assistant": + sb.WriteString("<|assistant|>") + if message.Thinking != "" { + sb.WriteString("" + message.Thinking + "") + } else { + sb.WriteString("") + } + if message.Content != "" { + sb.WriteString(message.Content) + } + if len(message.ToolCalls) > 0 { + for _, toolCall := range message.ToolCalls { + sb.WriteString("" + toolCall.Function.Name) + sb.WriteString(renderToolArguments(toolCall.Function.Arguments)) + sb.WriteString("") + } + } + case "tool": + if i == 0 || messages[i-1].Role != "tool" { + sb.WriteString("<|observation|>") + } + sb.WriteString("") + sb.WriteString(message.Content) + sb.WriteString("") + case "system": + sb.WriteString("<|system|>") + sb.WriteString(message.Content) + } + } + + sb.WriteString("<|assistant|>") + if think { + sb.WriteString("") + } else { + sb.WriteString("") + } + + return sb.String(), nil +} + +// renderToolArguments converts tool call arguments to GLM4 XML format. +func renderToolArguments(args api.ToolCallFunctionArguments) string { + var sb strings.Builder + for key, value := range args.All() { + sb.WriteString("" + key + "") + var valueStr string + if str, ok := value.(string); ok { + valueStr = str + } else { + jsonBytes, err := json.Marshal(value) + if err != nil { + valueStr = fmt.Sprintf("%v", value) + } else { + valueStr = string(jsonBytes) + } + } + + sb.WriteString("" + valueStr + "") + } + + return sb.String() +} + +// formatToolJSON formats JSON for GLM4 tool definitions by adding spaces after : and , +func formatToolJSON(raw []byte) string { + var sb strings.Builder + sb.Grow(len(raw) + len(raw)/10) + + inString := false + escaped := false + for i := range raw { + ch := raw[i] + sb.WriteByte(ch) + + if inString { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == '"' { + inString = false + } + continue + } + + if ch == '"' { + inString = true + continue + } + + if ch == ':' || ch == ',' { + sb.WriteByte(' ') + } + } + + return sb.String() +} diff --git a/x/models/glm4_moe_lite/render_test.go b/x/models/glm4_moe_lite/render_test.go new file mode 100644 index 000000000..f0d576bec --- /dev/null +++ b/x/models/glm4_moe_lite/render_test.go @@ -0,0 +1,205 @@ +//go:build mlx + +package glm4_moe_lite + +import ( + "strings" + "testing" + + "github.com/ollama/ollama/api" +) + +func TestRendererSimple(t *testing.T) { + r := &Renderer{} + + messages := []api.Message{ + {Role: "user", Content: "Hello"}, + } + + // Thinking enabled (default) + result, err := r.Render(messages, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := "[gMASK]<|user|>Hello<|assistant|>" + if result != expected { + t.Errorf("result = %q, want %q", result, expected) + } +} + +func TestRendererThinkingDisabled(t *testing.T) { + r := &Renderer{} + + messages := []api.Message{ + {Role: "user", Content: "Hello"}, + } + + tv := &api.ThinkValue{Value: false} + + result, err := r.Render(messages, nil, tv) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := "[gMASK]<|user|>Hello<|assistant|>" + if result != expected { + t.Errorf("result = %q, want %q", result, expected) + } +} + +func TestRendererMultiTurn(t *testing.T) { + r := &Renderer{} + + messages := []api.Message{ + {Role: "user", Content: "What is 2+2?"}, + {Role: "assistant", Content: "4", Thinking: "Let me calculate: 2+2=4"}, + {Role: "user", Content: "And 3+3?"}, + } + + result, err := r.Render(messages, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Check key parts + if !strings.Contains(result, "[gMASK]") { + t.Error("missing [gMASK] prefix") + } + if !strings.Contains(result, "<|user|>What is 2+2?") { + t.Error("missing first user message") + } + if !strings.Contains(result, "<|assistant|>Let me calculate: 2+2=44") { + t.Error("missing assistant message with thinking") + } + if !strings.Contains(result, "<|user|>And 3+3?") { + t.Error("missing second user message") + } + if !strings.HasSuffix(result, "<|assistant|>") { + t.Errorf("should end with <|assistant|>, got suffix: %q", result[len(result)-30:]) + } +} + +func TestRendererWithSystem(t *testing.T) { + r := &Renderer{} + + messages := []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello"}, + } + + result, err := r.Render(messages, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !strings.Contains(result, "<|system|>You are a helpful assistant.") { + t.Error("missing system message") + } +} + +func TestRendererWithTools(t *testing.T) { + r := &Renderer{} + + messages := []api.Message{ + {Role: "user", Content: "What's the weather?"}, + } + + props := api.NewToolPropertiesMap() + props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The city"}) + tools := []api.Tool{ + { + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the weather for a location", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: props, + Required: []string{"location"}, + }, + }, + }, + } + + result, err := r.Render(messages, tools, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Check for tool system prompt + if !strings.Contains(result, "<|system|>") { + t.Error("missing system tag for tools") + } + if !strings.Contains(result, "# Tools") { + t.Error("missing tools header") + } + if !strings.Contains(result, "") { + t.Error("missing tools tag") + } + if !strings.Contains(result, "get_weather") { + t.Error("missing tool name") + } + if !strings.Contains(result, "") { + t.Error("missing closing tools tag") + } +} + +func TestRendererWithToolCalls(t *testing.T) { + r := &Renderer{} + + args := api.NewToolCallFunctionArguments() + args.Set("location", "San Francisco") + + messages := []api.Message{ + {Role: "user", Content: "What's the weather in SF?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: args, + }, + }, + }, + }, + {Role: "tool", Content: "Sunny, 72F"}, + } + + result, err := r.Render(messages, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !strings.Contains(result, "get_weather") { + t.Error("missing tool call") + } + if !strings.Contains(result, "location") { + t.Error("missing arg_key") + } + if !strings.Contains(result, "San Francisco") { + t.Error("missing arg_value") + } + if !strings.Contains(result, "") { + t.Error("missing tool call closing tag") + } + if !strings.Contains(result, "<|observation|>") { + t.Error("missing observation tag") + } + if !strings.Contains(result, "Sunny, 72F") { + t.Error("missing tool response") + } +} + +func TestFormatToolJSON(t *testing.T) { + input := []byte(`{"name":"test","value":123}`) + result := formatToolJSON(input) + + // Should add spaces after : and , + if !strings.Contains(result, ": ") { + t.Error("should add space after colon") + } + if !strings.Contains(result, ", ") { + t.Error("should add space after comma") + } +} diff --git a/x/models/nn/nn.go b/x/models/nn/nn.go new file mode 100644 index 000000000..3f57d483a --- /dev/null +++ b/x/models/nn/nn.go @@ -0,0 +1,188 @@ +//go:build mlx + +package nn + +import "github.com/ollama/ollama/x/mlxrunner/mlx" + +// Layer is the interface for neural network layers with a Forward method. +type Layer interface { + Forward(x *mlx.Array) *mlx.Array +} + +// LinearLayer is an interface for linear layers (both regular and quantized). +type LinearLayer interface { + Forward(x *mlx.Array) *mlx.Array + OutputDim() int32 +} + +// Linear applies an affine transformation: y = x @ W.T + b +type Linear struct { + Weight *mlx.Array + Bias *mlx.Array +} + +func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear { + return &Linear{Weight: weight, Bias: bias} +} + +func (l *Linear) Forward(x *mlx.Array) *mlx.Array { + w := l.Weight.Transpose(1, 0) + if l.Bias != nil && l.Bias.Valid() { + return l.Bias.Addmm(x, w, 1.0, 1.0) + } + return x.Matmul(w) +} + +func (l *Linear) OutputDim() int32 { + return int32(l.Weight.Dim(0)) +} + +// QuantizedLinear applies an affine transformation using quantized weights. +type QuantizedLinear struct { + Weight *mlx.Array // Quantized weight data + Scales *mlx.Array // Scale factors for dequantization + QBiases *mlx.Array // Quantization biases (nil for nvfp4) + Bias *mlx.Array // Layer bias [output_dims] or nil + GroupSize int + Bits int + Mode string +} + +func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear { + qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode) + if qbiases != nil { + mlx.Eval(qw, scales, qbiases) + } else { + mlx.Eval(qw, scales) + } + return &QuantizedLinear{ + Weight: qw, + Scales: scales, + QBiases: qbiases, + Bias: bias, + GroupSize: groupSize, + Bits: bits, + Mode: mode, + } +} + +func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array { + out := mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode) + if ql.Bias != nil && ql.Bias.Valid() { + out = out.Add(ql.Bias) + } + return out +} + +func (ql *QuantizedLinear) OutputDim() int32 { + return int32(ql.Weight.Dim(0)) +} + +// RMSNorm represents an RMS normalization layer. +type RMSNorm struct { + Weight *mlx.Array + Eps float32 +} + +func NewRMSNorm(weight *mlx.Array, eps float32) *RMSNorm { + return &RMSNorm{Weight: weight, Eps: eps} +} + +func (rn *RMSNorm) Forward(x *mlx.Array, eps float32) *mlx.Array { + if eps == 0 { + eps = rn.Eps + } + return mlx.RMSNormFn(x, rn.Weight, eps) +} + +// Embedding represents an embedding layer. +type Embedding struct { + Weight *mlx.Array +} + +func NewEmbedding(weight *mlx.Array) *Embedding { + return &Embedding{Weight: weight} +} + +func (e *Embedding) Forward(indices *mlx.Array) *mlx.Array { + return e.Weight.TakeAxis(indices, 0) +} + +// LayerNorm represents a standard layer normalization layer (with bias). +type LayerNorm struct { + Weight *mlx.Array + Bias *mlx.Array + Eps float32 +} + +func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array { + eps := ln.Eps + if eps == 0 { + eps = 1e-5 + } + mean := mlx.Mean(x, -1, true) + centered := x.Subtract(mean) + variance := mlx.Mean(centered.Multiply(centered), -1, true) + normalized := centered.Multiply(mlx.RSqrt(mlx.AddScalar(variance, eps))) + out := normalized.Multiply(ln.Weight) + if ln.Bias != nil && ln.Bias.Valid() { + out = out.Add(ln.Bias) + } + return out +} + +// MultiLinearLayer is an interface for per-head linear layers. +type MultiLinearLayer interface { + Forward(x *mlx.Array) *mlx.Array +} + +// MultiLinear performs per-head linear projections. +// Weight shape: [num_heads, output_dims, input_dims] +type MultiLinear struct { + Weight *mlx.Array +} + +func NewMultiLinear(weight *mlx.Array) *MultiLinear { + return &MultiLinear{Weight: weight} +} + +func (ml *MultiLinear) Forward(x *mlx.Array) *mlx.Array { + wT := ml.Weight.Transpose(0, 2, 1) + return x.Matmul(wT) +} + +// RepeatKV repeats K/V tensors for grouped query attention. +func RepeatKV(x *mlx.Array, repeatFactor int32) *mlx.Array { + if repeatFactor == 1 { + return x + } + shape := x.Dims() + x = x.ExpandDims(2) + reps := []int32{1, 1, repeatFactor, 1, 1} + x = mlx.Tile(x, reps) + return mlx.Reshape(x, int32(shape[0]), int32(shape[1])*repeatFactor, int32(shape[2]), int32(shape[3])) +} + +// ApplyCausalMask applies causal (lower triangular) mask to attention scores. +func ApplyCausalMask(scores *mlx.Array) *mlx.Array { + shape := scores.Dims() + seqLen := int32(shape[2]) + mask := mlx.Tri(seqLen, seqLen, 0) + negInf := mlx.NewScalarArray(float32(-1e9)) + mask = mask.ExpandDims(0).ExpandDims(0) + return mlx.Where(mask, scores, negInf) +} + +// ApplyCausalMaskWithOffset applies causal mask for cached attention. +func ApplyCausalMaskWithOffset(scores *mlx.Array, offset int32) *mlx.Array { + if offset == 0 { + return ApplyCausalMask(scores) + } + shape := scores.Dims() + queryLen := int32(shape[2]) + keyLen := int32(shape[3]) + mask := mlx.Tri(queryLen, keyLen, int(offset)) + negInf := mlx.NewScalarArray(float32(-1e9)) + mask = mask.ExpandDims(0).ExpandDims(0) + return mlx.Where(mask, scores, negInf) +}