mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 08:13:29 +02:00
* mlx: perf improvements Fix nn.go to call mlx_fast_layer_norm instead of manually implementing (mean, subtract, variance, rsqrt, multiply, add — 6 ops) Fix llama.go, gemma3.go to remove RepeatKV to tile K/V tensors to match the Q head count, since scaled_dot_product_attention natively handles GQA (it just requires n_q_heads % n_kv_heads == 0) * review comments
147 lines
3.8 KiB
Go
147 lines
3.8 KiB
Go
package nn
|
|
|
|
import (
|
|
"math"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
)
|
|
|
|
func skipIfNoMLX(t *testing.T) {
|
|
t.Helper()
|
|
if err := mlx.CheckInit(); err != nil {
|
|
t.Skipf("MLX not available: %v", err)
|
|
}
|
|
}
|
|
|
|
func approxEqual(a, b, tol float32) bool {
|
|
return float32(math.Abs(float64(a-b))) < tol
|
|
}
|
|
|
|
// TestLayerNormNoBias verifies LayerNorm without bias against manual computation.
|
|
func TestLayerNormNoBias(t *testing.T) {
|
|
skipIfNoMLX(t)
|
|
|
|
// Input: [1, 4] — single row, 4 features
|
|
x := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
|
|
weight := mlx.FromValues([]float32{1, 1, 1, 1}, 4)
|
|
mlx.Eval(x, weight)
|
|
|
|
ln := &LayerNorm{Weight: weight, Eps: 1e-5}
|
|
out := ln.Forward(x)
|
|
mlx.Eval(out)
|
|
|
|
data := out.Floats()
|
|
if len(data) != 4 {
|
|
t.Fatalf("expected 4 values, got %d", len(data))
|
|
}
|
|
|
|
// Manual LayerNorm: mean=2.5, var=1.25, std=sqrt(1.25+1e-5)
|
|
// normalized = (x - mean) / std
|
|
mean := float32(2.5)
|
|
variance := float32(1.25)
|
|
std := float32(math.Sqrt(float64(variance + 1e-5)))
|
|
for i, v := range []float32{1, 2, 3, 4} {
|
|
expected := (v - mean) / std
|
|
if !approxEqual(data[i], expected, 1e-4) {
|
|
t.Errorf("index %d: expected %.6f, got %.6f", i, expected, data[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestLayerNormWithBias verifies LayerNorm with weight and bias.
|
|
func TestLayerNormWithBias(t *testing.T) {
|
|
skipIfNoMLX(t)
|
|
|
|
x := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
|
|
weight := mlx.FromValues([]float32{2, 2, 2, 2}, 4)
|
|
bias := mlx.FromValues([]float32{10, 20, 30, 40}, 4)
|
|
mlx.Eval(x, weight, bias)
|
|
|
|
ln := &LayerNorm{Weight: weight, Bias: bias, Eps: 1e-5}
|
|
out := ln.Forward(x)
|
|
mlx.Eval(out)
|
|
|
|
data := out.Floats()
|
|
if len(data) != 4 {
|
|
t.Fatalf("expected 4 values, got %d", len(data))
|
|
}
|
|
|
|
mean := float32(2.5)
|
|
variance := float32(1.25)
|
|
std := float32(math.Sqrt(float64(variance + 1e-5)))
|
|
biases := []float32{10, 20, 30, 40}
|
|
for i, v := range []float32{1, 2, 3, 4} {
|
|
expected := ((v-mean)/std)*2 + biases[i]
|
|
if !approxEqual(data[i], expected, 1e-4) {
|
|
t.Errorf("index %d: expected %.6f, got %.6f", i, expected, data[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestLayerNormBatched verifies LayerNorm normalizes each row independently.
|
|
func TestLayerNormBatched(t *testing.T) {
|
|
skipIfNoMLX(t)
|
|
|
|
// Input: [2, 3] — two rows
|
|
x := mlx.FromValues([]float32{
|
|
1, 2, 3,
|
|
10, 20, 30,
|
|
}, 2, 3)
|
|
weight := mlx.FromValues([]float32{1, 1, 1}, 3)
|
|
mlx.Eval(x, weight)
|
|
|
|
ln := &LayerNorm{Weight: weight, Eps: 1e-5}
|
|
out := ln.Forward(x)
|
|
mlx.Eval(out)
|
|
|
|
data := out.Floats()
|
|
if len(data) != 6 {
|
|
t.Fatalf("expected 6 values, got %d", len(data))
|
|
}
|
|
|
|
// Each row should be independently normalized.
|
|
// Row 0: [1,2,3] mean=2, var=2/3
|
|
// Row 1: [10,20,30] mean=20, var=200/3
|
|
// After normalization both rows should have the same pattern
|
|
// since [10,20,30] = 10*[1,2,3], the normalized values are identical.
|
|
for i := range 3 {
|
|
if !approxEqual(data[i], data[i+3], 1e-4) {
|
|
t.Errorf("row 0 elem %d (%.6f) != row 1 elem %d (%.6f); expected identical normalized values",
|
|
i, data[i], i, data[i+3])
|
|
}
|
|
}
|
|
|
|
// Verify the normalized values sum to ~0 (mean-centered)
|
|
sum := data[0] + data[1] + data[2]
|
|
if !approxEqual(sum, 0, 1e-4) {
|
|
t.Errorf("normalized row sum should be ~0, got %.6f", sum)
|
|
}
|
|
}
|
|
|
|
// TestLayerNormDefaultEps verifies the default epsilon of 1e-5 is used when Eps is 0.
|
|
func TestLayerNormDefaultEps(t *testing.T) {
|
|
skipIfNoMLX(t)
|
|
|
|
x := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
|
|
weight := mlx.FromValues([]float32{1, 1, 1, 1}, 4)
|
|
mlx.Eval(x, weight)
|
|
|
|
// Eps=0 should use default 1e-5
|
|
ln0 := &LayerNorm{Weight: weight, Eps: 0}
|
|
out0 := ln0.Forward(x)
|
|
mlx.Eval(out0)
|
|
|
|
lnExplicit := &LayerNorm{Weight: weight, Eps: 1e-5}
|
|
outExplicit := lnExplicit.Forward(x)
|
|
mlx.Eval(outExplicit)
|
|
|
|
d0 := out0.Floats()
|
|
dE := outExplicit.Floats()
|
|
for i := range d0 {
|
|
if !approxEqual(d0[i], dE[i], 1e-6) {
|
|
t.Errorf("index %d: Eps=0 gave %.6f, Eps=1e-5 gave %.6f", i, d0[i], dE[i])
|
|
}
|
|
}
|
|
}
|