diff --git a/x/mlxrunner/mlx/ops_extra.go b/x/mlxrunner/mlx/ops_extra.go index c61e6fdf9..ff06092e9 100644 --- a/x/mlxrunner/mlx/ops_extra.go +++ b/x/mlxrunner/mlx/ops_extra.go @@ -310,6 +310,12 @@ func Log(a *Array) *Array { return out } +func Logaddexp(a, b *Array) *Array { + out := New("LOGADDEXP") + C.mlx_logaddexp(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) + return out +} + func SoftmaxAxis(a *Array, axis int, precise bool) *Array { out := New("SOFTMAX_AXIS") C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx) diff --git a/x/mlxrunner/model/embedding.go b/x/mlxrunner/model/embedding.go new file mode 100644 index 000000000..d45d48e2f --- /dev/null +++ b/x/mlxrunner/model/embedding.go @@ -0,0 +1,42 @@ +package model + +import ( + "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/models/nn" +) + +// MakeEmbeddingLayer constructs an embedding layer from a tensor map. +// +// For quantized tensors (path.weight + path.weight_scale), it returns a +// QuantizedEmbedding using the same quant metadata path that linear layers use. +// For non-quantized tensors, it returns a standard dense embedding. +func MakeEmbeddingLayer( + tensors map[string]*mlx.Array, + path string, + defaultGroupSize, defaultBits int, + defaultMode string, + tensorQuant map[string]*TensorQuantInfo, +) nn.EmbeddingLayer { + w := tensors[path+".weight"] + if w == nil { + return nil + } + + scales := tensors[path+".weight_scale"] + if scales != nil { + qbiases := tensors[path+".weight_qbias"] + groupSize, bits, mode := ResolveLinearQuantParams( + defaultGroupSize, + defaultBits, + defaultMode, + tensorQuant, + path+".weight", + w, + scales, + ) + + return nn.NewQuantizedEmbedding(w, scales, qbiases, groupSize, bits, mode) + } + + return nn.NewEmbedding(w) +} diff --git a/x/mlxrunner/model/embedding_test.go b/x/mlxrunner/model/embedding_test.go new file mode 100644 index 000000000..2935bdc3f --- /dev/null +++ b/x/mlxrunner/model/embedding_test.go @@ -0,0 +1,78 @@ +package model + +import ( + "testing" + + "github.com/ollama/ollama/x/mlxrunner/mlx" + "github.com/ollama/ollama/x/models/nn" +) + +func skipIfNoMLX(t *testing.T) { + t.Helper() + if err := mlx.CheckInit(); err != nil { + t.Skipf("MLX not available: %v", err) + } +} + +func TestMakeEmbeddingLayerDense(t *testing.T) { + skipIfNoMLX(t) + + weight := mlx.FromValues([]float32{ + 1, 2, 3, 4, + 5, 6, 7, 8, + }, 2, 4).AsType(mlx.DTypeBFloat16) + + emb := MakeEmbeddingLayer(map[string]*mlx.Array{ + "model.embed_tokens.weight": weight, + }, "model.embed_tokens", 0, 0, "", nil) + + dense, ok := emb.(*nn.Embedding) + if !ok { + t.Fatalf("embedding type = %T, want *nn.Embedding", emb) + } + if dense.Weight.DType() != mlx.DTypeBFloat16 { + t.Fatalf("embedding dtype = %v, want %v", dense.Weight.DType(), mlx.DTypeBFloat16) + } + if _, ok := emb.AsLinear().(*nn.Linear); !ok { + t.Fatalf("AsLinear type = %T, want *nn.Linear", emb.AsLinear()) + } +} + +func TestMakeEmbeddingLayerQuantized(t *testing.T) { + skipIfNoMLX(t) + + denseWeight := mlx.FromValues(func() []float32 { + out := make([]float32, 2*64) + for i := range out { + out[i] = float32(i%17) / 8 + } + return out + }(), 2, 64).AsType(mlx.DTypeBFloat16) + + qw, scales, qbiases := mlx.Quantize(denseWeight, 64, 4, "affine") + mlx.Eval(qw, scales, qbiases) + + emb := MakeEmbeddingLayer(map[string]*mlx.Array{ + "model.embed_tokens.weight": qw, + "model.embed_tokens.weight_scale": scales, + "model.embed_tokens.weight_qbias": qbiases, + }, "model.embed_tokens", 64, 4, "affine", nil) + + qemb, ok := emb.(*nn.QuantizedEmbedding) + if !ok { + t.Fatalf("embedding type = %T, want *nn.QuantizedEmbedding", emb) + } + if qemb.GroupSize != 64 || qemb.Bits != 4 || qemb.Mode != "affine" { + t.Fatalf("quant params = (%d, %d, %q), want (64, 4, %q)", qemb.GroupSize, qemb.Bits, qemb.Mode, "affine") + } + + indices := mlx.FromValues([]int32{1, 0}, 2) + out := emb.Forward(indices) + mlx.Eval(out) + if dims := out.Dims(); len(dims) != 2 || dims[0] != 2 || dims[1] != 64 { + t.Fatalf("embedding output dims = %v, want [2 64]", dims) + } + if _, ok := emb.AsLinear().(*nn.QuantizedLinear); !ok { + t.Fatalf("AsLinear type = %T, want *nn.QuantizedLinear", emb.AsLinear()) + } +} diff --git a/x/mlxrunner/server.go b/x/mlxrunner/server.go index a9972bfdc..42261fc3e 100644 --- a/x/mlxrunner/server.go +++ b/x/mlxrunner/server.go @@ -147,7 +147,7 @@ func Execute(args []string) error { return } - tokens := runner.Tokenizer.Encode(b.String(), true) + tokens := runner.Tokenizer.Encode(b.String(), runner.Tokenizer.AddBOS()) if err := json.NewEncoder(w).Encode(tokens); err != nil { slog.Error("Failed to encode response", "error", err) diff --git a/x/models/gemma3/gemma3.go b/x/models/gemma3/gemma3.go index 453d040ad..266222b69 100644 --- a/x/models/gemma3/gemma3.go +++ b/x/models/gemma3/gemma3.go @@ -91,7 +91,7 @@ type DecoderLayer struct { // Model is the Gemma 3 text-only model. type Model struct { - EmbedTokens *nn.Embedding + EmbedTokens nn.EmbeddingLayer Layers []*DecoderLayer Norm *nn.RMSNorm LMHead nn.LinearLayer @@ -310,11 +310,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { prefix := m.weightPrefix linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant) - embedWeight := tensors[prefix+"model.embed_tokens.weight"] - if embedWeight == nil { + embedTokens := model.MakeEmbeddingLayer(tensors, prefix+"model.embed_tokens", m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant) + if embedTokens == nil { return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix) } - m.EmbedTokens = nn.NewEmbedding(embedWeight) + m.EmbedTokens = embedTokens normWeight := tensors[prefix+"model.norm.weight"] if normWeight == nil { @@ -328,7 +328,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { m.LMHead = lmHead } else { // Gemma usually ties output projection to embeddings. - m.LMHead = nn.NewLinear(embedWeight, nil) + m.LMHead = m.EmbedTokens.AsLinear() } for i := int32(0); i < m.NumHiddenLayers; i++ { diff --git a/x/models/glm4_moe_lite/glm4_moe_lite.go b/x/models/glm4_moe_lite/glm4_moe_lite.go index 2e8365580..a0a37a3f0 100644 --- a/x/models/glm4_moe_lite/glm4_moe_lite.go +++ b/x/models/glm4_moe_lite/glm4_moe_lite.go @@ -345,7 +345,7 @@ type Block interface { // Model represents the complete GLM4-MoE-Lite model type Model struct { - EmbedTokens *nn.Embedding + EmbedTokens nn.EmbeddingLayer Layers []Block Norm *nn.RMSNorm LMHead nn.LinearLayer @@ -586,9 +586,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { } // Load embedding - if w := tensors["model.embed_tokens.weight"]; w != nil { - m.EmbedTokens = nn.NewEmbedding(w) - } + m.EmbedTokens = model.MakeEmbeddingLayer(tensors, "model.embed_tokens", cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant) // Load final norm if w := tensors["model.norm.weight"]; w != nil { diff --git a/x/models/llama/llama.go b/x/models/llama/llama.go index 1f5a1dad5..ca99d9148 100644 --- a/x/models/llama/llama.go +++ b/x/models/llama/llama.go @@ -44,7 +44,7 @@ type Config struct { // Model is a Llama text model. type Model struct { - EmbedTokens *nn.Embedding + EmbedTokens nn.EmbeddingLayer Layers []*Layer Norm *nn.RMSNorm LMHead nn.LinearLayer @@ -170,11 +170,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { prefix := m.weightPrefix linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant) - embedWeight := tensors[prefix+"model.embed_tokens.weight"] - if embedWeight == nil { + embedTokens := model.MakeEmbeddingLayer(tensors, prefix+"model.embed_tokens", m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant) + if embedTokens == nil { return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix) } - m.EmbedTokens = nn.NewEmbedding(embedWeight) + m.EmbedTokens = embedTokens normWeight := tensors[prefix+"model.norm.weight"] if normWeight == nil { @@ -183,14 +183,14 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps) if m.TieWordEmbeddings { - m.LMHead = nn.NewLinear(embedWeight, nil) + m.LMHead = m.EmbedTokens.AsLinear() } else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil { m.LMHead = lmHead } else if lmHead := linears.Make("lm_head"); lmHead != nil { m.LMHead = lmHead } else { // Fallback used by many Llama checkpoints where output is tied. - m.LMHead = nn.NewLinear(embedWeight, nil) + m.LMHead = m.EmbedTokens.AsLinear() } for i := int32(0); i < m.NumHiddenLayers; i++ { diff --git a/x/models/nn/nn.go b/x/models/nn/nn.go index 3f41f4bd8..56e727617 100644 --- a/x/models/nn/nn.go +++ b/x/models/nn/nn.go @@ -13,6 +13,13 @@ type LinearLayer interface { OutputDim() int32 } +// EmbeddingLayer is an interface for embedding layers that can also expose a +// tied-output projection when the model reuses embedding weights as the LM head. +type EmbeddingLayer interface { + Forward(indices *mlx.Array) *mlx.Array + AsLinear() LinearLayer +} + // Conv1d applies 1D convolution over NLC input. type Conv1d struct { Weight *mlx.Array @@ -140,6 +147,53 @@ func (e *Embedding) Forward(indices *mlx.Array) *mlx.Array { return e.Weight.TakeAxis(indices, 0) } +func (e *Embedding) AsLinear() LinearLayer { + return NewLinear(e.Weight, nil) +} + +// QuantizedEmbedding performs row-wise embedding lookup from affine/nvfp4/etc. +// packed weights and dequantizes only the selected rows. +type QuantizedEmbedding struct { + Weight *mlx.Array + Scales *mlx.Array + QBiases *mlx.Array + GroupSize int + Bits int + Mode string +} + +func NewQuantizedEmbedding(weight, scales, qbiases *mlx.Array, groupSize, bits int, mode string) *QuantizedEmbedding { + return &QuantizedEmbedding{ + Weight: weight, + Scales: scales, + QBiases: qbiases, + GroupSize: groupSize, + Bits: bits, + Mode: mode, + } +} + +func (qe *QuantizedEmbedding) Forward(indices *mlx.Array) *mlx.Array { + weight := qe.Weight.TakeAxis(indices, 0) + scales := qe.Scales.TakeAxis(indices, 0) + var qbiases *mlx.Array + if qe.QBiases != nil && qe.QBiases.Valid() { + qbiases = qe.QBiases.TakeAxis(indices, 0) + } + return mlx.Dequantize(weight, scales, qbiases, qe.GroupSize, qe.Bits, qe.Mode) +} + +func (qe *QuantizedEmbedding) AsLinear() LinearLayer { + return &QuantizedLinear{ + Weight: qe.Weight, + Scales: qe.Scales, + QBiases: qe.QBiases, + GroupSize: qe.GroupSize, + Bits: qe.Bits, + Mode: qe.Mode, + } +} + // LayerNorm represents a standard layer normalization layer (with bias). type LayerNorm struct { Weight *mlx.Array @@ -175,7 +229,6 @@ func (ml *MultiLinear) Forward(x *mlx.Array) *mlx.Array { return x.Matmul(wT) } - // ApplyCausalMask applies causal (lower triangular) mask to attention scores. func ApplyCausalMask(scores *mlx.Array) *mlx.Array { shape := scores.Dims() diff --git a/x/models/qwen3/qwen3.go b/x/models/qwen3/qwen3.go index 71596af98..6773f0eb5 100644 --- a/x/models/qwen3/qwen3.go +++ b/x/models/qwen3/qwen3.go @@ -45,7 +45,7 @@ type Config struct { // Model is the Qwen3 text-only model. type Model struct { - EmbedTokens *nn.Embedding + EmbedTokens nn.EmbeddingLayer Layers []*Layer Norm *nn.RMSNorm LMHead nn.LinearLayer @@ -177,11 +177,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { prefix := m.weightPrefix linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant) - embedWeight := tensors[prefix+"model.embed_tokens.weight"] - if embedWeight == nil { + embedTokens := model.MakeEmbeddingLayer(tensors, prefix+"model.embed_tokens", m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant) + if embedTokens == nil { return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix) } - m.EmbedTokens = nn.NewEmbedding(embedWeight) + m.EmbedTokens = embedTokens normWeight := tensors[prefix+"model.norm.weight"] if normWeight == nil { @@ -190,14 +190,14 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps) if m.TieWordEmbeddings { - m.LMHead = nn.NewLinear(embedWeight, nil) + m.LMHead = m.EmbedTokens.AsLinear() } else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil { m.LMHead = lmHead } else if lmHead := linears.Make("lm_head"); lmHead != nil { m.LMHead = lmHead } else { // Qwen3 checkpoints commonly tie output projection to embeddings. - m.LMHead = nn.NewLinear(embedWeight, nil) + m.LMHead = m.EmbedTokens.AsLinear() } for i := int32(0); i < m.NumHiddenLayers; i++ { diff --git a/x/models/qwen3_5/qwen3_5.go b/x/models/qwen3_5/qwen3_5.go index 642ea1bba..b98830300 100644 --- a/x/models/qwen3_5/qwen3_5.go +++ b/x/models/qwen3_5/qwen3_5.go @@ -81,7 +81,7 @@ type Config struct { // Model is the Qwen 3.5 model. type Model struct { - EmbedTokens *nn.Embedding + EmbedTokens nn.EmbeddingLayer Layers []*Layer Norm *nn.RMSNorm LMHead nn.LinearLayer @@ -824,12 +824,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { freeTensorKeys(tensors, mtpKeys...) } - embedKey := modelPrefix + "embed_tokens.weight" - embedWeight := tensors[embedKey] - if embedWeight == nil { + embedTokens := model.MakeEmbeddingLayer(tensors, modelPrefix+"embed_tokens", cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant) + if embedTokens == nil { return fmt.Errorf("missing embedding weight: %sembed_tokens.weight", modelPrefix) } - m.EmbedTokens = nn.NewEmbedding(embedWeight) + m.EmbedTokens = embedTokens normKey := modelPrefix + "norm.weight" normWeight := maybeShiftNormWeight(normKey, tensors[normKey], shouldShiftNormWeights) @@ -839,13 +838,13 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { m.Norm = nn.NewRMSNorm(normWeight, cfg.RMSNormEps) if cfg.TieWordEmbeddings { - m.LMHead = nn.NewLinear(embedWeight, nil) + m.LMHead = m.EmbedTokens.AsLinear() } else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil { m.LMHead = lmHead } else if lmHead := linears.Make("lm_head"); lmHead != nil { m.LMHead = lmHead } else { - m.LMHead = nn.NewLinear(embedWeight, nil) + m.LMHead = m.EmbedTokens.AsLinear() } useQuantizedExperts := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits) @@ -1065,7 +1064,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { } func softplus(x *mlx.Array) *mlx.Array { - return mlx.Log(mlx.AddScalar(mlx.Exp(x), 1.0)) + return mlx.Logaddexp(x, mlx.Zeros(x.DType(), x.Dims()...)) } func depthwiseCausalConv1d(x, w *mlx.Array, outLen int32) *mlx.Array { @@ -1150,7 +1149,8 @@ func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1) out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) - out = mlx.Mul(out, mlx.Sigmoid(gate)) + gateSigmoid := mlx.Sigmoid(gate) + out = mlx.Mul(out, gateSigmoid) out = a.OProj.Forward(out) return out } @@ -1175,7 +1175,6 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co mlx.Reshape(v, B, L, cfg.LinearNumValueHeads*cfg.LinearValueHeadDim), }, -1) } - convTail := cfg.LinearConvKernelDim - 1 var convState *mlx.Array var rc *cache.RecurrentCache @@ -1216,9 +1215,7 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co q = mlx.MulScalar(mlx.RMSNormFn(q, nil, 1e-6), invScale*invScale) k = mlx.MulScalar(mlx.RMSNormFn(k, nil, 1e-6), invScale) - aF32 := a.AsType(mlx.DTypeFloat32) - dtBiasF32 := g.DtBias.AsType(mlx.DTypeFloat32) - gDecay := softplus(mlx.Add(aF32, dtBiasF32)) + gDecay := softplus(mlx.Add(a, g.DtBias)) gDecay = mlx.Mul(gDecay, g.AExp) gDecay = mlx.Exp(mlx.MulScalar(gDecay, -1)) gDecay = gDecay.AsType(a.DType()) @@ -1234,8 +1231,9 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co } out, state := mlx.GatedDelta(q, k, v, gDecay, beta, state) + outDType := out.DType() out = mlx.RMSNormFn(out, g.NormWeight, cfg.RMSNormEps) - out = mlx.Mul(out, mlx.SiLU(z)) + out = mlx.Mul(out.AsType(mlx.DTypeFloat32), mlx.SiLU(z.AsType(mlx.DTypeFloat32))).AsType(outDType) out = mlx.Reshape(out, B, L, valueDim) out = g.OutProj.Forward(out) if rc != nil { diff --git a/x/models/qwen3_5/qwen3_5_test.go b/x/models/qwen3_5/qwen3_5_test.go index 8165cd484..f425ee5c5 100644 --- a/x/models/qwen3_5/qwen3_5_test.go +++ b/x/models/qwen3_5/qwen3_5_test.go @@ -7,6 +7,13 @@ import ( "github.com/ollama/ollama/x/mlxrunner/mlx" ) +func skipIfNoMLX(t *testing.T) { + t.Helper() + if err := mlx.CheckInit(); err != nil { + t.Skipf("MLX not available: %v", err) + } +} + func TestParseConfigNestedDefaults(t *testing.T) { data := []byte(`{ "model_type": "Qwen3_5MoeForConditionalGeneration", @@ -155,3 +162,184 @@ func TestNewCachesLayout(t *testing.T) { t.Fatalf("cache[2] = %T, want *cache.RecurrentCache", caches[2]) } } + +func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) { + skipIfNoMLX(t) + + cfg := &Config{ + HiddenSize: 4, + IntermediateSize: 8, + NumHiddenLayers: 2, + NumAttentionHeads: 1, + NumKeyValueHeads: 1, + HeadDim: 4, + RMSNormEps: 1e-6, + TieWordEmbeddings: true, + LayerTypes: []string{"linear", "full"}, + LinearNumValueHeads: 1, + LinearNumKeyHeads: 1, + LinearKeyHeadDim: 2, + LinearValueHeadDim: 2, + LinearConvKernelDim: 4, + FullAttentionInterval: 2, + } + + m := &Model{ + Config: cfg, + Layers: make([]*Layer, cfg.NumHiddenLayers), + } + + bf16 := mlx.DTypeBFloat16 + f32 := mlx.DTypeFloat32 + tensors := map[string]*mlx.Array{ + "model.embed_tokens.weight": mlx.FromValues([]float32{1, 2, 3, 4, 5, 6, 7, 8}, 2, 4).AsType(bf16), + "model.norm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4), + "model.layers.0.input_layernorm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4), + "model.layers.0.post_attention_layernorm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4), + "model.layers.0.linear_attn.in_proj_qkv.weight": mlx.FromValues([]float32{ + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + 0, 0, 0, 1, + 1, 1, 0, 0, + 0, 1, 1, 0, + }, 6, 4), + "model.layers.0.linear_attn.in_proj_z.weight": mlx.FromValues([]float32{ + 1, 0, 0, 0, + 0, 1, 0, 0, + }, 2, 4), + "model.layers.0.linear_attn.in_proj_b.weight": mlx.FromValues([]float32{1, 0, 0, 0}, 1, 4), + "model.layers.0.linear_attn.in_proj_a.weight": mlx.FromValues([]float32{0, 1, 0, 0}, 1, 4), + "model.layers.0.linear_attn.out_proj.weight": mlx.FromValues([]float32{ + 1, 0, + 0, 1, + 1, 1, + 0, 0, + }, 4, 2), + "model.layers.0.linear_attn.conv1d.weight": mlx.FromValues([]float32{ + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + 0, 0, 0, 1, + 1, 1, 0, 0, + 0, 1, 1, 0, + }, 6, 4), + "model.layers.0.linear_attn.norm.weight": mlx.FromValues([]float32{1, 1}, 2), + "model.layers.0.linear_attn.dt_bias": mlx.FromValues([]float32{0}, 1), + "model.layers.0.linear_attn.A_log": mlx.FromValues([]float32{0}, 1), + "model.layers.0.mlp.gate_proj.weight": mlx.FromValues([]float32{ + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + 0, 0, 0, 1, + 1, 1, 0, 0, + 0, 1, 1, 0, + 0, 0, 1, 1, + 1, 0, 0, 1, + }, 8, 4), + "model.layers.0.mlp.up_proj.weight": mlx.FromValues([]float32{ + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + 0, 0, 0, 1, + 1, 1, 0, 0, + 0, 1, 1, 0, + 0, 0, 1, 1, + 1, 0, 0, 1, + }, 8, 4), + "model.layers.0.mlp.down_proj.weight": mlx.FromValues([]float32{ + 1, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 0, 0, 0, 0, + }, 4, 8), + "model.layers.1.input_layernorm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4), + "model.layers.1.post_attention_layernorm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4), + "model.layers.1.self_attn.q_proj.weight": mlx.FromValues([]float32{ + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + 0, 0, 0, 1, + 1, 1, 0, 0, + 0, 1, 1, 0, + 0, 0, 1, 1, + 1, 0, 0, 1, + }, 8, 4), + "model.layers.1.self_attn.k_proj.weight": mlx.FromValues([]float32{ + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + 0, 0, 0, 1, + }, 4, 4), + "model.layers.1.self_attn.v_proj.weight": mlx.FromValues([]float32{ + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + 0, 0, 0, 1, + }, 4, 4), + "model.layers.1.self_attn.o_proj.weight": mlx.FromValues([]float32{ + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + 0, 0, 0, 1, + }, 4, 4), + "model.layers.1.self_attn.q_norm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4), + "model.layers.1.self_attn.k_norm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4), + "model.layers.1.mlp.gate_proj.weight": mlx.FromValues([]float32{ + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + 0, 0, 0, 1, + 1, 1, 0, 0, + 0, 1, 1, 0, + 0, 0, 1, 1, + 1, 0, 0, 1, + }, 8, 4), + "model.layers.1.mlp.up_proj.weight": mlx.FromValues([]float32{ + 1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 1, 0, + 0, 0, 0, 1, + 1, 1, 0, 0, + 0, 1, 1, 0, + 0, 0, 1, 1, + 1, 0, 0, 1, + }, 8, 4), + "model.layers.1.mlp.down_proj.weight": mlx.FromValues([]float32{ + 1, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 0, 0, 0, 0, + }, 4, 8), + } + + if err := m.LoadWeights(tensors); err != nil { + t.Fatalf("LoadWeights failed: %v", err) + } + + if got := m.Layers[0].InputNorm.Weight.DType(); got != f32 { + t.Fatalf("layer 0 input norm dtype = %v, want %v", got, f32) + } + if got := m.Layers[0].PostAttentionNorm.Weight.DType(); got != f32 { + t.Fatalf("layer 0 post-attn norm dtype = %v, want %v", got, f32) + } + if got := m.Layers[1].InputNorm.Weight.DType(); got != f32 { + t.Fatalf("layer 1 input norm dtype = %v, want %v", got, f32) + } + if got := m.Layers[1].PostAttentionNorm.Weight.DType(); got != f32 { + t.Fatalf("layer 1 post-attn norm dtype = %v, want %v", got, f32) + } + + if got := m.Norm.Weight.DType(); got != f32 { + t.Fatalf("final norm dtype = %v, want %v", got, f32) + } + if got := m.Layers[0].Linear.NormWeight.DType(); got != f32 { + t.Fatalf("linear-attn norm dtype = %v, want %v", got, f32) + } + if got := m.Layers[1].FullAttn.QNorm.Weight.DType(); got != f32 { + t.Fatalf("q norm dtype = %v, want %v", got, f32) + } + if got := m.Layers[1].FullAttn.KNorm.Weight.DType(); got != f32 { + t.Fatalf("k norm dtype = %v, want %v", got, f32) + } +} diff --git a/x/tokenizer/tokenizer.go b/x/tokenizer/tokenizer.go index a1ce5e8ee..81c76fc32 100644 --- a/x/tokenizer/tokenizer.go +++ b/x/tokenizer/tokenizer.go @@ -71,6 +71,11 @@ func (t *Tokenizer) BOS() int32 { return t.vocab.BOS } +// AddBOS returns whether a BOS token should be prepended during encoding. +func (t *Tokenizer) AddBOS() bool { + return t.vocab.AddBOS +} + // EOS returns the first end of sequence token ID (for backwards compatibility) func (t *Tokenizer) EOS() int32 { if len(t.vocab.EOS) > 0 {