mirror of
https://github.com/ollama/ollama.git
synced 2026-04-22 16:55:44 +02:00
Match the ollamarunner and OpenAI semantics: raw, full-vocab log-softmax with the top-K ranked by probability. Skipped on the GPU when the request doesn't ask for logprobs so decode doesn't pay for it otherwise.
188 lines
4.9 KiB
Go
188 lines
4.9 KiB
Go
package nn
|
|
|
|
import (
|
|
"math"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
)
|
|
|
|
func skipIfNoMLX(t *testing.T) {
|
|
t.Helper()
|
|
if err := mlx.CheckInit(); err != nil {
|
|
t.Skipf("MLX not available: %v", err)
|
|
}
|
|
}
|
|
|
|
func approxEqual(a, b, tol float32) bool {
|
|
return float32(math.Abs(float64(a-b))) < tol
|
|
}
|
|
|
|
// TestLayerNormNoBias verifies LayerNorm without bias against manual computation.
|
|
func TestLayerNormNoBias(t *testing.T) {
|
|
skipIfNoMLX(t)
|
|
|
|
// Input: [1, 4] — single row, 4 features
|
|
x := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
|
|
weight := mlx.FromValues([]float32{1, 1, 1, 1}, 4)
|
|
mlx.Eval(x, weight)
|
|
|
|
ln := &LayerNorm{Weight: weight, Eps: 1e-5}
|
|
out := ln.Forward(x)
|
|
mlx.Eval(out)
|
|
|
|
data := out.Floats()
|
|
if len(data) != 4 {
|
|
t.Fatalf("expected 4 values, got %d", len(data))
|
|
}
|
|
|
|
// Manual LayerNorm: mean=2.5, var=1.25, std=sqrt(1.25+1e-5)
|
|
// normalized = (x - mean) / std
|
|
mean := float32(2.5)
|
|
variance := float32(1.25)
|
|
std := float32(math.Sqrt(float64(variance + 1e-5)))
|
|
for i, v := range []float32{1, 2, 3, 4} {
|
|
expected := (v - mean) / std
|
|
if !approxEqual(data[i], expected, 1e-4) {
|
|
t.Errorf("index %d: expected %.6f, got %.6f", i, expected, data[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestLayerNormWithBias verifies LayerNorm with weight and bias.
|
|
func TestLayerNormWithBias(t *testing.T) {
|
|
skipIfNoMLX(t)
|
|
|
|
x := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
|
|
weight := mlx.FromValues([]float32{2, 2, 2, 2}, 4)
|
|
bias := mlx.FromValues([]float32{10, 20, 30, 40}, 4)
|
|
mlx.Eval(x, weight, bias)
|
|
|
|
ln := &LayerNorm{Weight: weight, Bias: bias, Eps: 1e-5}
|
|
out := ln.Forward(x)
|
|
mlx.Eval(out)
|
|
|
|
data := out.Floats()
|
|
if len(data) != 4 {
|
|
t.Fatalf("expected 4 values, got %d", len(data))
|
|
}
|
|
|
|
mean := float32(2.5)
|
|
variance := float32(1.25)
|
|
std := float32(math.Sqrt(float64(variance + 1e-5)))
|
|
biases := []float32{10, 20, 30, 40}
|
|
for i, v := range []float32{1, 2, 3, 4} {
|
|
expected := ((v-mean)/std)*2 + biases[i]
|
|
if !approxEqual(data[i], expected, 1e-4) {
|
|
t.Errorf("index %d: expected %.6f, got %.6f", i, expected, data[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestLayerNormBatched verifies LayerNorm normalizes each row independently.
|
|
func TestLayerNormBatched(t *testing.T) {
|
|
skipIfNoMLX(t)
|
|
|
|
// Input: [2, 3] — two rows
|
|
x := mlx.FromValues([]float32{
|
|
1, 2, 3,
|
|
10, 20, 30,
|
|
}, 2, 3)
|
|
weight := mlx.FromValues([]float32{1, 1, 1}, 3)
|
|
mlx.Eval(x, weight)
|
|
|
|
ln := &LayerNorm{Weight: weight, Eps: 1e-5}
|
|
out := ln.Forward(x)
|
|
mlx.Eval(out)
|
|
|
|
data := out.Floats()
|
|
if len(data) != 6 {
|
|
t.Fatalf("expected 6 values, got %d", len(data))
|
|
}
|
|
|
|
// Each row should be independently normalized.
|
|
// Row 0: [1,2,3] mean=2, var=2/3
|
|
// Row 1: [10,20,30] mean=20, var=200/3
|
|
// After normalization both rows should have the same pattern
|
|
// since [10,20,30] = 10*[1,2,3], the normalized values are identical.
|
|
for i := range 3 {
|
|
if !approxEqual(data[i], data[i+3], 1e-4) {
|
|
t.Errorf("row 0 elem %d (%.6f) != row 1 elem %d (%.6f); expected identical normalized values",
|
|
i, data[i], i, data[i+3])
|
|
}
|
|
}
|
|
|
|
// Verify the normalized values sum to ~0 (mean-centered)
|
|
sum := data[0] + data[1] + data[2]
|
|
if !approxEqual(sum, 0, 1e-4) {
|
|
t.Errorf("normalized row sum should be ~0, got %.6f", sum)
|
|
}
|
|
}
|
|
|
|
// TestLayerNormDefaultEps verifies the default epsilon of 1e-5 is used when Eps is 0.
|
|
func TestLayerNormDefaultEps(t *testing.T) {
|
|
skipIfNoMLX(t)
|
|
|
|
x := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
|
|
weight := mlx.FromValues([]float32{1, 1, 1, 1}, 4)
|
|
mlx.Eval(x, weight)
|
|
|
|
// Eps=0 should use default 1e-5
|
|
ln0 := &LayerNorm{Weight: weight, Eps: 0}
|
|
out0 := ln0.Forward(x)
|
|
mlx.Eval(out0)
|
|
|
|
lnExplicit := &LayerNorm{Weight: weight, Eps: 1e-5}
|
|
outExplicit := lnExplicit.Forward(x)
|
|
mlx.Eval(outExplicit)
|
|
|
|
d0 := out0.Floats()
|
|
dE := outExplicit.Floats()
|
|
for i := range d0 {
|
|
if !approxEqual(d0[i], dE[i], 1e-6) {
|
|
t.Errorf("index %d: Eps=0 gave %.6f, Eps=1e-5 gave %.6f", i, d0[i], dE[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestQuantizedLinearMXFP4MatchesDequantizedWeight(t *testing.T) {
|
|
skipIfNoMLX(t)
|
|
|
|
weightVals := make([]float32, 3*32)
|
|
for i := range weightVals {
|
|
weightVals[i] = float32((i%11)-5) / 7
|
|
}
|
|
inputVals := make([]float32, 2*32)
|
|
for i := range inputVals {
|
|
inputVals[i] = float32((i%7)-3) / 5
|
|
}
|
|
|
|
weight := mlx.FromValues(weightVals, 3, 32).AsType(mlx.DTypeBFloat16)
|
|
input := mlx.FromValues(inputVals, 2, 32).AsType(mlx.DTypeBFloat16)
|
|
mlx.Eval(weight, input)
|
|
|
|
ql := NewQuantizedLinear(weight, nil, 32, 4, "mxfp4")
|
|
if ql.QBiases != nil {
|
|
t.Fatalf("mxfp4 qbiases = %v, want nil", ql.QBiases)
|
|
}
|
|
|
|
dequantizedWeight := mlx.Dequantize(ql.Weight, ql.Scales, ql.QBiases, 32, 4, "mxfp4")
|
|
mlx.Eval(dequantizedWeight)
|
|
|
|
qOut := ql.Forward(input).AsType(mlx.DTypeFloat32)
|
|
dOut := NewLinear(dequantizedWeight, nil).Forward(input).AsType(mlx.DTypeFloat32)
|
|
mlx.Eval(qOut, dOut)
|
|
|
|
got := qOut.Floats()
|
|
want := dOut.Floats()
|
|
if len(got) != len(want) {
|
|
t.Fatalf("output length = %d, want %d", len(got), len(want))
|
|
}
|
|
|
|
for i := range got {
|
|
if !approxEqual(got[i], want[i], 1e-3) {
|
|
t.Fatalf("output[%d] = %.6f, want %.6f", i, got[i], want[i])
|
|
}
|
|
}
|
|
}
|