mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 14:54:11 +02:00
Add QuantizedEmbedding and EmbeddingLayer interface so models can use quantized embedding weights and expose tied output projections. This change updates gemma3, glm4_moe_lite, llama, qwen3, and qwen3_5 to use the new interface.
771 lines
23 KiB
Go
771 lines
23 KiB
Go
// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX.
|
|
// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE).
|
|
package glm4_moe_lite
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
"github.com/ollama/ollama/x/mlxrunner/model"
|
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
|
"github.com/ollama/ollama/x/models/nn"
|
|
"github.com/ollama/ollama/x/tokenizer"
|
|
)
|
|
|
|
func init() {
|
|
base.Register("Glm4MoeLiteForCausalLM", newModel)
|
|
base.Register("GLM4MoeLite", newModel)
|
|
}
|
|
|
|
// RopeScaling holds RoPE scaling configuration
|
|
type RopeScaling struct {
|
|
Factor float32 `json:"factor"`
|
|
MscaleAllDim float32 `json:"mscale_all_dim"`
|
|
}
|
|
|
|
// Config holds GLM4-MoE-Lite model configuration
|
|
type Config struct {
|
|
HiddenSize int32 `json:"hidden_size"`
|
|
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
|
IntermediateSize int32 `json:"intermediate_size"`
|
|
MoEIntermediateSize int32 `json:"moe_intermediate_size"`
|
|
NumAttentionHeads int32 `json:"num_attention_heads"`
|
|
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
|
VocabSize int32 `json:"vocab_size"`
|
|
RMSNormEps float32 `json:"rms_norm_eps"`
|
|
RopeTheta float32 `json:"rope_theta"`
|
|
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
|
AttentionBias bool `json:"attention_bias"`
|
|
|
|
// MLA (Multi-head Latent Attention) parameters
|
|
QLoraRank int32 `json:"q_lora_rank"`
|
|
KVLoraRank int32 `json:"kv_lora_rank"`
|
|
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
|
|
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
|
|
VHeadDim int32 `json:"v_head_dim"`
|
|
|
|
// MoE parameters
|
|
NRoutedExperts int32 `json:"n_routed_experts"`
|
|
NSharedExperts int32 `json:"n_shared_experts"`
|
|
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
|
|
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
|
|
NormTopKProb bool `json:"norm_topk_prob"`
|
|
FirstKDenseReplace int32 `json:"first_k_dense_replace"`
|
|
NGroup int32 `json:"n_group"`
|
|
TopKGroup int32 `json:"topk_group"`
|
|
|
|
// RoPE scaling
|
|
RopeScaling *RopeScaling `json:"rope_scaling"`
|
|
|
|
// Quantization parameters (set during load based on model quantization)
|
|
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
|
|
QuantBits int `json:"-"` // Bits per weight (4 or 8)
|
|
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
|
|
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
|
|
|
|
// Computed fields
|
|
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
|
|
Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment
|
|
}
|
|
|
|
// MLAAttention implements Multi-head Latent Attention with absorption.
|
|
type MLAAttention struct {
|
|
QAProj nn.LinearLayer
|
|
QALayerNorm *nn.RMSNorm
|
|
QBProj nn.LinearLayer
|
|
|
|
KVAProjWithMQA nn.LinearLayer
|
|
KVALayerNorm *nn.RMSNorm
|
|
|
|
EmbedQ *nn.MultiLinear
|
|
UnembedOut *nn.MultiLinear
|
|
|
|
OProj nn.LinearLayer
|
|
}
|
|
|
|
// Forward computes absorbed MLA attention output.
|
|
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
|
q := a.QAProj.Forward(x)
|
|
q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
|
|
q = a.QBProj.Forward(q)
|
|
|
|
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim)
|
|
q = mlx.Transpose(q, 0, 2, 1, 3)
|
|
|
|
qNope := mlx.SliceStartStop(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
|
|
qPE := mlx.SliceStartStop(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim})
|
|
|
|
compressedKV := a.KVAProjWithMQA.Forward(x)
|
|
|
|
kvCompressed := mlx.SliceStartStop(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank})
|
|
kPE := mlx.SliceStartStop(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim})
|
|
|
|
kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim)
|
|
kPE = mlx.Transpose(kPE, 0, 2, 1, 3)
|
|
|
|
kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
|
|
kvLatent = mlx.ExpandDims(kvLatent, 1)
|
|
|
|
offset := 0
|
|
if c != nil {
|
|
offset = c.Offset()
|
|
}
|
|
qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
|
kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
|
|
|
qLatent := a.EmbedQ.Forward(qNope)
|
|
|
|
keys := mlx.Concatenate([]*mlx.Array{kvLatent, kPE}, 3)
|
|
|
|
cachedL := L
|
|
if c != nil {
|
|
placeholderValues := mlx.ZerosF32([]int32{B, 1, L, 0})
|
|
keys, _ = c.Update(keys, placeholderValues)
|
|
cachedL = int32(keys.Dim(2))
|
|
}
|
|
|
|
values := mlx.SliceStartStop(keys, []int32{0, 0, 0, 0}, []int32{B, 1, cachedL, cfg.KVLoraRank})
|
|
|
|
queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3)
|
|
|
|
out := mlx.ScaledDotProductAttentionCausal(queries, keys, values, cfg.Scale, L > 1)
|
|
out = a.UnembedOut.Forward(out)
|
|
|
|
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
|
|
|
|
return a.OProj.Forward(out)
|
|
}
|
|
|
|
// DenseMLP implements the standard SwiGLU MLP for dense layers
|
|
type DenseMLP struct {
|
|
GateProj nn.LinearLayer
|
|
UpProj nn.LinearLayer
|
|
DownProj nn.LinearLayer
|
|
}
|
|
|
|
// Forward applies the SwiGLU MLP
|
|
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
|
|
gate := mlx.SiLU(m.GateProj.Forward(x))
|
|
up := m.UpProj.Forward(x)
|
|
return m.DownProj.Forward(mlx.Mul(gate, up))
|
|
}
|
|
|
|
// MoEGate implements the expert gating mechanism
|
|
type MoEGate struct {
|
|
Gate nn.LinearLayer
|
|
EScoreCorrectionBias *mlx.Array
|
|
}
|
|
|
|
// Forward computes expert selection indices and scores
|
|
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
|
|
gates := g.Gate.Forward(x)
|
|
|
|
scores := mlx.Sigmoid(gates)
|
|
origScores := scores
|
|
|
|
if g.EScoreCorrectionBias != nil {
|
|
scores = mlx.Add(scores, g.EScoreCorrectionBias)
|
|
}
|
|
|
|
topK := cfg.NumExpertsPerTok
|
|
negScores := mlx.Neg(scores)
|
|
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
|
|
|
|
dims := inds.Dims()
|
|
inds = mlx.SliceStartStop(inds, []int32{0, 0, 0}, []int32{int32(dims[0]), int32(dims[1]), topK})
|
|
|
|
scores = mlx.TakeAlongAxis(origScores, inds, -1)
|
|
|
|
if topK > 1 && cfg.NormTopKProb {
|
|
sumScores := mlx.Sum(scores, -1, true)
|
|
scores = mlx.Div(scores, sumScores)
|
|
}
|
|
|
|
scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor)
|
|
|
|
return inds, scores
|
|
}
|
|
|
|
// SwitchMLP implements the MoE expert computation using stacked weights
|
|
type SwitchMLP struct {
|
|
GateWeight *mlx.Array
|
|
UpWeight *mlx.Array
|
|
DownWeight *mlx.Array
|
|
|
|
GateWeightQ, GateScales, GateBiases *mlx.Array
|
|
UpWeightQ, UpScales, UpBiases *mlx.Array
|
|
DownWeightQ, DownScales, DownBiases *mlx.Array
|
|
|
|
GateBits int
|
|
UpBits int
|
|
DownBits int
|
|
|
|
GateGroupSize int
|
|
UpGroupSize int
|
|
DownGroupSize int
|
|
|
|
UseQuantized bool
|
|
}
|
|
|
|
// Forward applies the switched expert MLP
|
|
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
|
|
dims := x.Dims()
|
|
B, L := int32(dims[0]), int32(dims[1])
|
|
topK := cfg.NumExpertsPerTok
|
|
|
|
xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2)
|
|
|
|
xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize)
|
|
|
|
idxFlat := mlx.Reshape(indices, B*L, topK)
|
|
|
|
doSort := B*L >= 64
|
|
var invOrder *mlx.Array
|
|
n := B * L * topK
|
|
|
|
if doSort {
|
|
idxAll := mlx.Flatten(idxFlat)
|
|
order := mlx.Argsort(idxAll, 0)
|
|
invOrder = mlx.Argsort(order, 0)
|
|
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1)
|
|
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
|
|
}
|
|
|
|
var gate, up, hidden, down *mlx.Array
|
|
|
|
if s.UseQuantized {
|
|
gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases,
|
|
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
|
|
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
|
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
|
|
|
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
|
|
|
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
|
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
|
} else {
|
|
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
|
|
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
|
|
|
|
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
|
|
|
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
|
|
}
|
|
|
|
if doSort {
|
|
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize)
|
|
} else {
|
|
down = mlx.Squeeze(down, 2)
|
|
}
|
|
|
|
return mlx.Reshape(down, B, L, topK, cfg.HiddenSize)
|
|
}
|
|
|
|
// SharedExperts implements the shared expert MLP
|
|
type SharedExperts struct {
|
|
GateProj nn.LinearLayer
|
|
UpProj nn.LinearLayer
|
|
DownProj nn.LinearLayer
|
|
}
|
|
|
|
// Forward applies the shared expert MLP
|
|
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
|
|
gate := mlx.SiLU(s.GateProj.Forward(x))
|
|
up := s.UpProj.Forward(x)
|
|
return s.DownProj.Forward(mlx.Mul(gate, up))
|
|
}
|
|
|
|
// MoE implements the full Mixture of Experts layer
|
|
type MoE struct {
|
|
Gate *MoEGate
|
|
SwitchMLP *SwitchMLP
|
|
SharedExperts *SharedExperts
|
|
}
|
|
|
|
// Forward applies the MoE layer
|
|
func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
|
|
dims := x.Dims()
|
|
B, L := int32(dims[0]), int32(dims[1])
|
|
|
|
inds, scores := m.Gate.Forward(x, cfg)
|
|
|
|
expertOut := m.SwitchMLP.Forward(x, inds, cfg)
|
|
|
|
scoresExpanded := mlx.ExpandDims(scores, -1)
|
|
y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false)
|
|
|
|
if m.SharedExperts != nil {
|
|
y = mlx.Add(y, m.SharedExperts.Forward(x))
|
|
}
|
|
|
|
return mlx.Reshape(y, B, L, cfg.HiddenSize)
|
|
}
|
|
|
|
// DenseBlock represents a dense transformer block (for first_k_dense_replace layers)
|
|
type DenseBlock struct {
|
|
Attention *MLAAttention
|
|
MLP *DenseMLP
|
|
InputLayerNorm *nn.RMSNorm
|
|
PostAttentionLayerNorm *nn.RMSNorm
|
|
}
|
|
|
|
// Forward applies the dense block
|
|
func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
|
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
|
h := mlx.Add(x, r)
|
|
|
|
r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
|
|
return mlx.Add(h, r)
|
|
}
|
|
|
|
// MoEBlock represents a MoE transformer block
|
|
type MoEBlock struct {
|
|
Attention *MLAAttention
|
|
MoE *MoE
|
|
InputLayerNorm *nn.RMSNorm
|
|
PostAttentionLayerNorm *nn.RMSNorm
|
|
}
|
|
|
|
// Forward applies the MoE block
|
|
func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
|
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
|
h := mlx.Add(x, r)
|
|
|
|
r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
|
|
return mlx.Add(h, r)
|
|
}
|
|
|
|
// Block interface for both dense and MoE blocks
|
|
type Block interface {
|
|
Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
|
|
}
|
|
|
|
// Model represents the complete GLM4-MoE-Lite model
|
|
type Model struct {
|
|
EmbedTokens nn.EmbeddingLayer
|
|
Layers []Block
|
|
Norm *nn.RMSNorm
|
|
LMHead nn.LinearLayer
|
|
|
|
tok *tokenizer.Tokenizer
|
|
*Config
|
|
}
|
|
|
|
// computeScale computes the attention scale.
|
|
func computeScale(cfg *Config) float32 {
|
|
keyLength := cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
|
|
scale := float32(1.0 / math.Sqrt(float64(keyLength)))
|
|
if cfg.RopeScaling != nil && cfg.RopeScaling.MscaleAllDim > 0 && cfg.RopeScaling.Factor > 1 {
|
|
s := 0.1*cfg.RopeScaling.MscaleAllDim*float32(math.Log(float64(cfg.RopeScaling.Factor))) + 1.0
|
|
scale *= s * s
|
|
}
|
|
return scale
|
|
}
|
|
|
|
// supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support.
|
|
func supportsGatherQMM(mode string, bits int) bool {
|
|
return mode == "affine" && (bits == 4 || bits == 8)
|
|
}
|
|
|
|
// ExpertWeight holds a single expert's weight with optional quantization components.
|
|
type ExpertWeight struct {
|
|
Weight *mlx.Array
|
|
Scales *mlx.Array
|
|
Biases *mlx.Array
|
|
Bits int
|
|
GroupSize int
|
|
}
|
|
|
|
// loadExpertWeight loads an expert weight from the tensor map.
|
|
func loadExpertWeight(tensors map[string]*mlx.Array, path string, useQuantized bool, cfg *Config) *ExpertWeight {
|
|
w := tensors[path+".weight"]
|
|
if w == nil {
|
|
return nil
|
|
}
|
|
|
|
scales := tensors[path+".weight_scale"]
|
|
if scales != nil {
|
|
qbiases := tensors[path+".weight_qbias"]
|
|
|
|
groupSize, bits, mode := model.ResolveLinearQuantParams(
|
|
cfg.QuantGroupSize,
|
|
cfg.QuantBits,
|
|
cfg.QuantMode,
|
|
cfg.TensorQuant,
|
|
path+".weight",
|
|
w,
|
|
scales,
|
|
)
|
|
|
|
if useQuantized && supportsGatherQMM(mode, bits) {
|
|
return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
|
|
}
|
|
|
|
return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)}
|
|
}
|
|
|
|
return &ExpertWeight{Weight: w}
|
|
}
|
|
|
|
// StackedExpertWeights holds stacked weights for all experts.
|
|
type StackedExpertWeights struct {
|
|
Weight *mlx.Array
|
|
Scales *mlx.Array
|
|
Biases *mlx.Array
|
|
Bits int
|
|
GroupSize int
|
|
}
|
|
|
|
// collectAndStackExpertWeights loads and stacks expert weights for one projection type.
|
|
func collectAndStackExpertWeights(
|
|
tensors map[string]*mlx.Array,
|
|
prefix string,
|
|
projName string,
|
|
numExperts int32,
|
|
useQuantized bool,
|
|
cfg *Config,
|
|
) *StackedExpertWeights {
|
|
var w, s, b []*mlx.Array
|
|
var bits, groupSize int
|
|
|
|
for e := int32(0); e < numExperts; e++ {
|
|
path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName)
|
|
ew := loadExpertWeight(tensors, path, useQuantized, cfg)
|
|
if ew == nil {
|
|
continue
|
|
}
|
|
w = append(w, ew.Weight)
|
|
if ew.Scales != nil {
|
|
s = append(s, ew.Scales)
|
|
}
|
|
if ew.Biases != nil {
|
|
b = append(b, ew.Biases)
|
|
}
|
|
if e == 0 {
|
|
bits = ew.Bits
|
|
groupSize = ew.GroupSize
|
|
}
|
|
}
|
|
|
|
result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize}
|
|
if len(w) > 0 {
|
|
result.Weight = mlx.Stack(w, 0)
|
|
if len(s) > 0 {
|
|
result.Scales = mlx.Stack(s, 0)
|
|
}
|
|
if len(b) > 0 {
|
|
result.Biases = mlx.Stack(b, 0)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// sanitizeExpertWeights stacks individual expert weights into tensors.
|
|
func sanitizeExpertWeights(tensors map[string]*mlx.Array, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) {
|
|
gate = collectAndStackExpertWeights(tensors, prefix, "gate_proj", numExperts, useQuantized, cfg)
|
|
up = collectAndStackExpertWeights(tensors, prefix, "up_proj", numExperts, useQuantized, cfg)
|
|
down = collectAndStackExpertWeights(tensors, prefix, "down_proj", numExperts, useQuantized, cfg)
|
|
return gate, up, down
|
|
}
|
|
|
|
// sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format.
|
|
func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Config) (*mlx.Array, *mlx.Array) {
|
|
path := prefix + ".self_attn.kv_b_proj"
|
|
w := tensors[path+".weight"]
|
|
if w == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
// Check if quantized and dequantize
|
|
if scales := tensors[path+".weight_scale"]; scales != nil {
|
|
qbiases := tensors[path+".weight_qbias"]
|
|
groupSize, bits, mode := model.ResolveLinearQuantParams(
|
|
cfg.QuantGroupSize,
|
|
cfg.QuantBits,
|
|
cfg.QuantMode,
|
|
cfg.TensorQuant,
|
|
path+".weight",
|
|
w,
|
|
scales,
|
|
)
|
|
w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
|
|
}
|
|
|
|
headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
|
|
w = mlx.Reshape(w, cfg.NumAttentionHeads, headDim, cfg.KVLoraRank)
|
|
|
|
wk := mlx.SliceStartStop(w, []int32{0, 0, 0}, []int32{cfg.NumAttentionHeads, cfg.QKNopeHeadDim, cfg.KVLoraRank})
|
|
wv := mlx.SliceStartStop(w, []int32{0, cfg.QKNopeHeadDim, 0}, []int32{cfg.NumAttentionHeads, headDim, cfg.KVLoraRank})
|
|
|
|
embedQ := mlx.Transpose(wk, 0, 2, 1)
|
|
unembedOut := wv
|
|
|
|
return embedQ, unembedOut
|
|
}
|
|
|
|
// newModel creates a new GLM4-MoE-Lite model from a Root (config + tokenizer,
|
|
// no weights loaded yet). Called by the registry via base.New().
|
|
func newModel(root *model.Root) (base.Model, error) {
|
|
configData, err := root.Manifest.ReadConfig("config.json")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load config: %w", err)
|
|
}
|
|
|
|
var cfg Config
|
|
if err := json.Unmarshal(configData, &cfg); err != nil {
|
|
return nil, fmt.Errorf("parse config: %w", err)
|
|
}
|
|
|
|
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
|
|
cfg.Scale = computeScale(&cfg)
|
|
|
|
// Set up quantization parameters from pre-scanned metadata
|
|
if qt := root.QuantType(); qt != "" {
|
|
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
|
if gs := root.GroupSize(); gs > 0 {
|
|
cfg.QuantGroupSize = gs
|
|
}
|
|
} else {
|
|
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
|
|
}
|
|
cfg.TensorQuant = root.AllTensorQuant()
|
|
|
|
// Load tokenizer
|
|
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
|
}
|
|
|
|
tokConfig := &tokenizer.TokenizerConfig{
|
|
ConfigJSON: configData,
|
|
}
|
|
|
|
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
|
|
tokConfig.GenerationConfigJSON = genConfigData
|
|
}
|
|
|
|
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
|
tokConfig.TokenizerConfigJSON = tokConfigData
|
|
}
|
|
|
|
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse tokenizer: %w", err)
|
|
}
|
|
|
|
m := &Model{
|
|
Layers: make([]Block, cfg.NumHiddenLayers),
|
|
Config: &cfg,
|
|
tok: tok,
|
|
}
|
|
|
|
return m, nil
|
|
}
|
|
|
|
// LoadWeights receives all tensors loaded from the manifest and assigns them
|
|
// to model fields. Handles MLA absorption, expert stacking, and quantized
|
|
// layer creation.
|
|
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
|
cfg := m.Config
|
|
linears := model.NewLinearFactory(tensors, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant)
|
|
useQuantized := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
|
|
if !useQuantized && cfg.TensorQuant != nil {
|
|
for _, tq := range cfg.TensorQuant {
|
|
if tq == nil {
|
|
continue
|
|
}
|
|
_, bits, mode := model.QuantizationParams(tq.QuantType)
|
|
if supportsGatherQMM(mode, bits) {
|
|
useQuantized = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// Load embedding
|
|
m.EmbedTokens = model.MakeEmbeddingLayer(tensors, "model.embed_tokens", cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant)
|
|
|
|
// Load final norm
|
|
if w := tensors["model.norm.weight"]; w != nil {
|
|
m.Norm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
|
}
|
|
|
|
// Load LM head
|
|
m.LMHead = linears.Make("lm_head")
|
|
|
|
// Load layers
|
|
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
|
prefix := fmt.Sprintf("model.layers.%d", i)
|
|
|
|
// Load attention (same for both block types)
|
|
attn := &MLAAttention{}
|
|
attn.QAProj = linears.Make(prefix + ".self_attn.q_a_proj")
|
|
if w := tensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil {
|
|
attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
|
}
|
|
attn.QBProj = linears.Make(prefix + ".self_attn.q_b_proj")
|
|
attn.KVAProjWithMQA = linears.Make(prefix + ".self_attn.kv_a_proj_with_mqa")
|
|
if w := tensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil {
|
|
attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
|
}
|
|
attn.OProj = linears.Make(prefix + ".self_attn.o_proj")
|
|
|
|
// Sanitize MLA weights for absorbed attention
|
|
embedQ, unembedOut := sanitizeMLAWeights(tensors, prefix, cfg)
|
|
attn.EmbedQ = nn.NewMultiLinear(embedQ)
|
|
attn.UnembedOut = nn.NewMultiLinear(unembedOut)
|
|
|
|
inputLN := tensors[prefix+".input_layernorm.weight"]
|
|
postAttnLN := tensors[prefix+".post_attention_layernorm.weight"]
|
|
|
|
if i < cfg.FirstKDenseReplace {
|
|
// Dense block
|
|
block := &DenseBlock{Attention: attn}
|
|
if inputLN != nil {
|
|
block.InputLayerNorm = nn.NewRMSNorm(inputLN, cfg.RMSNormEps)
|
|
}
|
|
if postAttnLN != nil {
|
|
block.PostAttentionLayerNorm = nn.NewRMSNorm(postAttnLN, cfg.RMSNormEps)
|
|
}
|
|
|
|
block.MLP = &DenseMLP{
|
|
GateProj: linears.Make(prefix + ".mlp.gate_proj"),
|
|
UpProj: linears.Make(prefix + ".mlp.up_proj"),
|
|
DownProj: linears.Make(prefix + ".mlp.down_proj"),
|
|
}
|
|
|
|
m.Layers[i] = block
|
|
} else {
|
|
// MoE block
|
|
block := &MoEBlock{Attention: attn}
|
|
if inputLN != nil {
|
|
block.InputLayerNorm = nn.NewRMSNorm(inputLN, cfg.RMSNormEps)
|
|
}
|
|
if postAttnLN != nil {
|
|
block.PostAttentionLayerNorm = nn.NewRMSNorm(postAttnLN, cfg.RMSNormEps)
|
|
}
|
|
|
|
// Stack expert weights
|
|
gate, up, down := sanitizeExpertWeights(tensors, prefix, cfg.NRoutedExperts, useQuantized, cfg)
|
|
|
|
switchMLP := &SwitchMLP{UseQuantized: useQuantized}
|
|
if useQuantized {
|
|
switchMLP.GateWeightQ = gate.Weight
|
|
switchMLP.GateScales = gate.Scales
|
|
switchMLP.GateBiases = gate.Biases
|
|
switchMLP.GateBits = gate.Bits
|
|
switchMLP.GateGroupSize = gate.GroupSize
|
|
switchMLP.UpWeightQ = up.Weight
|
|
switchMLP.UpScales = up.Scales
|
|
switchMLP.UpBiases = up.Biases
|
|
switchMLP.UpBits = up.Bits
|
|
switchMLP.UpGroupSize = up.GroupSize
|
|
switchMLP.DownWeightQ = down.Weight
|
|
switchMLP.DownScales = down.Scales
|
|
switchMLP.DownBiases = down.Biases
|
|
switchMLP.DownBits = down.Bits
|
|
switchMLP.DownGroupSize = down.GroupSize
|
|
} else {
|
|
switchMLP.GateWeight = gate.Weight
|
|
switchMLP.UpWeight = up.Weight
|
|
switchMLP.DownWeight = down.Weight
|
|
}
|
|
|
|
moeGate := &MoEGate{}
|
|
moeGate.Gate = linears.Make(prefix + ".mlp.gate")
|
|
if bias := tensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil {
|
|
moeGate.EScoreCorrectionBias = bias
|
|
}
|
|
|
|
block.MoE = &MoE{
|
|
Gate: moeGate,
|
|
SwitchMLP: switchMLP,
|
|
}
|
|
|
|
// Load shared experts if present
|
|
if cfg.NSharedExperts > 0 {
|
|
block.MoE.SharedExperts = &SharedExperts{
|
|
GateProj: linears.Make(prefix + ".mlp.shared_experts.gate_proj"),
|
|
UpProj: linears.Make(prefix + ".mlp.shared_experts.up_proj"),
|
|
DownProj: linears.Make(prefix + ".mlp.shared_experts.down_proj"),
|
|
}
|
|
}
|
|
|
|
m.Layers[i] = block
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Forward computes the forward pass of the model
|
|
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
|
dims := tokens.Dims()
|
|
B, L := int32(dims[0]), int32(dims[1])
|
|
|
|
h := m.EmbedTokens.Forward(tokens)
|
|
|
|
for i, layer := range m.Layers {
|
|
var c cache.Cache
|
|
if caches != nil {
|
|
c = caches[i]
|
|
}
|
|
h = layer.Forward(h, c, B, L, m.Config)
|
|
}
|
|
|
|
h = m.Norm.Forward(h, m.RMSNormEps)
|
|
return h
|
|
}
|
|
|
|
// Unembed applies the LM head to get logits.
|
|
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
|
return m.LMHead.Forward(x)
|
|
}
|
|
|
|
// NumLayers returns the number of transformer layers
|
|
func (m *Model) NumLayers() int { return len(m.Layers) }
|
|
|
|
// MaxContextLength returns the maximum context length
|
|
func (m *Model) MaxContextLength() int { return int(m.MaxPositionEmbeddings) }
|
|
|
|
// VocabSize returns the vocabulary size
|
|
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
|
|
|
// Tokenizer returns the model's tokenizer
|
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
|
|
|
// NewCache creates a new KV cache for the model
|
|
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
|
|
caches := make([]cache.Cache, len(m.Layers))
|
|
for i := range caches {
|
|
caches[i] = cache.NewKVCache()
|
|
}
|
|
return caches
|
|
}
|
|
|
|
// FormatPrompt applies the GLM-4 chat template with thinking enabled by default.
|
|
func (m *Model) FormatPrompt(prompt string) string {
|
|
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
|
|
}
|
|
|
|
// FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control.
|
|
func (m *Model) FormatPromptWithThinking(prompt string, think bool) string {
|
|
if think {
|
|
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
|
|
}
|
|
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|></think>"
|
|
}
|
|
|
|
// NewRenderer returns a new Renderer for formatting multi-turn conversations.
|
|
func (m *Model) NewRenderer() *Renderer {
|
|
return &Renderer{}
|
|
}
|
|
|
|
// NewParser returns a new Parser for extracting thinking and tool calls from output.
|
|
func (m *Model) NewParser() *Parser {
|
|
return &Parser{}
|
|
}
|