mirror of
https://github.com/ollama/ollama.git
synced 2026-04-20 07:54:25 +02:00
The previous approach tracked array lifecycles through reference counting, where each array recorded its inputs and a reference count that was decremented as dependents were freed. This is not really necessary as MLX tracks references internally. It is also error prone as it is easy to create new arrays and forget to free them when the Go variable goes out of scope. Instead, we can pin just the arrays we want (typically outputs and specific intermediates, like the cache). All other arrays are freed by default when we run sweep. This avoids most causes of memory leaks while still giving the freedom to save what we want.
775 lines
23 KiB
Go
775 lines
23 KiB
Go
//go:build mlx
|
|
|
|
// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX.
|
|
// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE).
|
|
package glm4_moe_lite
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
"github.com/ollama/ollama/x/mlxrunner/model"
|
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
|
"github.com/ollama/ollama/x/models/nn"
|
|
"github.com/ollama/ollama/x/tokenizer"
|
|
)
|
|
|
|
func init() {
|
|
base.Register("Glm4MoeLiteForCausalLM", newModel)
|
|
base.Register("GLM4MoeLite", newModel)
|
|
}
|
|
|
|
// RopeScaling holds RoPE scaling configuration
|
|
type RopeScaling struct {
|
|
Factor float32 `json:"factor"`
|
|
MscaleAllDim float32 `json:"mscale_all_dim"`
|
|
}
|
|
|
|
// Config holds GLM4-MoE-Lite model configuration
|
|
type Config struct {
|
|
HiddenSize int32 `json:"hidden_size"`
|
|
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
|
IntermediateSize int32 `json:"intermediate_size"`
|
|
MoEIntermediateSize int32 `json:"moe_intermediate_size"`
|
|
NumAttentionHeads int32 `json:"num_attention_heads"`
|
|
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
|
VocabSize int32 `json:"vocab_size"`
|
|
RMSNormEps float32 `json:"rms_norm_eps"`
|
|
RopeTheta float32 `json:"rope_theta"`
|
|
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
|
AttentionBias bool `json:"attention_bias"`
|
|
|
|
// MLA (Multi-head Latent Attention) parameters
|
|
QLoraRank int32 `json:"q_lora_rank"`
|
|
KVLoraRank int32 `json:"kv_lora_rank"`
|
|
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
|
|
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
|
|
VHeadDim int32 `json:"v_head_dim"`
|
|
|
|
// MoE parameters
|
|
NRoutedExperts int32 `json:"n_routed_experts"`
|
|
NSharedExperts int32 `json:"n_shared_experts"`
|
|
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
|
|
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
|
|
NormTopKProb bool `json:"norm_topk_prob"`
|
|
FirstKDenseReplace int32 `json:"first_k_dense_replace"`
|
|
NGroup int32 `json:"n_group"`
|
|
TopKGroup int32 `json:"topk_group"`
|
|
|
|
// RoPE scaling
|
|
RopeScaling *RopeScaling `json:"rope_scaling"`
|
|
|
|
// Quantization parameters (set during load based on model quantization)
|
|
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
|
|
QuantBits int `json:"-"` // Bits per weight (4 or 8)
|
|
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
|
|
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
|
|
|
|
// Computed fields
|
|
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
|
|
Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment
|
|
}
|
|
|
|
// MLAAttention implements Multi-head Latent Attention with absorption.
|
|
type MLAAttention struct {
|
|
QAProj nn.LinearLayer
|
|
QALayerNorm *nn.RMSNorm
|
|
QBProj nn.LinearLayer
|
|
|
|
KVAProjWithMQA nn.LinearLayer
|
|
KVALayerNorm *nn.RMSNorm
|
|
|
|
EmbedQ *nn.MultiLinear
|
|
UnembedOut *nn.MultiLinear
|
|
|
|
OProj nn.LinearLayer
|
|
}
|
|
|
|
// Forward computes absorbed MLA attention output.
|
|
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
|
q := a.QAProj.Forward(x)
|
|
q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
|
|
q = a.QBProj.Forward(q)
|
|
|
|
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim)
|
|
q = mlx.Transpose(q, 0, 2, 1, 3)
|
|
|
|
qNope := mlx.SliceStartStop(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
|
|
qPE := mlx.SliceStartStop(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim})
|
|
|
|
compressedKV := a.KVAProjWithMQA.Forward(x)
|
|
|
|
kvCompressed := mlx.SliceStartStop(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank})
|
|
kPE := mlx.SliceStartStop(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim})
|
|
|
|
kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim)
|
|
kPE = mlx.Transpose(kPE, 0, 2, 1, 3)
|
|
|
|
kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
|
|
kvLatent = mlx.ExpandDims(kvLatent, 1)
|
|
|
|
offset := 0
|
|
if c != nil {
|
|
offset = c.Offset()
|
|
}
|
|
qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
|
kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
|
|
|
qLatent := a.EmbedQ.Forward(qNope)
|
|
|
|
keys := mlx.Concatenate([]*mlx.Array{kvLatent, kPE}, 3)
|
|
|
|
cachedL := L
|
|
if c != nil {
|
|
placeholderValues := mlx.ZerosF32([]int32{B, 1, L, 0})
|
|
keys, _ = c.Update(keys, placeholderValues)
|
|
cachedL = int32(keys.Dim(2))
|
|
}
|
|
|
|
values := mlx.SliceStartStop(keys, []int32{0, 0, 0, 0}, []int32{B, 1, cachedL, cfg.KVLoraRank})
|
|
|
|
queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3)
|
|
|
|
out := mlx.ScaledDotProductAttentionCausal(queries, keys, values, cfg.Scale, L > 1)
|
|
out = a.UnembedOut.Forward(out)
|
|
|
|
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
|
|
|
|
return a.OProj.Forward(out)
|
|
}
|
|
|
|
// DenseMLP implements the standard SwiGLU MLP for dense layers
|
|
type DenseMLP struct {
|
|
GateProj nn.LinearLayer
|
|
UpProj nn.LinearLayer
|
|
DownProj nn.LinearLayer
|
|
}
|
|
|
|
// Forward applies the SwiGLU MLP
|
|
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
|
|
gate := mlx.SiLU(m.GateProj.Forward(x))
|
|
up := m.UpProj.Forward(x)
|
|
return m.DownProj.Forward(mlx.Mul(gate, up))
|
|
}
|
|
|
|
// MoEGate implements the expert gating mechanism
|
|
type MoEGate struct {
|
|
Gate nn.LinearLayer
|
|
EScoreCorrectionBias *mlx.Array
|
|
}
|
|
|
|
// Forward computes expert selection indices and scores
|
|
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
|
|
gates := g.Gate.Forward(x)
|
|
|
|
scores := mlx.Sigmoid(gates)
|
|
origScores := scores
|
|
|
|
if g.EScoreCorrectionBias != nil {
|
|
scores = mlx.Add(scores, g.EScoreCorrectionBias)
|
|
}
|
|
|
|
topK := cfg.NumExpertsPerTok
|
|
negScores := mlx.Neg(scores)
|
|
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
|
|
|
|
dims := inds.Dims()
|
|
inds = mlx.SliceStartStop(inds, []int32{0, 0, 0}, []int32{int32(dims[0]), int32(dims[1]), topK})
|
|
|
|
scores = mlx.TakeAlongAxis(origScores, inds, -1)
|
|
|
|
if topK > 1 && cfg.NormTopKProb {
|
|
sumScores := mlx.Sum(scores, -1, true)
|
|
scores = mlx.Div(scores, sumScores)
|
|
}
|
|
|
|
scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor)
|
|
|
|
return inds, scores
|
|
}
|
|
|
|
// SwitchMLP implements the MoE expert computation using stacked weights
|
|
type SwitchMLP struct {
|
|
GateWeight *mlx.Array
|
|
UpWeight *mlx.Array
|
|
DownWeight *mlx.Array
|
|
|
|
GateWeightQ, GateScales, GateBiases *mlx.Array
|
|
UpWeightQ, UpScales, UpBiases *mlx.Array
|
|
DownWeightQ, DownScales, DownBiases *mlx.Array
|
|
|
|
GateBits int
|
|
UpBits int
|
|
DownBits int
|
|
|
|
GateGroupSize int
|
|
UpGroupSize int
|
|
DownGroupSize int
|
|
|
|
UseQuantized bool
|
|
}
|
|
|
|
// Forward applies the switched expert MLP
|
|
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
|
|
dims := x.Dims()
|
|
B, L := int32(dims[0]), int32(dims[1])
|
|
topK := cfg.NumExpertsPerTok
|
|
|
|
xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2)
|
|
|
|
xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize)
|
|
|
|
idxFlat := mlx.Reshape(indices, B*L, topK)
|
|
|
|
doSort := B*L >= 64
|
|
var invOrder *mlx.Array
|
|
n := B * L * topK
|
|
|
|
if doSort {
|
|
idxAll := mlx.Flatten(idxFlat)
|
|
order := mlx.Argsort(idxAll, 0)
|
|
invOrder = mlx.Argsort(order, 0)
|
|
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1)
|
|
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
|
|
}
|
|
|
|
var gate, up, hidden, down *mlx.Array
|
|
|
|
if s.UseQuantized {
|
|
gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases,
|
|
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
|
|
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
|
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
|
|
|
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
|
|
|
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
|
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
|
} else {
|
|
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
|
|
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
|
|
|
|
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
|
|
|
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
|
|
}
|
|
|
|
if doSort {
|
|
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize)
|
|
} else {
|
|
down = mlx.Squeeze(down, 2)
|
|
}
|
|
|
|
return mlx.Reshape(down, B, L, topK, cfg.HiddenSize)
|
|
}
|
|
|
|
// SharedExperts implements the shared expert MLP
|
|
type SharedExperts struct {
|
|
GateProj nn.LinearLayer
|
|
UpProj nn.LinearLayer
|
|
DownProj nn.LinearLayer
|
|
}
|
|
|
|
// Forward applies the shared expert MLP
|
|
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
|
|
gate := mlx.SiLU(s.GateProj.Forward(x))
|
|
up := s.UpProj.Forward(x)
|
|
return s.DownProj.Forward(mlx.Mul(gate, up))
|
|
}
|
|
|
|
// MoE implements the full Mixture of Experts layer
|
|
type MoE struct {
|
|
Gate *MoEGate
|
|
SwitchMLP *SwitchMLP
|
|
SharedExperts *SharedExperts
|
|
}
|
|
|
|
// Forward applies the MoE layer
|
|
func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
|
|
dims := x.Dims()
|
|
B, L := int32(dims[0]), int32(dims[1])
|
|
|
|
inds, scores := m.Gate.Forward(x, cfg)
|
|
|
|
expertOut := m.SwitchMLP.Forward(x, inds, cfg)
|
|
|
|
scoresExpanded := mlx.ExpandDims(scores, -1)
|
|
y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false)
|
|
|
|
if m.SharedExperts != nil {
|
|
y = mlx.Add(y, m.SharedExperts.Forward(x))
|
|
}
|
|
|
|
return mlx.Reshape(y, B, L, cfg.HiddenSize)
|
|
}
|
|
|
|
// DenseBlock represents a dense transformer block (for first_k_dense_replace layers)
|
|
type DenseBlock struct {
|
|
Attention *MLAAttention
|
|
MLP *DenseMLP
|
|
InputLayerNorm *nn.RMSNorm
|
|
PostAttentionLayerNorm *nn.RMSNorm
|
|
}
|
|
|
|
// Forward applies the dense block
|
|
func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
|
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
|
h := mlx.Add(x, r)
|
|
|
|
r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
|
|
return mlx.Add(h, r)
|
|
}
|
|
|
|
// MoEBlock represents a MoE transformer block
|
|
type MoEBlock struct {
|
|
Attention *MLAAttention
|
|
MoE *MoE
|
|
InputLayerNorm *nn.RMSNorm
|
|
PostAttentionLayerNorm *nn.RMSNorm
|
|
}
|
|
|
|
// Forward applies the MoE block
|
|
func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
|
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
|
h := mlx.Add(x, r)
|
|
|
|
r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
|
|
return mlx.Add(h, r)
|
|
}
|
|
|
|
// Block interface for both dense and MoE blocks
|
|
type Block interface {
|
|
Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
|
|
}
|
|
|
|
// Model represents the complete GLM4-MoE-Lite model
|
|
type Model struct {
|
|
EmbedTokens *nn.Embedding
|
|
Layers []Block
|
|
Norm *nn.RMSNorm
|
|
LMHead nn.LinearLayer
|
|
|
|
tok *tokenizer.Tokenizer
|
|
*Config
|
|
}
|
|
|
|
// computeScale computes the attention scale.
|
|
func computeScale(cfg *Config) float32 {
|
|
keyLength := cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
|
|
scale := float32(1.0 / math.Sqrt(float64(keyLength)))
|
|
if cfg.RopeScaling != nil && cfg.RopeScaling.MscaleAllDim > 0 && cfg.RopeScaling.Factor > 1 {
|
|
s := 0.1*cfg.RopeScaling.MscaleAllDim*float32(math.Log(float64(cfg.RopeScaling.Factor))) + 1.0
|
|
scale *= s * s
|
|
}
|
|
return scale
|
|
}
|
|
|
|
// supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support.
|
|
func supportsGatherQMM(mode string, bits int) bool {
|
|
return mode == "affine" && (bits == 4 || bits == 8)
|
|
}
|
|
|
|
// ExpertWeight holds a single expert's weight with optional quantization components.
|
|
type ExpertWeight struct {
|
|
Weight *mlx.Array
|
|
Scales *mlx.Array
|
|
Biases *mlx.Array
|
|
Bits int
|
|
GroupSize int
|
|
}
|
|
|
|
// loadExpertWeight loads an expert weight from the tensor map.
|
|
func loadExpertWeight(tensors map[string]*mlx.Array, path string, useQuantized bool, cfg *Config) *ExpertWeight {
|
|
w := tensors[path+".weight"]
|
|
if w == nil {
|
|
return nil
|
|
}
|
|
|
|
scales := tensors[path+".weight_scale"]
|
|
if scales != nil {
|
|
qbiases := tensors[path+".weight_qbias"]
|
|
|
|
groupSize, bits, mode := model.ResolveLinearQuantParams(
|
|
cfg.QuantGroupSize,
|
|
cfg.QuantBits,
|
|
cfg.QuantMode,
|
|
cfg.TensorQuant,
|
|
path+".weight",
|
|
w,
|
|
scales,
|
|
)
|
|
|
|
if useQuantized && supportsGatherQMM(mode, bits) {
|
|
return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
|
|
}
|
|
|
|
return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)}
|
|
}
|
|
|
|
return &ExpertWeight{Weight: w}
|
|
}
|
|
|
|
// StackedExpertWeights holds stacked weights for all experts.
|
|
type StackedExpertWeights struct {
|
|
Weight *mlx.Array
|
|
Scales *mlx.Array
|
|
Biases *mlx.Array
|
|
Bits int
|
|
GroupSize int
|
|
}
|
|
|
|
// collectAndStackExpertWeights loads and stacks expert weights for one projection type.
|
|
func collectAndStackExpertWeights(
|
|
tensors map[string]*mlx.Array,
|
|
prefix string,
|
|
projName string,
|
|
numExperts int32,
|
|
useQuantized bool,
|
|
cfg *Config,
|
|
) *StackedExpertWeights {
|
|
var w, s, b []*mlx.Array
|
|
var bits, groupSize int
|
|
|
|
for e := int32(0); e < numExperts; e++ {
|
|
path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName)
|
|
ew := loadExpertWeight(tensors, path, useQuantized, cfg)
|
|
if ew == nil {
|
|
continue
|
|
}
|
|
w = append(w, ew.Weight)
|
|
if ew.Scales != nil {
|
|
s = append(s, ew.Scales)
|
|
}
|
|
if ew.Biases != nil {
|
|
b = append(b, ew.Biases)
|
|
}
|
|
if e == 0 {
|
|
bits = ew.Bits
|
|
groupSize = ew.GroupSize
|
|
}
|
|
}
|
|
|
|
result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize}
|
|
if len(w) > 0 {
|
|
result.Weight = mlx.Stack(w, 0)
|
|
if len(s) > 0 {
|
|
result.Scales = mlx.Stack(s, 0)
|
|
}
|
|
if len(b) > 0 {
|
|
result.Biases = mlx.Stack(b, 0)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// sanitizeExpertWeights stacks individual expert weights into tensors.
|
|
func sanitizeExpertWeights(tensors map[string]*mlx.Array, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) {
|
|
gate = collectAndStackExpertWeights(tensors, prefix, "gate_proj", numExperts, useQuantized, cfg)
|
|
up = collectAndStackExpertWeights(tensors, prefix, "up_proj", numExperts, useQuantized, cfg)
|
|
down = collectAndStackExpertWeights(tensors, prefix, "down_proj", numExperts, useQuantized, cfg)
|
|
return gate, up, down
|
|
}
|
|
|
|
// sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format.
|
|
func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Config) (*mlx.Array, *mlx.Array) {
|
|
path := prefix + ".self_attn.kv_b_proj"
|
|
w := tensors[path+".weight"]
|
|
if w == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
// Check if quantized and dequantize
|
|
if scales := tensors[path+".weight_scale"]; scales != nil {
|
|
qbiases := tensors[path+".weight_qbias"]
|
|
groupSize, bits, mode := model.ResolveLinearQuantParams(
|
|
cfg.QuantGroupSize,
|
|
cfg.QuantBits,
|
|
cfg.QuantMode,
|
|
cfg.TensorQuant,
|
|
path+".weight",
|
|
w,
|
|
scales,
|
|
)
|
|
w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
|
|
}
|
|
|
|
headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
|
|
w = mlx.Reshape(w, cfg.NumAttentionHeads, headDim, cfg.KVLoraRank)
|
|
|
|
wk := mlx.SliceStartStop(w, []int32{0, 0, 0}, []int32{cfg.NumAttentionHeads, cfg.QKNopeHeadDim, cfg.KVLoraRank})
|
|
wv := mlx.SliceStartStop(w, []int32{0, cfg.QKNopeHeadDim, 0}, []int32{cfg.NumAttentionHeads, headDim, cfg.KVLoraRank})
|
|
|
|
embedQ := mlx.Transpose(wk, 0, 2, 1)
|
|
unembedOut := wv
|
|
|
|
return embedQ, unembedOut
|
|
}
|
|
|
|
// newModel creates a new GLM4-MoE-Lite model from a Root (config + tokenizer,
|
|
// no weights loaded yet). Called by the registry via base.New().
|
|
func newModel(root *model.Root) (base.Model, error) {
|
|
configData, err := root.Manifest.ReadConfig("config.json")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load config: %w", err)
|
|
}
|
|
|
|
var cfg Config
|
|
if err := json.Unmarshal(configData, &cfg); err != nil {
|
|
return nil, fmt.Errorf("parse config: %w", err)
|
|
}
|
|
|
|
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
|
|
cfg.Scale = computeScale(&cfg)
|
|
|
|
// Set up quantization parameters from pre-scanned metadata
|
|
if qt := root.QuantType(); qt != "" {
|
|
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
|
if gs := root.GroupSize(); gs > 0 {
|
|
cfg.QuantGroupSize = gs
|
|
}
|
|
} else {
|
|
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
|
|
}
|
|
cfg.TensorQuant = root.AllTensorQuant()
|
|
|
|
// Load tokenizer
|
|
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
|
}
|
|
|
|
tokConfig := &tokenizer.TokenizerConfig{
|
|
ConfigJSON: configData,
|
|
}
|
|
|
|
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
|
|
tokConfig.GenerationConfigJSON = genConfigData
|
|
}
|
|
|
|
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
|
tokConfig.TokenizerConfigJSON = tokConfigData
|
|
}
|
|
|
|
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse tokenizer: %w", err)
|
|
}
|
|
|
|
m := &Model{
|
|
Layers: make([]Block, cfg.NumHiddenLayers),
|
|
Config: &cfg,
|
|
tok: tok,
|
|
}
|
|
|
|
return m, nil
|
|
}
|
|
|
|
// LoadWeights receives all tensors loaded from the manifest and assigns them
|
|
// to model fields. Handles MLA absorption, expert stacking, and quantized
|
|
// layer creation.
|
|
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
|
cfg := m.Config
|
|
linears := model.NewLinearFactory(tensors, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant)
|
|
useQuantized := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
|
|
if !useQuantized && cfg.TensorQuant != nil {
|
|
for _, tq := range cfg.TensorQuant {
|
|
if tq == nil {
|
|
continue
|
|
}
|
|
_, bits, mode := model.QuantizationParams(tq.QuantType)
|
|
if supportsGatherQMM(mode, bits) {
|
|
useQuantized = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// Load embedding
|
|
if w := tensors["model.embed_tokens.weight"]; w != nil {
|
|
m.EmbedTokens = nn.NewEmbedding(w)
|
|
}
|
|
|
|
// Load final norm
|
|
if w := tensors["model.norm.weight"]; w != nil {
|
|
m.Norm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
|
}
|
|
|
|
// Load LM head
|
|
m.LMHead = linears.Make("lm_head")
|
|
|
|
// Load layers
|
|
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
|
prefix := fmt.Sprintf("model.layers.%d", i)
|
|
|
|
// Load attention (same for both block types)
|
|
attn := &MLAAttention{}
|
|
attn.QAProj = linears.Make(prefix + ".self_attn.q_a_proj")
|
|
if w := tensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil {
|
|
attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
|
}
|
|
attn.QBProj = linears.Make(prefix + ".self_attn.q_b_proj")
|
|
attn.KVAProjWithMQA = linears.Make(prefix + ".self_attn.kv_a_proj_with_mqa")
|
|
if w := tensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil {
|
|
attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
|
}
|
|
attn.OProj = linears.Make(prefix + ".self_attn.o_proj")
|
|
|
|
// Sanitize MLA weights for absorbed attention
|
|
embedQ, unembedOut := sanitizeMLAWeights(tensors, prefix, cfg)
|
|
attn.EmbedQ = nn.NewMultiLinear(embedQ)
|
|
attn.UnembedOut = nn.NewMultiLinear(unembedOut)
|
|
|
|
inputLN := tensors[prefix+".input_layernorm.weight"]
|
|
postAttnLN := tensors[prefix+".post_attention_layernorm.weight"]
|
|
|
|
if i < cfg.FirstKDenseReplace {
|
|
// Dense block
|
|
block := &DenseBlock{Attention: attn}
|
|
if inputLN != nil {
|
|
block.InputLayerNorm = nn.NewRMSNorm(inputLN, cfg.RMSNormEps)
|
|
}
|
|
if postAttnLN != nil {
|
|
block.PostAttentionLayerNorm = nn.NewRMSNorm(postAttnLN, cfg.RMSNormEps)
|
|
}
|
|
|
|
block.MLP = &DenseMLP{
|
|
GateProj: linears.Make(prefix + ".mlp.gate_proj"),
|
|
UpProj: linears.Make(prefix + ".mlp.up_proj"),
|
|
DownProj: linears.Make(prefix + ".mlp.down_proj"),
|
|
}
|
|
|
|
m.Layers[i] = block
|
|
} else {
|
|
// MoE block
|
|
block := &MoEBlock{Attention: attn}
|
|
if inputLN != nil {
|
|
block.InputLayerNorm = nn.NewRMSNorm(inputLN, cfg.RMSNormEps)
|
|
}
|
|
if postAttnLN != nil {
|
|
block.PostAttentionLayerNorm = nn.NewRMSNorm(postAttnLN, cfg.RMSNormEps)
|
|
}
|
|
|
|
// Stack expert weights
|
|
gate, up, down := sanitizeExpertWeights(tensors, prefix, cfg.NRoutedExperts, useQuantized, cfg)
|
|
|
|
switchMLP := &SwitchMLP{UseQuantized: useQuantized}
|
|
if useQuantized {
|
|
switchMLP.GateWeightQ = gate.Weight
|
|
switchMLP.GateScales = gate.Scales
|
|
switchMLP.GateBiases = gate.Biases
|
|
switchMLP.GateBits = gate.Bits
|
|
switchMLP.GateGroupSize = gate.GroupSize
|
|
switchMLP.UpWeightQ = up.Weight
|
|
switchMLP.UpScales = up.Scales
|
|
switchMLP.UpBiases = up.Biases
|
|
switchMLP.UpBits = up.Bits
|
|
switchMLP.UpGroupSize = up.GroupSize
|
|
switchMLP.DownWeightQ = down.Weight
|
|
switchMLP.DownScales = down.Scales
|
|
switchMLP.DownBiases = down.Biases
|
|
switchMLP.DownBits = down.Bits
|
|
switchMLP.DownGroupSize = down.GroupSize
|
|
} else {
|
|
switchMLP.GateWeight = gate.Weight
|
|
switchMLP.UpWeight = up.Weight
|
|
switchMLP.DownWeight = down.Weight
|
|
}
|
|
|
|
moeGate := &MoEGate{}
|
|
moeGate.Gate = linears.Make(prefix + ".mlp.gate")
|
|
if bias := tensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil {
|
|
moeGate.EScoreCorrectionBias = bias
|
|
}
|
|
|
|
block.MoE = &MoE{
|
|
Gate: moeGate,
|
|
SwitchMLP: switchMLP,
|
|
}
|
|
|
|
// Load shared experts if present
|
|
if cfg.NSharedExperts > 0 {
|
|
block.MoE.SharedExperts = &SharedExperts{
|
|
GateProj: linears.Make(prefix + ".mlp.shared_experts.gate_proj"),
|
|
UpProj: linears.Make(prefix + ".mlp.shared_experts.up_proj"),
|
|
DownProj: linears.Make(prefix + ".mlp.shared_experts.down_proj"),
|
|
}
|
|
}
|
|
|
|
m.Layers[i] = block
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Forward computes the forward pass of the model
|
|
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
|
dims := tokens.Dims()
|
|
B, L := int32(dims[0]), int32(dims[1])
|
|
|
|
h := m.EmbedTokens.Forward(tokens)
|
|
|
|
for i, layer := range m.Layers {
|
|
var c cache.Cache
|
|
if caches != nil {
|
|
c = caches[i]
|
|
}
|
|
h = layer.Forward(h, c, B, L, m.Config)
|
|
}
|
|
|
|
h = m.Norm.Forward(h, m.RMSNormEps)
|
|
return h
|
|
}
|
|
|
|
// Unembed applies the LM head to get logits.
|
|
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
|
return m.LMHead.Forward(x)
|
|
}
|
|
|
|
// NumLayers returns the number of transformer layers
|
|
func (m *Model) NumLayers() int { return len(m.Layers) }
|
|
|
|
// MaxContextLength returns the maximum context length
|
|
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
|
|
|
|
// VocabSize returns the vocabulary size
|
|
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
|
|
|
|
// Tokenizer returns the model's tokenizer
|
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
|
|
|
|
// NewCache creates a new KV cache for the model
|
|
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
|
|
caches := make([]cache.Cache, len(m.Layers))
|
|
for i := range caches {
|
|
caches[i] = cache.NewKVCache()
|
|
}
|
|
return caches
|
|
}
|
|
|
|
// FormatPrompt applies the GLM-4 chat template with thinking enabled by default.
|
|
func (m *Model) FormatPrompt(prompt string) string {
|
|
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
|
|
}
|
|
|
|
// FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control.
|
|
func (m *Model) FormatPromptWithThinking(prompt string, think bool) string {
|
|
if think {
|
|
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
|
|
}
|
|
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|></think>"
|
|
}
|
|
|
|
// NewRenderer returns a new Renderer for formatting multi-turn conversations.
|
|
func (m *Model) NewRenderer() *Renderer {
|
|
return &Renderer{}
|
|
}
|
|
|
|
// NewParser returns a new Parser for extracting thinking and tool calls from output.
|
|
func (m *Model) NewParser() *Parser {
|
|
return &Parser{}
|
|
}
|