mirror of
https://github.com/ollama/ollama.git
synced 2026-04-26 02:36:09 +02:00
* prefer rocm v6 on windows Avoid building with v7 - more changes are needed * MLX: add header vendoring and remove go build tag This switches to using a vendoring approach for the mlx-c headers so that Go can build without requiring a cmake first. This enables building the new MLX based code by default. Every time cmake runs, the headers are refreshed, so we can easily keep them in sync when we bump mlx versions. Basic Windows and Linux support are verified. * ci: harden for flaky choco repo servers CI sometimes fails due to choco not actually installing cache. Since it just speeds up the build, we can proceed without. * review comments
501 lines
12 KiB
Go
501 lines
12 KiB
Go
package mlx
|
|
|
|
// #include "generated.h"
|
|
import "C"
|
|
|
|
import (
|
|
"reflect"
|
|
"unsafe"
|
|
)
|
|
|
|
// Quantization operations
|
|
|
|
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
|
|
cMode := C.CString(mode)
|
|
defer C.free(unsafe.Pointer(cMode))
|
|
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
|
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
|
res := C.mlx_vector_array_new()
|
|
defer C.mlx_vector_array_free(res)
|
|
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, DefaultStream().ctx)
|
|
|
|
vecSize := int(C.mlx_vector_array_size(res))
|
|
w0 := New("QUANTIZE_W")
|
|
C.mlx_vector_array_get(&w0.ctx, res, 0)
|
|
w1 := New("QUANTIZE_S")
|
|
C.mlx_vector_array_get(&w1.ctx, res, 1)
|
|
if vecSize >= 3 {
|
|
w2 := New("QUANTIZE_B")
|
|
C.mlx_vector_array_get(&w2.ctx, res, 2)
|
|
return w0, w1, w2
|
|
}
|
|
return w0, w1, nil
|
|
}
|
|
|
|
func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array {
|
|
cMode := C.CString(mode)
|
|
defer C.free(unsafe.Pointer(cMode))
|
|
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
|
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
|
optDtype := C.mlx_optional_dtype{has_value: false}
|
|
|
|
var b C.mlx_array
|
|
if biases != nil {
|
|
b = biases.ctx
|
|
}
|
|
|
|
out := New("DEQUANTIZE")
|
|
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array {
|
|
cMode := C.CString(mode)
|
|
defer C.free(unsafe.Pointer(cMode))
|
|
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
|
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
|
|
|
var b C.mlx_array
|
|
if biases != nil {
|
|
b = biases.ctx
|
|
}
|
|
|
|
out := New("QUANTIZED_MATMUL")
|
|
C.mlx_quantized_matmul(&out.ctx, x.ctx, w.ctx, scales.ctx, b, C.bool(transpose), optGroupSize, optBits, cMode, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, transpose bool, groupSize, bits int, mode string, sortedIndices bool) *Array {
|
|
cMode := C.CString(mode)
|
|
defer C.free(unsafe.Pointer(cMode))
|
|
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
|
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
|
|
|
var b, lhs, rhs C.mlx_array
|
|
if biases != nil {
|
|
b = biases.ctx
|
|
}
|
|
if lhsIndices != nil {
|
|
lhs = lhsIndices.ctx
|
|
}
|
|
if rhsIndices != nil {
|
|
rhs = rhsIndices.ctx
|
|
}
|
|
|
|
out := New("GATHER_QMM")
|
|
C.mlx_gather_qmm(&out.ctx, x.ctx, w.ctx, scales.ctx, b, lhs, rhs, C.bool(transpose), optGroupSize, optBits, cMode, C.bool(sortedIndices), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
// Missing tensor ops
|
|
|
|
func Tile(a *Array, reps []int32) *Array {
|
|
cReps := make([]C.int, len(reps))
|
|
for i, r := range reps {
|
|
cReps[i] = C.int(r)
|
|
}
|
|
out := New("TILE")
|
|
C.mlx_tile(&out.ctx, a.ctx, unsafe.SliceData(cReps), C.size_t(len(reps)), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func Tri(n, m int32, k int) *Array {
|
|
out := New("TRI")
|
|
C.mlx_tri(&out.ctx, C.int(n), C.int(m), C.int(k), C.mlx_dtype(DTypeFloat32), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func Where(condition, a, b *Array) *Array {
|
|
out := New("WHERE")
|
|
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array {
|
|
out := New("CONV1D")
|
|
C.mlx_conv1d(
|
|
&out.ctx,
|
|
x.ctx,
|
|
weight.ctx,
|
|
C.int(stride),
|
|
C.int(padding),
|
|
C.int(dilation),
|
|
C.int(groups),
|
|
DefaultStream().ctx,
|
|
)
|
|
if bias != nil && bias.Valid() {
|
|
out = Add(out, bias)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func Contiguous(a *Array, allowColMajor bool) *Array {
|
|
out := New("CONTIGUOUS")
|
|
C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
|
groups := int32(x.Dim(x.NumDims() - 1))
|
|
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
|
}
|
|
|
|
// Convenience wrappers (function-style for the model code)
|
|
|
|
func Stack(arrays []*Array, axis int) *Array {
|
|
vectorData := make([]C.mlx_array, len(arrays))
|
|
for i := range arrays {
|
|
vectorData[i] = arrays[i].ctx
|
|
}
|
|
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
|
|
defer C.mlx_vector_array_free(vector)
|
|
|
|
out := New("STACK")
|
|
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func Neg(a *Array) *Array {
|
|
return a.Negative()
|
|
}
|
|
|
|
func Sum(a *Array, axis int, keepDims bool) *Array {
|
|
return a.SumAxis(axis, keepDims)
|
|
}
|
|
|
|
func Argsort(a *Array, axis int) *Array {
|
|
return a.ArgsortAxis(axis)
|
|
}
|
|
|
|
func Take(a *Array, indices *Array, axis int) *Array {
|
|
return a.TakeAxis(indices, axis)
|
|
}
|
|
|
|
func RSqrt(a *Array) *Array {
|
|
out := New("RSQRT")
|
|
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func Mean(a *Array, axis int, keepDims bool) *Array {
|
|
out := New("MEAN_AXIS")
|
|
C.mlx_mean_axis(&out.ctx, a.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func Argpartition(a *Array, kth int, axis int) *Array {
|
|
return a.ArgpartitionAxis(kth, axis)
|
|
}
|
|
|
|
func TakeAlongAxis(a, indices *Array, axis int) *Array {
|
|
return a.TakeAlongAxis(indices, axis)
|
|
}
|
|
|
|
// Function-style wrappers matching imagegen API
|
|
|
|
func Add(a, b *Array) *Array {
|
|
return a.Add(b)
|
|
}
|
|
|
|
func Sub(a, b *Array) *Array {
|
|
return a.Subtract(b)
|
|
}
|
|
|
|
func Mul(a, b *Array) *Array {
|
|
return a.Multiply(b)
|
|
}
|
|
|
|
func Div(a, b *Array) *Array {
|
|
return a.Divide(b)
|
|
}
|
|
|
|
func Matmul(a, b *Array) *Array {
|
|
return a.Matmul(b)
|
|
}
|
|
|
|
func Reshape(a *Array, shape ...int32) *Array {
|
|
axes := make([]int, len(shape))
|
|
for i, s := range shape {
|
|
axes[i] = int(s)
|
|
}
|
|
return a.Reshape(axes...)
|
|
}
|
|
|
|
func Transpose(a *Array, axes ...int) *Array {
|
|
return a.Transpose(axes...)
|
|
}
|
|
|
|
func ExpandDims(a *Array, axis int) *Array {
|
|
return a.ExpandDims(axis)
|
|
}
|
|
|
|
func Squeeze(a *Array, axis int) *Array {
|
|
return a.Squeeze(axis)
|
|
}
|
|
|
|
func Flatten(a *Array) *Array {
|
|
return a.Flatten(0, -1)
|
|
}
|
|
|
|
func Concatenate(arrays []*Array, axis int) *Array {
|
|
if len(arrays) == 0 {
|
|
return nil
|
|
}
|
|
return arrays[0].Concatenate(axis, arrays[1:]...)
|
|
}
|
|
|
|
func SliceStartStop(a *Array, start, stop []int32) *Array {
|
|
n := len(start)
|
|
cStart := make([]C.int, n)
|
|
cStop := make([]C.int, n)
|
|
cStrides := make([]C.int, n)
|
|
for i := 0; i < n; i++ {
|
|
cStart[i] = C.int(start[i])
|
|
cStop[i] = C.int(stop[i])
|
|
cStrides[i] = 1
|
|
}
|
|
out := New("SLICE")
|
|
C.mlx_slice(&out.ctx, a.ctx, unsafe.SliceData(cStart), C.size_t(n), unsafe.SliceData(cStop), C.size_t(n), unsafe.SliceData(cStrides), C.size_t(n), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *Array {
|
|
if lhsIndices == nil {
|
|
lhsIndices = New("")
|
|
}
|
|
if rhsIndices == nil {
|
|
rhsIndices = New("")
|
|
}
|
|
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
|
|
}
|
|
|
|
func SiLU(a *Array) *Array {
|
|
sig := a.Sigmoid()
|
|
return a.Multiply(sig)
|
|
}
|
|
|
|
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
|
freqs := New("")
|
|
out := New("FAST_ROPE")
|
|
C.mlx_fast_rope(
|
|
&out.ctx,
|
|
x.ctx,
|
|
C.int(dims),
|
|
C.bool(traditional),
|
|
C.mlx_optional_float{
|
|
value: C.float(base),
|
|
has_value: C.bool(func() bool { return base != 0 }()),
|
|
},
|
|
C.float(scale),
|
|
C.int(offset),
|
|
freqs.ctx,
|
|
DefaultStream().ctx,
|
|
)
|
|
return out
|
|
}
|
|
|
|
func Sigmoid(a *Array) *Array {
|
|
return a.Sigmoid()
|
|
}
|
|
|
|
func Exp(a *Array) *Array {
|
|
out := New("EXP")
|
|
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func Log(a *Array) *Array {
|
|
out := New("LOG")
|
|
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
|
|
out := New("SOFTMAX_AXIS")
|
|
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
|
|
mask := New("")
|
|
sinks := New("")
|
|
mode := ""
|
|
if causalMask {
|
|
mode = "causal"
|
|
}
|
|
cMode := C.CString(mode)
|
|
defer C.free(unsafe.Pointer(cMode))
|
|
|
|
out := New("FAST_SDPA")
|
|
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func RMSNormFn(x, weight *Array, eps float32) *Array {
|
|
out := New("FAST_RMSNORM")
|
|
var w C.mlx_array
|
|
if weight != nil {
|
|
w = weight.ctx
|
|
}
|
|
C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func AddMM(c, a, b *Array, alpha, beta float32) *Array {
|
|
return c.Addmm(a, b, alpha, beta)
|
|
}
|
|
|
|
// Scalar helpers
|
|
|
|
// scalarWithDtype creates a scalar array matching the dtype of a.
|
|
// Matching dtype is important for graph fusion and avoiding implicit casts.
|
|
func scalarWithDtype(s float32, a *Array) C.mlx_array {
|
|
f32 := C.mlx_array_new_float(C.float(s))
|
|
dtype := a.DType()
|
|
if dtype == DTypeFloat32 {
|
|
return f32
|
|
}
|
|
casted := C.mlx_array_new()
|
|
C.mlx_astype(&casted, f32, C.mlx_dtype(dtype), DefaultStream().ctx)
|
|
C.mlx_array_free(f32)
|
|
return casted
|
|
}
|
|
|
|
func AddScalar(a *Array, s float32) *Array {
|
|
scalar := scalarWithDtype(s, a)
|
|
out := New("ADD_SCALAR")
|
|
C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
|
C.mlx_array_free(scalar)
|
|
return out
|
|
}
|
|
|
|
func MulScalar(a *Array, s float32) *Array {
|
|
scalar := scalarWithDtype(s, a)
|
|
out := New("MUL_SCALAR")
|
|
C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
|
C.mlx_array_free(scalar)
|
|
return out
|
|
}
|
|
|
|
func DivScalar(a *Array, s float32) *Array {
|
|
scalar := scalarWithDtype(s, a)
|
|
out := New("DIV_SCALAR")
|
|
C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
|
C.mlx_array_free(scalar)
|
|
return out
|
|
}
|
|
|
|
func FloorDivideScalar(a *Array, s int32) *Array {
|
|
scalar := FromValue(int(s))
|
|
return a.FloorDivide(scalar)
|
|
}
|
|
|
|
// Array constructors
|
|
|
|
func NewArrayInt32(data []int32, shape []int32) *Array {
|
|
cShape := make([]C.int, len(shape))
|
|
for i, s := range shape {
|
|
cShape[i] = C.int(s)
|
|
}
|
|
out := New("NEW_ARRAY_INT32")
|
|
out.ctx = C.mlx_array_new_data(unsafe.Pointer(&data[0]), unsafe.SliceData(cShape), C.int(len(shape)), C.mlx_dtype(DTypeInt32))
|
|
return out
|
|
}
|
|
|
|
func NewScalarArray(value float32) *Array {
|
|
out := New("SCALAR")
|
|
out.ctx = C.mlx_array_new_float32(C.float(value))
|
|
return out
|
|
}
|
|
|
|
func ZerosF32(shape []int32) *Array {
|
|
return Zeros(DTypeFloat32, func() []int {
|
|
ints := make([]int, len(shape))
|
|
for i, s := range shape {
|
|
ints[i] = int(s)
|
|
}
|
|
return ints
|
|
}()...)
|
|
}
|
|
|
|
// Utility
|
|
|
|
func Collect(v any) []*Array {
|
|
var arrays []*Array
|
|
seen := make(map[uintptr]bool)
|
|
collect(reflect.ValueOf(v), &arrays, seen)
|
|
return arrays
|
|
}
|
|
|
|
func Copy(a *Array) *Array {
|
|
if a == nil || !a.Valid() {
|
|
return a
|
|
}
|
|
out := New("COPY")
|
|
C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
|
|
if !v.IsValid() {
|
|
return
|
|
}
|
|
|
|
if v.Kind() == reflect.Ptr {
|
|
if v.IsNil() {
|
|
return
|
|
}
|
|
ptr := v.Pointer()
|
|
if seen[ptr] {
|
|
return
|
|
}
|
|
seen[ptr] = true
|
|
|
|
if arr, ok := v.Interface().(*Array); ok {
|
|
if arr != nil && arr.Valid() {
|
|
*arrays = append(*arrays, arr)
|
|
}
|
|
return
|
|
}
|
|
collect(v.Elem(), arrays, seen)
|
|
return
|
|
}
|
|
|
|
switch v.Kind() {
|
|
case reflect.Struct:
|
|
// Check if this struct IS an Array (not a pointer to one)
|
|
if arr, ok := v.Addr().Interface().(*Array); ok {
|
|
if arr != nil && arr.Valid() {
|
|
*arrays = append(*arrays, arr)
|
|
}
|
|
return
|
|
}
|
|
for i := 0; i < v.NumField(); i++ {
|
|
field := v.Field(i)
|
|
if field.CanInterface() {
|
|
collect(field, arrays, seen)
|
|
}
|
|
}
|
|
case reflect.Slice:
|
|
for i := 0; i < v.Len(); i++ {
|
|
collect(v.Index(i), arrays, seen)
|
|
}
|
|
case reflect.Map:
|
|
for _, key := range v.MapKeys() {
|
|
collect(v.MapIndex(key), arrays, seen)
|
|
}
|
|
case reflect.Interface:
|
|
if !v.IsNil() {
|
|
collect(v.Elem(), arrays, seen)
|
|
}
|
|
}
|
|
}
|
|
|
|
func EnableCompile() {
|
|
C.mlx_enable_compile()
|
|
}
|
|
|
|
func DisableCompile() {
|
|
C.mlx_disable_compile()
|
|
}
|