Daniel Hiltgen
|
539741199e
|
mlx: perf improvements (#14768)
* mlx: perf improvements
Fix nn.go to call mlx_fast_layer_norm instead of manually implementing (mean,
subtract, variance, rsqrt, multiply, add — 6 ops)
Fix llama.go, gemma3.go to remove RepeatKV to tile K/V tensors to match the Q
head count, since scaled_dot_product_attention natively handles GQA (it just
requires n_q_heads % n_kv_heads == 0)
* review comments
|
2026-03-12 12:01:28 -07:00 |
|