Files
ollama/x/mlxrunner/model/embedding_test.go
Patrick Devine d727aacd04 mlx: quantized embeddings, fast SwiGLU, and runtime fixes (#14884)
Add QuantizedEmbedding and EmbeddingLayer interface so models can
use quantized embedding weights and expose tied output projections.
This change updates gemma3, glm4_moe_lite, llama, qwen3, and qwen3_5
to use the new interface.
2026-03-17 11:21:38 -07:00

79 lines
2.1 KiB
Go

package model
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/models/nn"
)
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}
func TestMakeEmbeddingLayerDense(t *testing.T) {
skipIfNoMLX(t)
weight := mlx.FromValues([]float32{
1, 2, 3, 4,
5, 6, 7, 8,
}, 2, 4).AsType(mlx.DTypeBFloat16)
emb := MakeEmbeddingLayer(map[string]*mlx.Array{
"model.embed_tokens.weight": weight,
}, "model.embed_tokens", 0, 0, "", nil)
dense, ok := emb.(*nn.Embedding)
if !ok {
t.Fatalf("embedding type = %T, want *nn.Embedding", emb)
}
if dense.Weight.DType() != mlx.DTypeBFloat16 {
t.Fatalf("embedding dtype = %v, want %v", dense.Weight.DType(), mlx.DTypeBFloat16)
}
if _, ok := emb.AsLinear().(*nn.Linear); !ok {
t.Fatalf("AsLinear type = %T, want *nn.Linear", emb.AsLinear())
}
}
func TestMakeEmbeddingLayerQuantized(t *testing.T) {
skipIfNoMLX(t)
denseWeight := mlx.FromValues(func() []float32 {
out := make([]float32, 2*64)
for i := range out {
out[i] = float32(i%17) / 8
}
return out
}(), 2, 64).AsType(mlx.DTypeBFloat16)
qw, scales, qbiases := mlx.Quantize(denseWeight, 64, 4, "affine")
mlx.Eval(qw, scales, qbiases)
emb := MakeEmbeddingLayer(map[string]*mlx.Array{
"model.embed_tokens.weight": qw,
"model.embed_tokens.weight_scale": scales,
"model.embed_tokens.weight_qbias": qbiases,
}, "model.embed_tokens", 64, 4, "affine", nil)
qemb, ok := emb.(*nn.QuantizedEmbedding)
if !ok {
t.Fatalf("embedding type = %T, want *nn.QuantizedEmbedding", emb)
}
if qemb.GroupSize != 64 || qemb.Bits != 4 || qemb.Mode != "affine" {
t.Fatalf("quant params = (%d, %d, %q), want (64, 4, %q)", qemb.GroupSize, qemb.Bits, qemb.Mode, "affine")
}
indices := mlx.FromValues([]int32{1, 0}, 2)
out := emb.Forward(indices)
mlx.Eval(out)
if dims := out.Dims(); len(dims) != 2 || dims[0] != 2 || dims[1] != 64 {
t.Fatalf("embedding output dims = %v, want [2 64]", dims)
}
if _, ok := emb.AsLinear().(*nn.QuantizedLinear); !ok {
t.Fatalf("AsLinear type = %T, want *nn.QuantizedLinear", emb.AsLinear())
}
}