// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX. // This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE). package glm4_moe_lite import ( "encoding/json" "fmt" "math" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" "github.com/ollama/ollama/x/mlxrunner/model/base" "github.com/ollama/ollama/x/models/nn" "github.com/ollama/ollama/x/tokenizer" ) func init() { base.Register("Glm4MoeLiteForCausalLM", newModel) base.Register("GLM4MoeLite", newModel) } // RopeScaling holds RoPE scaling configuration type RopeScaling struct { Factor float32 `json:"factor"` MscaleAllDim float32 `json:"mscale_all_dim"` } // Config holds GLM4-MoE-Lite model configuration type Config struct { HiddenSize int32 `json:"hidden_size"` NumHiddenLayers int32 `json:"num_hidden_layers"` IntermediateSize int32 `json:"intermediate_size"` MoEIntermediateSize int32 `json:"moe_intermediate_size"` NumAttentionHeads int32 `json:"num_attention_heads"` NumKeyValueHeads int32 `json:"num_key_value_heads"` VocabSize int32 `json:"vocab_size"` RMSNormEps float32 `json:"rms_norm_eps"` RopeTheta float32 `json:"rope_theta"` MaxPositionEmbeddings int32 `json:"max_position_embeddings"` AttentionBias bool `json:"attention_bias"` // MLA (Multi-head Latent Attention) parameters QLoraRank int32 `json:"q_lora_rank"` KVLoraRank int32 `json:"kv_lora_rank"` QKRopeHeadDim int32 `json:"qk_rope_head_dim"` QKNopeHeadDim int32 `json:"qk_nope_head_dim"` VHeadDim int32 `json:"v_head_dim"` // MoE parameters NRoutedExperts int32 `json:"n_routed_experts"` NSharedExperts int32 `json:"n_shared_experts"` NumExpertsPerTok int32 `json:"num_experts_per_tok"` RoutedScalingFactor float32 `json:"routed_scaling_factor"` NormTopKProb bool `json:"norm_topk_prob"` FirstKDenseReplace int32 `json:"first_k_dense_replace"` NGroup int32 `json:"n_group"` TopKGroup int32 `json:"topk_group"` // RoPE scaling RopeScaling *RopeScaling `json:"rope_scaling"` // Quantization parameters (set during load based on model quantization) QuantGroupSize int `json:"-"` // Group size for quantization (default 64) QuantBits int `json:"-"` // Bits per weight (4 or 8) QuantMode string `json:"-"` // Quantization mode ("affine", etc.) TensorQuant map[string]*model.TensorQuantInfo `json:"-"` // Computed fields QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment } // MLAAttention implements Multi-head Latent Attention with absorption. type MLAAttention struct { QAProj nn.LinearLayer QALayerNorm *nn.RMSNorm QBProj nn.LinearLayer KVAProjWithMQA nn.LinearLayer KVALayerNorm *nn.RMSNorm EmbedQ *nn.MultiLinear UnembedOut *nn.MultiLinear OProj nn.LinearLayer } // Forward computes absorbed MLA attention output. func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { q := a.QAProj.Forward(x) q = a.QALayerNorm.Forward(q, cfg.RMSNormEps) q = a.QBProj.Forward(q) q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim) q = mlx.Transpose(q, 0, 2, 1, 3) qNope := mlx.SliceStartStop(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim}) qPE := mlx.SliceStartStop(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim}) compressedKV := a.KVAProjWithMQA.Forward(x) kvCompressed := mlx.SliceStartStop(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank}) kPE := mlx.SliceStartStop(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim}) kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim) kPE = mlx.Transpose(kPE, 0, 2, 1, 3) kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps) kvLatent = mlx.ExpandDims(kvLatent, 1) offset := 0 if c != nil { offset = c.Offset() } qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset) kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset) qLatent := a.EmbedQ.Forward(qNope) keys := mlx.Concatenate([]*mlx.Array{kvLatent, kPE}, 3) cachedL := L if c != nil { placeholderValues := mlx.ZerosF32([]int32{B, 1, L, 0}) keys, _ = c.Update(keys, placeholderValues) cachedL = int32(keys.Dim(2)) } values := mlx.SliceStartStop(keys, []int32{0, 0, 0, 0}, []int32{B, 1, cachedL, cfg.KVLoraRank}) queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3) out := mlx.ScaledDotProductAttentionCausal(queries, keys, values, cfg.Scale, L > 1) out = a.UnembedOut.Forward(out) out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim) return a.OProj.Forward(out) } // DenseMLP implements the standard SwiGLU MLP for dense layers type DenseMLP struct { GateProj nn.LinearLayer UpProj nn.LinearLayer DownProj nn.LinearLayer } // Forward applies the SwiGLU MLP func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array { gate := mlx.SiLU(m.GateProj.Forward(x)) up := m.UpProj.Forward(x) return m.DownProj.Forward(mlx.Mul(gate, up)) } // MoEGate implements the expert gating mechanism type MoEGate struct { Gate nn.LinearLayer EScoreCorrectionBias *mlx.Array } // Forward computes expert selection indices and scores func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) { gates := g.Gate.Forward(x) scores := mlx.Sigmoid(gates) origScores := scores if g.EScoreCorrectionBias != nil { scores = mlx.Add(scores, g.EScoreCorrectionBias) } topK := cfg.NumExpertsPerTok negScores := mlx.Neg(scores) inds := mlx.Argpartition(negScores, int(topK)-1, -1) dims := inds.Dims() inds = mlx.SliceStartStop(inds, []int32{0, 0, 0}, []int32{int32(dims[0]), int32(dims[1]), topK}) scores = mlx.TakeAlongAxis(origScores, inds, -1) if topK > 1 && cfg.NormTopKProb { sumScores := mlx.Sum(scores, -1, true) scores = mlx.Div(scores, sumScores) } scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor) return inds, scores } // SwitchMLP implements the MoE expert computation using stacked weights type SwitchMLP struct { GateWeight *mlx.Array UpWeight *mlx.Array DownWeight *mlx.Array GateWeightQ, GateScales, GateBiases *mlx.Array UpWeightQ, UpScales, UpBiases *mlx.Array DownWeightQ, DownScales, DownBiases *mlx.Array GateBits int UpBits int DownBits int GateGroupSize int UpGroupSize int DownGroupSize int UseQuantized bool } // Forward applies the switched expert MLP func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array { dims := x.Dims() B, L := int32(dims[0]), int32(dims[1]) topK := cfg.NumExpertsPerTok xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2) xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize) idxFlat := mlx.Reshape(indices, B*L, topK) doSort := B*L >= 64 var invOrder *mlx.Array n := B * L * topK if doSort { idxAll := mlx.Flatten(idxFlat) order := mlx.Argsort(idxAll, 0) invOrder = mlx.Argsort(order, 0) xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1) idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1) } var gate, up, hidden, down *mlx.Array if s.UseQuantized { gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases, nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort) up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases, nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort) hidden = mlx.Mul(mlx.SiLU(gate), up) down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases, nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort) } else { gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort) up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort) hidden = mlx.Mul(mlx.SiLU(gate), up) down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort) } if doSort { down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize) } else { down = mlx.Squeeze(down, 2) } return mlx.Reshape(down, B, L, topK, cfg.HiddenSize) } // SharedExperts implements the shared expert MLP type SharedExperts struct { GateProj nn.LinearLayer UpProj nn.LinearLayer DownProj nn.LinearLayer } // Forward applies the shared expert MLP func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array { gate := mlx.SiLU(s.GateProj.Forward(x)) up := s.UpProj.Forward(x) return s.DownProj.Forward(mlx.Mul(gate, up)) } // MoE implements the full Mixture of Experts layer type MoE struct { Gate *MoEGate SwitchMLP *SwitchMLP SharedExperts *SharedExperts } // Forward applies the MoE layer func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array { dims := x.Dims() B, L := int32(dims[0]), int32(dims[1]) inds, scores := m.Gate.Forward(x, cfg) expertOut := m.SwitchMLP.Forward(x, inds, cfg) scoresExpanded := mlx.ExpandDims(scores, -1) y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false) if m.SharedExperts != nil { y = mlx.Add(y, m.SharedExperts.Forward(x)) } return mlx.Reshape(y, B, L, cfg.HiddenSize) } // DenseBlock represents a dense transformer block (for first_k_dense_replace layers) type DenseBlock struct { Attention *MLAAttention MLP *DenseMLP InputLayerNorm *nn.RMSNorm PostAttentionLayerNorm *nn.RMSNorm } // Forward applies the dense block func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg) h := mlx.Add(x, r) r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps)) return mlx.Add(h, r) } // MoEBlock represents a MoE transformer block type MoEBlock struct { Attention *MLAAttention MoE *MoE InputLayerNorm *nn.RMSNorm PostAttentionLayerNorm *nn.RMSNorm } // Forward applies the MoE block func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg) h := mlx.Add(x, r) r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg) return mlx.Add(h, r) } // Block interface for both dense and MoE blocks type Block interface { Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array } // Model represents the complete GLM4-MoE-Lite model type Model struct { EmbedTokens nn.EmbeddingLayer Layers []Block Norm *nn.RMSNorm LMHead nn.LinearLayer tok *tokenizer.Tokenizer *Config } // computeScale computes the attention scale. func computeScale(cfg *Config) float32 { keyLength := cfg.QKNopeHeadDim + cfg.QKRopeHeadDim scale := float32(1.0 / math.Sqrt(float64(keyLength))) if cfg.RopeScaling != nil && cfg.RopeScaling.MscaleAllDim > 0 && cfg.RopeScaling.Factor > 1 { s := 0.1*cfg.RopeScaling.MscaleAllDim*float32(math.Log(float64(cfg.RopeScaling.Factor))) + 1.0 scale *= s * s } return scale } // supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support. func supportsGatherQMM(mode string, bits int) bool { return mode == "affine" && (bits == 4 || bits == 8) } // ExpertWeight holds a single expert's weight with optional quantization components. type ExpertWeight struct { Weight *mlx.Array Scales *mlx.Array Biases *mlx.Array Bits int GroupSize int } // loadExpertWeight loads an expert weight from the tensor map. func loadExpertWeight(tensors map[string]*mlx.Array, path string, useQuantized bool, cfg *Config) *ExpertWeight { w := tensors[path+".weight"] if w == nil { return nil } scales := tensors[path+".weight_scale"] if scales != nil { qbiases := tensors[path+".weight_qbias"] groupSize, bits, mode := model.ResolveLinearQuantParams( cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant, path+".weight", w, scales, ) if useQuantized && supportsGatherQMM(mode, bits) { return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize} } return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)} } return &ExpertWeight{Weight: w} } // StackedExpertWeights holds stacked weights for all experts. type StackedExpertWeights struct { Weight *mlx.Array Scales *mlx.Array Biases *mlx.Array Bits int GroupSize int } // collectAndStackExpertWeights loads and stacks expert weights for one projection type. func collectAndStackExpertWeights( tensors map[string]*mlx.Array, prefix string, projName string, numExperts int32, useQuantized bool, cfg *Config, ) *StackedExpertWeights { var w, s, b []*mlx.Array var bits, groupSize int for e := int32(0); e < numExperts; e++ { path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName) ew := loadExpertWeight(tensors, path, useQuantized, cfg) if ew == nil { continue } w = append(w, ew.Weight) if ew.Scales != nil { s = append(s, ew.Scales) } if ew.Biases != nil { b = append(b, ew.Biases) } if e == 0 { bits = ew.Bits groupSize = ew.GroupSize } } result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize} if len(w) > 0 { result.Weight = mlx.Stack(w, 0) if len(s) > 0 { result.Scales = mlx.Stack(s, 0) } if len(b) > 0 { result.Biases = mlx.Stack(b, 0) } } return result } // sanitizeExpertWeights stacks individual expert weights into tensors. func sanitizeExpertWeights(tensors map[string]*mlx.Array, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) { gate = collectAndStackExpertWeights(tensors, prefix, "gate_proj", numExperts, useQuantized, cfg) up = collectAndStackExpertWeights(tensors, prefix, "up_proj", numExperts, useQuantized, cfg) down = collectAndStackExpertWeights(tensors, prefix, "down_proj", numExperts, useQuantized, cfg) return gate, up, down } // sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format. func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Config) (*mlx.Array, *mlx.Array) { path := prefix + ".self_attn.kv_b_proj" w := tensors[path+".weight"] if w == nil { return nil, nil } // Check if quantized and dequantize if scales := tensors[path+".weight_scale"]; scales != nil { qbiases := tensors[path+".weight_qbias"] groupSize, bits, mode := model.ResolveLinearQuantParams( cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant, path+".weight", w, scales, ) w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode) } headDim := cfg.QKNopeHeadDim + cfg.VHeadDim w = mlx.Reshape(w, cfg.NumAttentionHeads, headDim, cfg.KVLoraRank) wk := mlx.SliceStartStop(w, []int32{0, 0, 0}, []int32{cfg.NumAttentionHeads, cfg.QKNopeHeadDim, cfg.KVLoraRank}) wv := mlx.SliceStartStop(w, []int32{0, cfg.QKNopeHeadDim, 0}, []int32{cfg.NumAttentionHeads, headDim, cfg.KVLoraRank}) embedQ := mlx.Transpose(wk, 0, 2, 1) unembedOut := wv return embedQ, unembedOut } // newModel creates a new GLM4-MoE-Lite model from a Root (config + tokenizer, // no weights loaded yet). Called by the registry via base.New(). func newModel(root *model.Root) (base.Model, error) { configData, err := root.Manifest.ReadConfig("config.json") if err != nil { return nil, fmt.Errorf("load config: %w", err) } var cfg Config if err := json.Unmarshal(configData, &cfg); err != nil { return nil, fmt.Errorf("parse config: %w", err) } cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim cfg.Scale = computeScale(&cfg) // Set up quantization parameters from pre-scanned metadata if qt := root.QuantType(); qt != "" { cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt) if gs := root.GroupSize(); gs > 0 { cfg.QuantGroupSize = gs } } else { cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("") } cfg.TensorQuant = root.AllTensorQuant() // Load tokenizer tokData, err := root.Manifest.ReadConfig("tokenizer.json") if err != nil { return nil, fmt.Errorf("load tokenizer config: %w", err) } tokConfig := &tokenizer.TokenizerConfig{ ConfigJSON: configData, } if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil { tokConfig.GenerationConfigJSON = genConfigData } if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil { tokConfig.TokenizerConfigJSON = tokConfigData } tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig) if err != nil { return nil, fmt.Errorf("parse tokenizer: %w", err) } m := &Model{ Layers: make([]Block, cfg.NumHiddenLayers), Config: &cfg, tok: tok, } return m, nil } // LoadWeights receives all tensors loaded from the manifest and assigns them // to model fields. Handles MLA absorption, expert stacking, and quantized // layer creation. func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { cfg := m.Config linears := model.NewLinearFactory(tensors, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant) useQuantized := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits) if !useQuantized && cfg.TensorQuant != nil { for _, tq := range cfg.TensorQuant { if tq == nil { continue } _, bits, mode := model.QuantizationParams(tq.QuantType) if supportsGatherQMM(mode, bits) { useQuantized = true break } } } // Load embedding m.EmbedTokens = model.MakeEmbeddingLayer(tensors, "model.embed_tokens", cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant) // Load final norm if w := tensors["model.norm.weight"]; w != nil { m.Norm = nn.NewRMSNorm(w, cfg.RMSNormEps) } // Load LM head m.LMHead = linears.Make("lm_head") // Load layers for i := int32(0); i < cfg.NumHiddenLayers; i++ { prefix := fmt.Sprintf("model.layers.%d", i) // Load attention (same for both block types) attn := &MLAAttention{} attn.QAProj = linears.Make(prefix + ".self_attn.q_a_proj") if w := tensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil { attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps) } attn.QBProj = linears.Make(prefix + ".self_attn.q_b_proj") attn.KVAProjWithMQA = linears.Make(prefix + ".self_attn.kv_a_proj_with_mqa") if w := tensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil { attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps) } attn.OProj = linears.Make(prefix + ".self_attn.o_proj") // Sanitize MLA weights for absorbed attention embedQ, unembedOut := sanitizeMLAWeights(tensors, prefix, cfg) attn.EmbedQ = nn.NewMultiLinear(embedQ) attn.UnembedOut = nn.NewMultiLinear(unembedOut) inputLN := tensors[prefix+".input_layernorm.weight"] postAttnLN := tensors[prefix+".post_attention_layernorm.weight"] if i < cfg.FirstKDenseReplace { // Dense block block := &DenseBlock{Attention: attn} if inputLN != nil { block.InputLayerNorm = nn.NewRMSNorm(inputLN, cfg.RMSNormEps) } if postAttnLN != nil { block.PostAttentionLayerNorm = nn.NewRMSNorm(postAttnLN, cfg.RMSNormEps) } block.MLP = &DenseMLP{ GateProj: linears.Make(prefix + ".mlp.gate_proj"), UpProj: linears.Make(prefix + ".mlp.up_proj"), DownProj: linears.Make(prefix + ".mlp.down_proj"), } m.Layers[i] = block } else { // MoE block block := &MoEBlock{Attention: attn} if inputLN != nil { block.InputLayerNorm = nn.NewRMSNorm(inputLN, cfg.RMSNormEps) } if postAttnLN != nil { block.PostAttentionLayerNorm = nn.NewRMSNorm(postAttnLN, cfg.RMSNormEps) } // Stack expert weights gate, up, down := sanitizeExpertWeights(tensors, prefix, cfg.NRoutedExperts, useQuantized, cfg) switchMLP := &SwitchMLP{UseQuantized: useQuantized} if useQuantized { switchMLP.GateWeightQ = gate.Weight switchMLP.GateScales = gate.Scales switchMLP.GateBiases = gate.Biases switchMLP.GateBits = gate.Bits switchMLP.GateGroupSize = gate.GroupSize switchMLP.UpWeightQ = up.Weight switchMLP.UpScales = up.Scales switchMLP.UpBiases = up.Biases switchMLP.UpBits = up.Bits switchMLP.UpGroupSize = up.GroupSize switchMLP.DownWeightQ = down.Weight switchMLP.DownScales = down.Scales switchMLP.DownBiases = down.Biases switchMLP.DownBits = down.Bits switchMLP.DownGroupSize = down.GroupSize } else { switchMLP.GateWeight = gate.Weight switchMLP.UpWeight = up.Weight switchMLP.DownWeight = down.Weight } moeGate := &MoEGate{} moeGate.Gate = linears.Make(prefix + ".mlp.gate") if bias := tensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil { moeGate.EScoreCorrectionBias = bias } block.MoE = &MoE{ Gate: moeGate, SwitchMLP: switchMLP, } // Load shared experts if present if cfg.NSharedExperts > 0 { block.MoE.SharedExperts = &SharedExperts{ GateProj: linears.Make(prefix + ".mlp.shared_experts.gate_proj"), UpProj: linears.Make(prefix + ".mlp.shared_experts.up_proj"), DownProj: linears.Make(prefix + ".mlp.shared_experts.down_proj"), } } m.Layers[i] = block } } return nil } // Forward computes the forward pass of the model func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { dims := tokens.Dims() B, L := int32(dims[0]), int32(dims[1]) h := m.EmbedTokens.Forward(tokens) for i, layer := range m.Layers { var c cache.Cache if caches != nil { c = caches[i] } h = layer.Forward(h, c, B, L, m.Config) } h = m.Norm.Forward(h, m.RMSNormEps) return h } // Unembed applies the LM head to get logits. func (m *Model) Unembed(x *mlx.Array) *mlx.Array { return m.LMHead.Forward(x) } // NumLayers returns the number of transformer layers func (m *Model) NumLayers() int { return len(m.Layers) } // MaxContextLength returns the maximum context length func (m *Model) MaxContextLength() int { return int(m.MaxPositionEmbeddings) } // VocabSize returns the vocabulary size func (m *Model) VocabSize() int32 { return m.Config.VocabSize } // Tokenizer returns the model's tokenizer func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok } // NewCache creates a new KV cache for the model func (m *Model) NewCache(maxSeqLen int32) []cache.Cache { caches := make([]cache.Cache, len(m.Layers)) for i := range caches { caches[i] = cache.NewKVCache() } return caches } // FormatPrompt applies the GLM-4 chat template with thinking enabled by default. func (m *Model) FormatPrompt(prompt string) string { return "[gMASK]<|user|>" + prompt + "<|assistant|>" } // FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control. func (m *Model) FormatPromptWithThinking(prompt string, think bool) string { if think { return "[gMASK]<|user|>" + prompt + "<|assistant|>" } return "[gMASK]<|user|>" + prompt + "<|assistant|>" } // NewRenderer returns a new Renderer for formatting multi-turn conversations. func (m *Model) NewRenderer() *Renderer { return &Renderer{} } // NewParser returns a new Parser for extracting thinking and tool calls from output. func (m *Model) NewParser() *Parser { return &Parser{} }