mirror of
https://github.com/ollama/ollama.git
synced 2026-04-20 07:54:25 +02:00
Use atomic.Int32 for Array.pinned and a sync.Mutex for the global arrays slice so MLX arrays can be safely created and managed from multiple goroutines. Convert Array value receivers to pointer receivers and struct fields from Array to *Array to avoid copying the atomic.
73 lines
1.5 KiB
Go
73 lines
1.5 KiB
Go
package mlx
|
|
|
|
// #include "generated.h"
|
|
import "C"
|
|
|
|
import (
|
|
"unsafe"
|
|
)
|
|
|
|
func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *Array {
|
|
if mask == nil {
|
|
mask = New("")
|
|
}
|
|
|
|
sinks := New("")
|
|
|
|
mode := "causal"
|
|
cMode := C.CString(mode)
|
|
defer C.free(unsafe.Pointer(cMode))
|
|
|
|
out := New("FAST_SDPA")
|
|
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
type LayerNorm struct {
|
|
Weight *Array `weight:"weight"`
|
|
Bias *Array `weight:"bias"`
|
|
}
|
|
|
|
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
|
|
out := New("FAST_LAYERNORM")
|
|
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
type RMSNorm struct {
|
|
Weight *Array `weight:"weight"`
|
|
}
|
|
|
|
func (r *RMSNorm) Forward(x *Array, eps float32) *Array {
|
|
out := New("FAST_RMSNORM")
|
|
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
|
|
return out
|
|
}
|
|
|
|
type RoPE struct {
|
|
Dims int
|
|
Traditional bool
|
|
Base float32 `json:"rope_theta"`
|
|
Scale float32
|
|
}
|
|
|
|
func (r RoPE) Forward(t *Array, offset int) *Array {
|
|
freqs := New("")
|
|
out := New("FAST_ROPE")
|
|
C.mlx_fast_rope(
|
|
&out.ctx,
|
|
t.ctx,
|
|
C.int(r.Dims),
|
|
C._Bool(r.Traditional),
|
|
C.mlx_optional_float{
|
|
value: C.float(r.Base),
|
|
has_value: C._Bool(func() bool { return r.Base != 0 }()),
|
|
},
|
|
C.float(r.Scale),
|
|
C.int(offset),
|
|
freqs.ctx,
|
|
DefaultStream().ctx,
|
|
)
|
|
return out
|
|
}
|