mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 22:54:05 +02:00
Converts SiLU/GELUApprox to compiled kernels and adds SwiGLU, matching upstream mlx/mlx_lm's activations pattern. Routes llama, qwen3, qwen3_5 (dense + MoE), and glm4_moe_lite MLP paths through mlx.SwiGLU so each MLP invocation runs as one fused Metal/CUDA kernel rather than a chain of per-op launches.
1393 lines
42 KiB
Go
1393 lines
42 KiB
Go
// Package qwen3_5 provides the Qwen 3.5 text and MoE implementation for MLX.
|
|
package qwen3_5
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"math"
|
|
"strings"
|
|
|
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
"github.com/ollama/ollama/x/mlxrunner/model"
|
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
|
"github.com/ollama/ollama/x/models/nn"
|
|
"github.com/ollama/ollama/x/tokenizer"
|
|
)
|
|
|
|
func init() {
|
|
base.Register("Qwen3_5ForCausalLM", NewModel)
|
|
base.Register("Qwen3_5ForConditionalGeneration", NewModel)
|
|
base.Register("Qwen3NextForCausalLM", NewModel)
|
|
base.Register("Qwen3NextForConditionalGeneration", NewModel)
|
|
}
|
|
|
|
// RopeParameters carries optional rope metadata embedded under rope_parameters.
|
|
type RopeParameters struct {
|
|
Type string `json:"type"`
|
|
RopeType string `json:"rope_type"`
|
|
RopeTheta float32 `json:"rope_theta"`
|
|
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
|
}
|
|
|
|
// Config holds Qwen 3.5 text config (top-level or nested text_config).
|
|
type Config struct {
|
|
ModelType string `json:"model_type"`
|
|
HiddenSize int32 `json:"hidden_size"`
|
|
IntermediateSize int32 `json:"intermediate_size"`
|
|
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
|
NumAttentionHeads int32 `json:"num_attention_heads"`
|
|
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
|
HeadDim int32 `json:"head_dim"`
|
|
RMSNormEps float32 `json:"rms_norm_eps"`
|
|
VocabSize int32 `json:"vocab_size"`
|
|
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
|
AttentionBias bool `json:"attention_bias"`
|
|
TieWordEmbeddings bool `json:"tie_word_embeddings"`
|
|
LayerTypes []string `json:"layer_types"`
|
|
|
|
FullAttentionInterval int32 `json:"full_attention_interval"`
|
|
|
|
LinearNumValueHeads int32 `json:"linear_num_value_heads"`
|
|
LinearNumKeyHeads int32 `json:"linear_num_key_heads"`
|
|
LinearKeyHeadDim int32 `json:"linear_key_head_dim"`
|
|
LinearValueHeadDim int32 `json:"linear_value_head_dim"`
|
|
LinearConvKernelDim int32 `json:"linear_conv_kernel_dim"`
|
|
DecoderSparseStep int32 `json:"decoder_sparse_step"`
|
|
SharedExpertGateRank int32 `json:"-"`
|
|
|
|
NumExperts int32 `json:"num_experts"`
|
|
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
|
|
SharedExpertIntermediateSize int32 `json:"shared_expert_intermediate_size"`
|
|
MoeIntermediateSize int32 `json:"moe_intermediate_size"`
|
|
NormTopKProb bool `json:"norm_topk_prob"`
|
|
MLPOnlyLayers []int32 `json:"mlp_only_layers"`
|
|
|
|
RopeTheta float32 `json:"rope_theta"`
|
|
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
|
RopeScaling map[string]any `json:"rope_scaling"`
|
|
RopeParameters *RopeParameters `json:"rope_parameters"`
|
|
|
|
// Quantization metadata.
|
|
QuantGroupSize int `json:"-"`
|
|
QuantBits int `json:"-"`
|
|
QuantMode string `json:"-"`
|
|
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
|
|
|
|
// Computed fields.
|
|
Scale float32 `json:"-"`
|
|
RopeDim int32 `json:"-"`
|
|
}
|
|
|
|
// Model is the Qwen 3.5 model.
|
|
type Model struct {
|
|
EmbedTokens nn.EmbeddingLayer
|
|
Layers []*Layer
|
|
Norm *nn.RMSNorm
|
|
LMHead nn.LinearLayer
|
|
|
|
tok *tokenizer.Tokenizer
|
|
*Config
|
|
|
|
weightPrefix string
|
|
}
|
|
|
|
// Layer is a transformer decoder layer.
|
|
type Layer struct {
|
|
InputNorm *nn.RMSNorm
|
|
PostAttentionNorm *nn.RMSNorm
|
|
|
|
IsLinear bool
|
|
FullAttn *FullAttention
|
|
Linear *GatedDeltaNet
|
|
MLP MLPBlock
|
|
}
|
|
|
|
// FullAttention is the full-attention branch used every N layers.
|
|
type FullAttention struct {
|
|
QProj nn.LinearLayer
|
|
KProj nn.LinearLayer
|
|
VProj nn.LinearLayer
|
|
OProj nn.LinearLayer
|
|
|
|
QNorm *nn.RMSNorm
|
|
KNorm *nn.RMSNorm
|
|
}
|
|
|
|
// GatedDeltaNet is the recurrent linear-attention branch.
|
|
type GatedDeltaNet struct {
|
|
InProjQKV nn.LinearLayer
|
|
InProjZ nn.LinearLayer
|
|
InProjB nn.LinearLayer
|
|
InProjA nn.LinearLayer
|
|
InProjQKVZ nn.LinearLayer
|
|
InProjBA nn.LinearLayer
|
|
OutProj nn.LinearLayer
|
|
|
|
Conv1D *nn.Conv1d
|
|
ConvWeight *mlx.Array
|
|
NormWeight *mlx.Array
|
|
DtBias *mlx.Array
|
|
ALog *mlx.Array
|
|
AExp *mlx.Array
|
|
}
|
|
|
|
// MLPBlock is the feed-forward interface for dense and MoE blocks.
|
|
type MLPBlock interface {
|
|
Forward(x *mlx.Array, cfg *Config) *mlx.Array
|
|
}
|
|
|
|
// DenseMLP is SwiGLU feed-forward.
|
|
type DenseMLP struct {
|
|
GateProj nn.LinearLayer
|
|
UpProj nn.LinearLayer
|
|
DownProj nn.LinearLayer
|
|
}
|
|
|
|
// SparseMoE is Qwen3.5's sparse MoE with shared expert.
|
|
type SparseMoE struct {
|
|
Gate nn.LinearLayer
|
|
SwitchMLP *SwitchMLP
|
|
SharedExpert *DenseMLP
|
|
SharedExpertGate nn.LinearLayer
|
|
}
|
|
|
|
// SwitchMLP executes selected expert MLPs.
|
|
type SwitchMLP struct {
|
|
GateWeight *mlx.Array
|
|
UpWeight *mlx.Array
|
|
DownWeight *mlx.Array
|
|
|
|
GateWeightQ, GateScales, GateBiases *mlx.Array
|
|
UpWeightQ, UpScales, UpBiases *mlx.Array
|
|
DownWeightQ, DownScales, DownBiases *mlx.Array
|
|
|
|
GateBits int
|
|
UpBits int
|
|
DownBits int
|
|
|
|
GateGroupSize int
|
|
UpGroupSize int
|
|
DownGroupSize int
|
|
|
|
UseQuantized bool
|
|
}
|
|
|
|
type stackedExpertWeights struct {
|
|
Weight *mlx.Array
|
|
Scales *mlx.Array
|
|
Biases *mlx.Array
|
|
Bits int
|
|
GroupSize int
|
|
Mode string
|
|
}
|
|
|
|
func parseConfig(configData []byte) (Config, error) {
|
|
var rawTop map[string]json.RawMessage
|
|
if err := json.Unmarshal(configData, &rawTop); err != nil {
|
|
return Config{}, fmt.Errorf("parse config envelope: %w", err)
|
|
}
|
|
|
|
var cfg Config
|
|
activeRaw := rawTop
|
|
if textRaw, ok := rawTop["text_config"]; ok {
|
|
if err := json.Unmarshal(textRaw, &cfg); err != nil {
|
|
return Config{}, fmt.Errorf("parse text_config: %w", err)
|
|
}
|
|
if err := json.Unmarshal(textRaw, &activeRaw); err != nil {
|
|
return Config{}, fmt.Errorf("parse text_config envelope: %w", err)
|
|
}
|
|
} else {
|
|
if err := json.Unmarshal(configData, &cfg); err != nil {
|
|
return Config{}, fmt.Errorf("parse config: %w", err)
|
|
}
|
|
}
|
|
|
|
if cfg.HiddenSize <= 0 {
|
|
return Config{}, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
|
|
}
|
|
if cfg.NumHiddenLayers <= 0 {
|
|
return Config{}, fmt.Errorf("invalid num_hidden_layers: %d", cfg.NumHiddenLayers)
|
|
}
|
|
if cfg.NumAttentionHeads <= 0 {
|
|
return Config{}, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
|
|
}
|
|
if cfg.NumKeyValueHeads <= 0 {
|
|
cfg.NumKeyValueHeads = cfg.NumAttentionHeads
|
|
}
|
|
if cfg.HeadDim <= 0 {
|
|
if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
|
|
return Config{}, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
|
|
}
|
|
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
|
}
|
|
if cfg.HeadDim <= 0 {
|
|
return Config{}, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
|
|
}
|
|
|
|
if cfg.RMSNormEps == 0 {
|
|
cfg.RMSNormEps = 1e-6
|
|
}
|
|
if cfg.LinearConvKernelDim <= 0 {
|
|
cfg.LinearConvKernelDim = 4
|
|
}
|
|
if cfg.LinearNumKeyHeads <= 0 || cfg.LinearNumValueHeads <= 0 || cfg.LinearKeyHeadDim <= 0 || cfg.LinearValueHeadDim <= 0 {
|
|
return Config{}, fmt.Errorf("invalid linear attention config (k_heads=%d v_heads=%d k_dim=%d v_dim=%d)",
|
|
cfg.LinearNumKeyHeads, cfg.LinearNumValueHeads, cfg.LinearKeyHeadDim, cfg.LinearValueHeadDim)
|
|
}
|
|
if cfg.LinearNumValueHeads%cfg.LinearNumKeyHeads != 0 {
|
|
return Config{}, fmt.Errorf("linear_num_value_heads (%d) must be divisible by linear_num_key_heads (%d)", cfg.LinearNumValueHeads, cfg.LinearNumKeyHeads)
|
|
}
|
|
|
|
if cfg.RopeParameters != nil {
|
|
if cfg.RopeParameters.RopeTheta > 0 {
|
|
cfg.RopeTheta = cfg.RopeParameters.RopeTheta
|
|
}
|
|
if cfg.RopeParameters.PartialRotaryFactor > 0 {
|
|
cfg.PartialRotaryFactor = cfg.RopeParameters.PartialRotaryFactor
|
|
}
|
|
}
|
|
if cfg.RopeTheta == 0 {
|
|
cfg.RopeTheta = 100000.0
|
|
}
|
|
if cfg.PartialRotaryFactor == 0 {
|
|
cfg.PartialRotaryFactor = 0.25
|
|
}
|
|
if cfg.PartialRotaryFactor < 0 {
|
|
cfg.PartialRotaryFactor = 0.25
|
|
}
|
|
ropeDim := int32(float32(cfg.HeadDim) * cfg.PartialRotaryFactor)
|
|
if ropeDim <= 0 {
|
|
ropeDim = cfg.HeadDim
|
|
}
|
|
if ropeDim > cfg.HeadDim {
|
|
ropeDim = cfg.HeadDim
|
|
}
|
|
cfg.RopeDim = ropeDim
|
|
|
|
if cfg.FullAttentionInterval <= 0 {
|
|
for i, lt := range cfg.LayerTypes {
|
|
if strings.Contains(strings.ToLower(lt), "full") {
|
|
cfg.FullAttentionInterval = int32(i + 1)
|
|
break
|
|
}
|
|
}
|
|
if cfg.FullAttentionInterval <= 0 {
|
|
cfg.FullAttentionInterval = 4
|
|
}
|
|
}
|
|
if cfg.FullAttentionInterval > cfg.NumHiddenLayers {
|
|
cfg.FullAttentionInterval = cfg.NumHiddenLayers
|
|
}
|
|
|
|
if cfg.NumExperts > 0 {
|
|
if cfg.NumExpertsPerTok <= 0 {
|
|
cfg.NumExpertsPerTok = 1
|
|
}
|
|
if cfg.MoeIntermediateSize <= 0 {
|
|
cfg.MoeIntermediateSize = cfg.IntermediateSize
|
|
}
|
|
if cfg.SharedExpertIntermediateSize <= 0 {
|
|
cfg.SharedExpertIntermediateSize = cfg.IntermediateSize
|
|
}
|
|
if _, ok := activeRaw["norm_topk_prob"]; !ok {
|
|
cfg.NormTopKProb = true
|
|
}
|
|
if cfg.DecoderSparseStep <= 0 {
|
|
cfg.DecoderSparseStep = 1
|
|
}
|
|
}
|
|
|
|
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
|
return cfg, nil
|
|
}
|
|
|
|
type tensorPathLayout struct {
|
|
containerPrefix string
|
|
modelPrefix string
|
|
}
|
|
|
|
func (l tensorPathLayout) modelPath(suffix string) string {
|
|
return l.containerPrefix + l.modelPrefix + suffix
|
|
}
|
|
|
|
func resolveTensorPathLayout(tensors map[string]*mlx.Array) tensorPathLayout {
|
|
for _, layout := range []tensorPathLayout{
|
|
{containerPrefix: "", modelPrefix: "model."},
|
|
{containerPrefix: "language_model.", modelPrefix: "model."},
|
|
{containerPrefix: "language_model.", modelPrefix: ""},
|
|
{containerPrefix: "model.language_model.", modelPrefix: "model."},
|
|
{containerPrefix: "model.language_model.", modelPrefix: ""},
|
|
} {
|
|
if tensors[layout.modelPath("embed_tokens.weight")] != nil {
|
|
return layout
|
|
}
|
|
}
|
|
|
|
return tensorPathLayout{modelPrefix: "model."}
|
|
}
|
|
|
|
func layerIsLinear(cfg *Config, layer int32) bool {
|
|
if len(cfg.LayerTypes) == int(cfg.NumHiddenLayers) {
|
|
t := strings.ToLower(cfg.LayerTypes[layer])
|
|
return !strings.Contains(t, "full")
|
|
}
|
|
if cfg.FullAttentionInterval <= 0 {
|
|
return true
|
|
}
|
|
return (layer+1)%cfg.FullAttentionInterval != 0
|
|
}
|
|
|
|
func layerUsesMoE(cfg *Config, layer int32) bool {
|
|
if cfg.NumExperts <= 0 {
|
|
return false
|
|
}
|
|
for _, l := range cfg.MLPOnlyLayers {
|
|
if l == layer {
|
|
return false
|
|
}
|
|
}
|
|
if cfg.DecoderSparseStep <= 1 {
|
|
return true
|
|
}
|
|
return (layer+1)%cfg.DecoderSparseStep == 0
|
|
}
|
|
|
|
// NewModel creates a Qwen 3.5 model from a manifest root.
|
|
func NewModel(root *model.Root) (base.Model, error) {
|
|
configData, err := root.Manifest.ReadConfig("config.json")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load config: %w", err)
|
|
}
|
|
|
|
cfg, err := parseConfig(configData)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if qt := root.QuantType(); qt != "" {
|
|
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
|
if gs := root.GroupSize(); gs > 0 {
|
|
cfg.QuantGroupSize = gs
|
|
}
|
|
} else {
|
|
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
|
|
}
|
|
cfg.TensorQuant = root.AllTensorQuant()
|
|
|
|
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
|
}
|
|
|
|
tokConfig := &tokenizer.TokenizerConfig{ConfigJSON: configData}
|
|
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
|
|
tokConfig.GenerationConfigJSON = genConfigData
|
|
}
|
|
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
|
tokConfig.TokenizerConfigJSON = tokConfigData
|
|
}
|
|
|
|
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse tokenizer: %w", err)
|
|
}
|
|
|
|
m := &Model{
|
|
Layers: make([]*Layer, cfg.NumHiddenLayers),
|
|
Config: &cfg,
|
|
tok: tok,
|
|
}
|
|
|
|
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
|
m.Layers[i] = &Layer{IsLinear: layerIsLinear(&cfg, i)}
|
|
}
|
|
|
|
return m, nil
|
|
}
|
|
|
|
func tensorAny(tensors map[string]*mlx.Array, keys ...string) (*mlx.Array, string) {
|
|
for _, k := range keys {
|
|
if v := tensors[k]; v != nil {
|
|
return v, k
|
|
}
|
|
}
|
|
return nil, ""
|
|
}
|
|
|
|
func tensorByBase(tensors map[string]*mlx.Array, base string) (*mlx.Array, string) {
|
|
return tensorAny(tensors, base+".weight", base)
|
|
}
|
|
|
|
func supportsGatherQMM(mode string, bits int) bool {
|
|
switch mode {
|
|
case "affine":
|
|
return bits == 4 || bits == 8
|
|
case "mxfp8":
|
|
return bits == 8
|
|
case "nvfp4", "mxfp4":
|
|
return bits == 4
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func freeTensorKeys(tensors map[string]*mlx.Array, keys ...string) {
|
|
for _, k := range keys {
|
|
if k == "" {
|
|
continue
|
|
}
|
|
if t := tensors[k]; t != nil {
|
|
delete(tensors, k)
|
|
}
|
|
}
|
|
}
|
|
|
|
func stackAndClone(parts []*mlx.Array) *mlx.Array {
|
|
if len(parts) == 0 {
|
|
return nil
|
|
}
|
|
stacked := mlx.Stack(parts, 0)
|
|
cloned := stacked.Clone()
|
|
mlx.Eval(cloned)
|
|
return cloned
|
|
}
|
|
|
|
func transposeExpertWeightForGatherMM(w *mlx.Array) *mlx.Array {
|
|
if w == nil || !w.Valid() || w.NumDims() != 3 {
|
|
return w
|
|
}
|
|
t := mlx.Transpose(w, 0, 2, 1)
|
|
cloned := t.Clone()
|
|
mlx.Eval(cloned)
|
|
return cloned
|
|
}
|
|
|
|
func describeMoEProjection(prefix string, w *stackedExpertWeights) string {
|
|
if w == nil {
|
|
return prefix + "=missing"
|
|
}
|
|
if w.Scales != nil {
|
|
return fmt.Sprintf("%s=qmm(mode=%s,bits=%d,gs=%d)", prefix, w.Mode, w.Bits, w.GroupSize)
|
|
}
|
|
if w.Bits > 0 || w.Mode != "" {
|
|
reason := "dequantized"
|
|
if !supportsGatherQMM(w.Mode, w.Bits) {
|
|
reason = "unsupported_gather_qmm"
|
|
}
|
|
return fmt.Sprintf("%s=%s(mode=%s,bits=%d,gs=%d)", prefix, reason, w.Mode, w.Bits, w.GroupSize)
|
|
}
|
|
return prefix + "=fp"
|
|
}
|
|
|
|
func summarizeMoEFallbackReason(gateW, upW, downW *stackedExpertWeights) string {
|
|
for _, w := range []*stackedExpertWeights{gateW, upW, downW} {
|
|
if w == nil {
|
|
return "missing_projection"
|
|
}
|
|
if w.Scales != nil {
|
|
continue
|
|
}
|
|
if w.Bits > 0 || w.Mode != "" {
|
|
if !supportsGatherQMM(w.Mode, w.Bits) {
|
|
return fmt.Sprintf("unsupported_gather_qmm(mode=%s,bits=%d)", w.Mode, w.Bits)
|
|
}
|
|
return "dequantized_quant_weights"
|
|
}
|
|
}
|
|
return "unquantized_weights"
|
|
}
|
|
|
|
func sliceStackedExpertAxis1(a *mlx.Array, start, stop int32) *mlx.Array {
|
|
if a == nil || !a.Valid() {
|
|
return nil
|
|
}
|
|
dims := a.Dims()
|
|
if len(dims) < 2 {
|
|
return nil
|
|
}
|
|
beg := make([]int32, len(dims))
|
|
end := make([]int32, len(dims))
|
|
for i, d := range dims {
|
|
end[i] = int32(d)
|
|
}
|
|
beg[1] = start
|
|
end[1] = stop
|
|
return mlx.SliceStartStop(a, beg, end)
|
|
}
|
|
|
|
func loadStackedProjection(tensors map[string]*mlx.Array, cfg *Config, useQuantized bool, bases ...string) *stackedExpertWeights {
|
|
for _, base := range bases {
|
|
w, key := tensorByBase(tensors, base)
|
|
if w == nil {
|
|
continue
|
|
}
|
|
|
|
scales := tensors[key+"_scale"]
|
|
if scales == nil {
|
|
return &stackedExpertWeights{Weight: w}
|
|
}
|
|
|
|
qbiases := tensors[key+"_qbias"]
|
|
groupSize, bits, mode := model.ResolveLinearQuantParams(
|
|
cfg.QuantGroupSize,
|
|
cfg.QuantBits,
|
|
cfg.QuantMode,
|
|
cfg.TensorQuant,
|
|
key,
|
|
w,
|
|
scales,
|
|
)
|
|
if useQuantized && supportsGatherQMM(mode, bits) {
|
|
return &stackedExpertWeights{
|
|
Weight: w,
|
|
Scales: scales,
|
|
Biases: qbiases,
|
|
Bits: bits,
|
|
GroupSize: groupSize,
|
|
Mode: mode,
|
|
}
|
|
}
|
|
|
|
return &stackedExpertWeights{
|
|
Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode),
|
|
Bits: bits,
|
|
GroupSize: groupSize,
|
|
Mode: mode,
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func collectPerExpertProjection(tensors map[string]*mlx.Array, cfg *Config, useQuantized bool, layerPrefix, proj string, numExperts int32) *stackedExpertWeights {
|
|
weights := make([]*mlx.Array, 0, numExperts)
|
|
scales := make([]*mlx.Array, 0, numExperts)
|
|
biases := make([]*mlx.Array, 0, numExperts)
|
|
consumedKeys := make([]string, 0, numExperts*3)
|
|
bits := 0
|
|
groupSize := 0
|
|
mode := cfg.QuantMode
|
|
|
|
for e := int32(0); e < numExperts; e++ {
|
|
base := fmt.Sprintf("%s.mlp.experts.%d.%s", layerPrefix, e, proj)
|
|
w, key := tensorByBase(tensors, base)
|
|
if w == nil {
|
|
continue
|
|
}
|
|
consumedKeys = append(consumedKeys, key)
|
|
|
|
s := tensors[key+"_scale"]
|
|
if s == nil {
|
|
weights = append(weights, w)
|
|
continue
|
|
}
|
|
consumedKeys = append(consumedKeys, key+"_scale")
|
|
qb := tensors[key+"_qbias"]
|
|
if qb != nil {
|
|
consumedKeys = append(consumedKeys, key+"_qbias")
|
|
}
|
|
gs, b, m := model.ResolveLinearQuantParams(
|
|
cfg.QuantGroupSize,
|
|
cfg.QuantBits,
|
|
cfg.QuantMode,
|
|
cfg.TensorQuant,
|
|
key,
|
|
w,
|
|
s,
|
|
)
|
|
if bits == 0 {
|
|
bits = b
|
|
groupSize = gs
|
|
mode = m
|
|
}
|
|
if useQuantized && supportsGatherQMM(m, b) {
|
|
weights = append(weights, w)
|
|
scales = append(scales, s)
|
|
if qb != nil {
|
|
biases = append(biases, qb)
|
|
}
|
|
} else {
|
|
weights = append(weights, mlx.Dequantize(w, s, qb, gs, b, m))
|
|
}
|
|
}
|
|
|
|
if len(weights) == 0 {
|
|
return nil
|
|
}
|
|
|
|
out := &stackedExpertWeights{Weight: stackAndClone(weights), Bits: bits, GroupSize: groupSize, Mode: mode}
|
|
if len(scales) == len(weights) {
|
|
out.Scales = stackAndClone(scales)
|
|
}
|
|
if len(biases) == len(weights) {
|
|
out.Biases = stackAndClone(biases)
|
|
}
|
|
freeTensorKeys(tensors, consumedKeys...)
|
|
return out
|
|
}
|
|
|
|
func splitGateUpProjection(tensors map[string]*mlx.Array, cfg *Config, useQuantized bool, layerPrefix string) (gate, up, down *stackedExpertWeights) {
|
|
gateUp, key := tensorAny(
|
|
tensors,
|
|
layerPrefix+".mlp.experts.gate_up_proj.weight",
|
|
layerPrefix+".mlp.experts.gate_up_proj",
|
|
)
|
|
if gateUp == nil {
|
|
return nil, nil, nil
|
|
}
|
|
|
|
if scales := tensors[key+"_scale"]; scales != nil {
|
|
qbiases := tensors[key+"_qbias"]
|
|
groupSize, bits, mode := model.ResolveLinearQuantParams(
|
|
cfg.QuantGroupSize,
|
|
cfg.QuantBits,
|
|
cfg.QuantMode,
|
|
cfg.TensorQuant,
|
|
key,
|
|
gateUp,
|
|
scales,
|
|
)
|
|
if useQuantized && supportsGatherQMM(mode, bits) {
|
|
gate = &stackedExpertWeights{
|
|
Bits: bits,
|
|
GroupSize: groupSize,
|
|
Mode: mode,
|
|
}
|
|
up = &stackedExpertWeights{
|
|
Bits: bits,
|
|
GroupSize: groupSize,
|
|
Mode: mode,
|
|
}
|
|
// Keep quantized packed tensor and split along the out-dim (axis=1).
|
|
// This assumes MLX quantization preserves the leading [experts, out, ...] layout.
|
|
if gateUp.NumDims() != 3 {
|
|
return nil, nil, nil
|
|
}
|
|
shape := gateUp.Dims()
|
|
nExperts, twoHidden, inHidden := int32(shape[0]), int32(shape[1]), int32(shape[2])
|
|
_ = nExperts
|
|
_ = inHidden
|
|
mid := twoHidden / 2
|
|
|
|
gate.Weight = sliceStackedExpertAxis1(gateUp, 0, mid)
|
|
up.Weight = sliceStackedExpertAxis1(gateUp, mid, twoHidden)
|
|
gate.Scales = sliceStackedExpertAxis1(scales, 0, mid)
|
|
up.Scales = sliceStackedExpertAxis1(scales, mid, twoHidden)
|
|
if qbiases != nil {
|
|
gate.Biases = sliceStackedExpertAxis1(qbiases, 0, mid)
|
|
up.Biases = sliceStackedExpertAxis1(qbiases, mid, twoHidden)
|
|
}
|
|
} else {
|
|
gateUp = mlx.Dequantize(gateUp, scales, qbiases, groupSize, bits, mode)
|
|
gate = &stackedExpertWeights{Bits: bits, GroupSize: groupSize, Mode: mode}
|
|
up = &stackedExpertWeights{Bits: bits, GroupSize: groupSize, Mode: mode}
|
|
}
|
|
}
|
|
|
|
if gateUp.NumDims() != 3 {
|
|
return nil, nil, nil
|
|
}
|
|
shape := gateUp.Dims()
|
|
nExperts, twoHidden, inHidden := int32(shape[0]), int32(shape[1]), int32(shape[2])
|
|
mid := twoHidden / 2
|
|
|
|
if gate == nil {
|
|
gate = &stackedExpertWeights{}
|
|
}
|
|
if up == nil {
|
|
up = &stackedExpertWeights{}
|
|
}
|
|
if gate.Weight == nil {
|
|
gate.Weight = mlx.SliceStartStop(gateUp, []int32{0, 0, 0}, []int32{nExperts, mid, inHidden})
|
|
}
|
|
if up.Weight == nil {
|
|
up.Weight = mlx.SliceStartStop(gateUp, []int32{0, mid, 0}, []int32{nExperts, twoHidden, inHidden})
|
|
}
|
|
|
|
downW, downKey := tensorAny(
|
|
tensors,
|
|
layerPrefix+".mlp.experts.down_proj.weight",
|
|
layerPrefix+".mlp.experts.down_proj",
|
|
)
|
|
if downW == nil {
|
|
return gate, up, nil
|
|
}
|
|
if scales := tensors[downKey+"_scale"]; scales != nil {
|
|
qbiases := tensors[downKey+"_qbias"]
|
|
groupSize, bits, mode := model.ResolveLinearQuantParams(
|
|
cfg.QuantGroupSize,
|
|
cfg.QuantBits,
|
|
cfg.QuantMode,
|
|
cfg.TensorQuant,
|
|
downKey,
|
|
downW,
|
|
scales,
|
|
)
|
|
if useQuantized && supportsGatherQMM(mode, bits) {
|
|
down = &stackedExpertWeights{
|
|
Weight: downW,
|
|
Scales: scales,
|
|
Biases: qbiases,
|
|
Bits: bits,
|
|
GroupSize: groupSize,
|
|
Mode: mode,
|
|
}
|
|
return gate, up, down
|
|
}
|
|
downW = mlx.Dequantize(downW, scales, qbiases, groupSize, bits, mode)
|
|
down = &stackedExpertWeights{Bits: bits, GroupSize: groupSize, Mode: mode}
|
|
}
|
|
if down == nil {
|
|
down = &stackedExpertWeights{}
|
|
}
|
|
down.Weight = downW
|
|
return gate, up, down
|
|
}
|
|
|
|
func sanitizeConvWeight(w *mlx.Array) *mlx.Array {
|
|
if w == nil {
|
|
return nil
|
|
}
|
|
if w.NumDims() == 3 {
|
|
if w.Dim(1) == 1 {
|
|
return mlx.Squeeze(w, 1)
|
|
}
|
|
if w.Dim(2) == 1 {
|
|
return mlx.Squeeze(w, 2)
|
|
}
|
|
}
|
|
return w
|
|
}
|
|
|
|
func depthwiseConv1dKernelWeight(w *mlx.Array) *mlx.Array {
|
|
if w == nil {
|
|
return nil
|
|
}
|
|
switch w.NumDims() {
|
|
case 2:
|
|
// qwen3.5 manual path stores [C, K]; MLX grouped conv expects [Cout, K, Cin/groups].
|
|
// For depthwise conv (groups=C), that is [C, K, 1].
|
|
return mlx.ExpandDims(w, 2)
|
|
case 3:
|
|
switch {
|
|
case w.Dim(2) == 1:
|
|
// [C, K, 1]
|
|
return w
|
|
case w.Dim(1) == 1:
|
|
// [C, 1, K] -> [C, K, 1]
|
|
return mlx.Transpose(w, 0, 2, 1)
|
|
case w.Dim(0) == 1:
|
|
// [1, K, C] -> [C, K, 1]
|
|
return mlx.Transpose(w, 2, 1, 0)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func shouldShiftNormKey(key string) bool {
|
|
for _, suffix := range []string{
|
|
".input_layernorm.weight",
|
|
".post_attention_layernorm.weight",
|
|
"model.norm.weight",
|
|
".self_attn.q_norm.weight",
|
|
".self_attn.k_norm.weight",
|
|
} {
|
|
if strings.HasSuffix(key, suffix) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func maybeShiftNormWeight(key string, w *mlx.Array, shouldShift bool) *mlx.Array {
|
|
if !shouldShift || w == nil || w.NumDims() != 1 || !shouldShiftNormKey(key) {
|
|
return w
|
|
}
|
|
return mlx.AddScalar(w, 1.0)
|
|
}
|
|
|
|
// LoadWeights assigns tensors to model fields.
|
|
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
|
layout := resolveTensorPathLayout(tensors)
|
|
m.weightPrefix = layout.containerPrefix
|
|
prefix := m.weightPrefix
|
|
modelPrefix := layout.containerPrefix + layout.modelPrefix
|
|
cfg := m.Config
|
|
|
|
linears := model.NewLinearFactory(tensors, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant)
|
|
|
|
shouldShiftNormWeights := false
|
|
mtpKeys := make([]string, 0)
|
|
for name, t := range tensors {
|
|
if strings.Contains(name, "mtp.") {
|
|
shouldShiftNormWeights = true
|
|
mtpKeys = append(mtpKeys, name)
|
|
continue
|
|
}
|
|
if !shouldShiftNormWeights && strings.Contains(name, ".linear_attn.conv1d.weight") && t != nil && t.NumDims() == 3 && t.Dim(2) != 1 {
|
|
shouldShiftNormWeights = true
|
|
}
|
|
}
|
|
if len(mtpKeys) > 0 {
|
|
freeTensorKeys(tensors, mtpKeys...)
|
|
}
|
|
|
|
embedTokens := model.MakeEmbeddingLayer(tensors, modelPrefix+"embed_tokens", cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant)
|
|
if embedTokens == nil {
|
|
return fmt.Errorf("missing embedding weight: %sembed_tokens.weight", modelPrefix)
|
|
}
|
|
m.EmbedTokens = embedTokens
|
|
|
|
normKey := modelPrefix + "norm.weight"
|
|
normWeight := maybeShiftNormWeight(normKey, tensors[normKey], shouldShiftNormWeights)
|
|
if normWeight == nil {
|
|
return fmt.Errorf("missing final norm weight: %snorm.weight", modelPrefix)
|
|
}
|
|
m.Norm = nn.NewRMSNorm(normWeight, cfg.RMSNormEps)
|
|
|
|
if cfg.TieWordEmbeddings {
|
|
m.LMHead = m.EmbedTokens.AsLinear()
|
|
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
|
|
m.LMHead = lmHead
|
|
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
|
|
m.LMHead = lmHead
|
|
} else {
|
|
m.LMHead = m.EmbedTokens.AsLinear()
|
|
}
|
|
|
|
useQuantizedExperts := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
|
|
if !useQuantizedExperts && cfg.TensorQuant != nil {
|
|
for _, tq := range cfg.TensorQuant {
|
|
if tq == nil {
|
|
continue
|
|
}
|
|
_, bits, mode := model.QuantizationParams(tq.QuantType)
|
|
if supportsGatherQMM(mode, bits) {
|
|
useQuantizedExperts = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
moeLoadSummaries := make([]string, 0)
|
|
|
|
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
|
layerPrefix := fmt.Sprintf("%slayers.%d", modelPrefix, i)
|
|
layer := &Layer{IsLinear: layerIsLinear(cfg, i)}
|
|
|
|
if w := maybeShiftNormWeight(layerPrefix+".input_layernorm.weight", tensors[layerPrefix+".input_layernorm.weight"], shouldShiftNormWeights); w != nil {
|
|
layer.InputNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
|
}
|
|
if w := maybeShiftNormWeight(layerPrefix+".post_attention_layernorm.weight", tensors[layerPrefix+".post_attention_layernorm.weight"], shouldShiftNormWeights); w != nil {
|
|
layer.PostAttentionNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
|
}
|
|
if layer.InputNorm == nil || layer.PostAttentionNorm == nil {
|
|
return fmt.Errorf("layer %d: missing layer norms", i)
|
|
}
|
|
|
|
if layer.IsLinear {
|
|
lin := &GatedDeltaNet{}
|
|
lin.InProjQKV = linears.Make(layerPrefix + ".linear_attn.in_proj_qkv")
|
|
lin.InProjZ = linears.Make(layerPrefix + ".linear_attn.in_proj_z")
|
|
lin.InProjB = linears.Make(layerPrefix + ".linear_attn.in_proj_b")
|
|
lin.InProjA = linears.Make(layerPrefix + ".linear_attn.in_proj_a")
|
|
lin.InProjQKVZ = linears.Make(layerPrefix + ".linear_attn.in_proj_qkvz")
|
|
lin.InProjBA = linears.Make(layerPrefix + ".linear_attn.in_proj_ba")
|
|
lin.OutProj = linears.Make(layerPrefix + ".linear_attn.out_proj")
|
|
|
|
lin.ConvWeight = sanitizeConvWeight(tensors[layerPrefix+".linear_attn.conv1d.weight"])
|
|
if lin.ConvWeight == nil {
|
|
lin.ConvWeight = sanitizeConvWeight(tensors[layerPrefix+".linear_attn.conv1d"])
|
|
}
|
|
lin.NormWeight, _ = tensorAny(tensors,
|
|
layerPrefix+".linear_attn.norm.weight",
|
|
layerPrefix+".linear_attn.norm",
|
|
)
|
|
lin.DtBias, _ = tensorAny(tensors,
|
|
layerPrefix+".linear_attn.dt_bias",
|
|
layerPrefix+".linear_attn.dt_proj",
|
|
)
|
|
lin.ALog, _ = tensorAny(tensors,
|
|
layerPrefix+".linear_attn.A_log",
|
|
layerPrefix+".linear_attn.a_log",
|
|
)
|
|
if lin.ALog != nil {
|
|
lin.AExp = mlx.Exp(lin.ALog.AsType(mlx.DTypeFloat32))
|
|
}
|
|
|
|
hasSplit := lin.InProjQKV != nil && lin.InProjZ != nil && lin.InProjB != nil && lin.InProjA != nil
|
|
hasCombined := lin.InProjQKVZ != nil && lin.InProjBA != nil
|
|
if (!hasSplit && !hasCombined) || lin.OutProj == nil {
|
|
return fmt.Errorf("layer %d: missing linear attention projections", i)
|
|
}
|
|
if lin.ConvWeight == nil || lin.NormWeight == nil || lin.DtBias == nil || lin.ALog == nil || lin.AExp == nil {
|
|
return fmt.Errorf("layer %d: missing linear attention state tensors", i)
|
|
}
|
|
if lin.ConvWeight.NumDims() != 2 {
|
|
return fmt.Errorf("layer %d: conv1d weight must be 2D after sanitization, got %dD", i, lin.ConvWeight.NumDims())
|
|
}
|
|
if convKernel := depthwiseConv1dKernelWeight(lin.ConvWeight); convKernel != nil {
|
|
lin.Conv1D = nn.NewConv1d(convKernel, nil, 1, 0, 1, int32(lin.ConvWeight.Dim(0)))
|
|
}
|
|
|
|
layer.Linear = lin
|
|
} else {
|
|
attn := &FullAttention{}
|
|
attn.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
|
|
attn.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
|
|
attn.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
|
|
attn.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
|
|
|
|
if w := maybeShiftNormWeight(layerPrefix+".self_attn.q_norm.weight", tensors[layerPrefix+".self_attn.q_norm.weight"], shouldShiftNormWeights); w != nil {
|
|
attn.QNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
|
}
|
|
if w := maybeShiftNormWeight(layerPrefix+".self_attn.k_norm.weight", tensors[layerPrefix+".self_attn.k_norm.weight"], shouldShiftNormWeights); w != nil {
|
|
attn.KNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
|
}
|
|
|
|
if attn.QProj == nil || attn.KProj == nil || attn.VProj == nil || attn.OProj == nil {
|
|
return fmt.Errorf("layer %d: missing full attention projections", i)
|
|
}
|
|
if attn.QNorm == nil || attn.KNorm == nil {
|
|
return fmt.Errorf("layer %d: missing full attention q/k norms", i)
|
|
}
|
|
layer.FullAttn = attn
|
|
}
|
|
|
|
if layerUsesMoE(cfg, i) {
|
|
moe := &SparseMoE{}
|
|
moe.Gate = linears.Make(layerPrefix + ".mlp.gate")
|
|
if moe.Gate == nil {
|
|
return fmt.Errorf("layer %d: missing moe gate", i)
|
|
}
|
|
|
|
gateW := loadStackedProjection(tensors, cfg, useQuantizedExperts,
|
|
layerPrefix+".mlp.switch_mlp.gate_proj",
|
|
layerPrefix+".mlp.experts.gate_proj",
|
|
)
|
|
upW := loadStackedProjection(tensors, cfg, useQuantizedExperts,
|
|
layerPrefix+".mlp.switch_mlp.up_proj",
|
|
layerPrefix+".mlp.experts.up_proj",
|
|
)
|
|
downW := loadStackedProjection(tensors, cfg, useQuantizedExperts,
|
|
layerPrefix+".mlp.switch_mlp.down_proj",
|
|
layerPrefix+".mlp.experts.down_proj",
|
|
)
|
|
if gateW == nil || upW == nil || downW == nil {
|
|
g2, u2, d2 := splitGateUpProjection(tensors, cfg, useQuantizedExperts, layerPrefix)
|
|
if gateW == nil {
|
|
gateW = g2
|
|
}
|
|
if upW == nil {
|
|
upW = u2
|
|
}
|
|
if downW == nil {
|
|
downW = d2
|
|
}
|
|
}
|
|
if gateW == nil || upW == nil || downW == nil {
|
|
gateW = collectPerExpertProjection(tensors, cfg, useQuantizedExperts, layerPrefix, "gate_proj", cfg.NumExperts)
|
|
upW = collectPerExpertProjection(tensors, cfg, useQuantizedExperts, layerPrefix, "up_proj", cfg.NumExperts)
|
|
downW = collectPerExpertProjection(tensors, cfg, useQuantizedExperts, layerPrefix, "down_proj", cfg.NumExperts)
|
|
}
|
|
|
|
if gateW == nil || upW == nil || downW == nil {
|
|
return fmt.Errorf("layer %d: missing switch expert weights", i)
|
|
}
|
|
|
|
switchMLP := &SwitchMLP{}
|
|
if gateW.Scales != nil && upW.Scales != nil && downW.Scales != nil {
|
|
switchMLP.UseQuantized = true
|
|
switchMLP.GateWeightQ = gateW.Weight
|
|
switchMLP.GateScales = gateW.Scales
|
|
switchMLP.GateBiases = gateW.Biases
|
|
switchMLP.GateBits = gateW.Bits
|
|
switchMLP.GateGroupSize = gateW.GroupSize
|
|
switchMLP.UpWeightQ = upW.Weight
|
|
switchMLP.UpScales = upW.Scales
|
|
switchMLP.UpBiases = upW.Biases
|
|
switchMLP.UpBits = upW.Bits
|
|
switchMLP.UpGroupSize = upW.GroupSize
|
|
switchMLP.DownWeightQ = downW.Weight
|
|
switchMLP.DownScales = downW.Scales
|
|
switchMLP.DownBiases = downW.Biases
|
|
switchMLP.DownBits = downW.Bits
|
|
switchMLP.DownGroupSize = downW.GroupSize
|
|
} else {
|
|
switchMLP.GateWeight = transposeExpertWeightForGatherMM(gateW.Weight)
|
|
switchMLP.UpWeight = transposeExpertWeightForGatherMM(upW.Weight)
|
|
switchMLP.DownWeight = transposeExpertWeightForGatherMM(downW.Weight)
|
|
moeLoadSummaries = append(moeLoadSummaries,
|
|
fmt.Sprintf(
|
|
"layer=%d moe_fallback reason=%s %s %s %s",
|
|
i,
|
|
summarizeMoEFallbackReason(gateW, upW, downW),
|
|
describeMoEProjection("gate", gateW),
|
|
describeMoEProjection("up", upW),
|
|
describeMoEProjection("down", downW),
|
|
),
|
|
)
|
|
}
|
|
if switchMLP.UseQuantized {
|
|
moeLoadSummaries = append(moeLoadSummaries,
|
|
fmt.Sprintf(
|
|
"layer=%d moe_quantized %s %s %s",
|
|
i,
|
|
describeMoEProjection("gate", gateW),
|
|
describeMoEProjection("up", upW),
|
|
describeMoEProjection("down", downW),
|
|
),
|
|
)
|
|
}
|
|
moe.SwitchMLP = switchMLP
|
|
|
|
sharedGateProj := linears.Make(layerPrefix + ".mlp.shared_expert.gate_proj")
|
|
sharedUpProj := linears.Make(layerPrefix + ".mlp.shared_expert.up_proj")
|
|
sharedDownProj := linears.Make(layerPrefix + ".mlp.shared_expert.down_proj")
|
|
if sharedGateProj != nil && sharedUpProj != nil && sharedDownProj != nil {
|
|
moe.SharedExpert = &DenseMLP{
|
|
GateProj: sharedGateProj,
|
|
UpProj: sharedUpProj,
|
|
DownProj: sharedDownProj,
|
|
}
|
|
moe.SharedExpertGate = linears.Make(layerPrefix + ".mlp.shared_expert_gate")
|
|
}
|
|
|
|
layer.MLP = moe
|
|
} else {
|
|
mlp := &DenseMLP{
|
|
GateProj: linears.Make(layerPrefix + ".mlp.gate_proj"),
|
|
UpProj: linears.Make(layerPrefix + ".mlp.up_proj"),
|
|
DownProj: linears.Make(layerPrefix + ".mlp.down_proj"),
|
|
}
|
|
if mlp.GateProj == nil || mlp.UpProj == nil || mlp.DownProj == nil {
|
|
return fmt.Errorf("layer %d: missing dense mlp projections", i)
|
|
}
|
|
layer.MLP = mlp
|
|
}
|
|
|
|
m.Layers[i] = layer
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func softplus(x *mlx.Array) *mlx.Array {
|
|
return mlx.Logaddexp(x, mlx.Zeros(x.DType(), x.Dims()...))
|
|
}
|
|
|
|
func depthwiseCausalConv1d(x, w *mlx.Array, outLen int32) *mlx.Array {
|
|
if x == nil || w == nil {
|
|
return nil
|
|
}
|
|
if w.NumDims() != 2 {
|
|
return nil
|
|
}
|
|
B := int32(x.Dim(0))
|
|
C := int32(w.Dim(0))
|
|
K := int32(w.Dim(1))
|
|
var out *mlx.Array
|
|
for i := int32(0); i < K; i++ {
|
|
seg := mlx.SliceStartStop(x, []int32{0, i, 0}, []int32{B, i + outLen, C})
|
|
wi := mlx.SliceStartStop(w, []int32{0, i}, []int32{C, i + 1})
|
|
wi = mlx.Reshape(wi, 1, 1, C)
|
|
term := mlx.Mul(seg, wi)
|
|
if out == nil {
|
|
out = term
|
|
} else {
|
|
out = mlx.Add(out, term)
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func splitQKVZBA(mixedQKVZ, mixedBA *mlx.Array, cfg *Config, B, L int32) (q, k, v, z, b, a *mlx.Array) {
|
|
nk := cfg.LinearNumKeyHeads
|
|
nv := cfg.LinearNumValueHeads
|
|
dk := cfg.LinearKeyHeadDim
|
|
dv := cfg.LinearValueHeadDim
|
|
vPerK := nv / nk
|
|
|
|
mixedQKVZ = mlx.Reshape(mixedQKVZ, B, L, nk, 2*dk+2*vPerK*dv)
|
|
q = mlx.SliceStartStop(mixedQKVZ, []int32{0, 0, 0, 0}, []int32{B, L, nk, dk})
|
|
k = mlx.SliceStartStop(mixedQKVZ, []int32{0, 0, 0, dk}, []int32{B, L, nk, 2 * dk})
|
|
v = mlx.SliceStartStop(mixedQKVZ, []int32{0, 0, 0, 2 * dk}, []int32{B, L, nk, 2*dk + vPerK*dv})
|
|
z = mlx.SliceStartStop(mixedQKVZ, []int32{0, 0, 0, 2*dk + vPerK*dv}, []int32{B, L, nk, 2*dk + 2*vPerK*dv})
|
|
|
|
v = mlx.Reshape(v, B, L, nv, dv)
|
|
z = mlx.Reshape(z, B, L, nv, dv)
|
|
|
|
mixedBA = mlx.Reshape(mixedBA, B, L, nk, 2*vPerK)
|
|
b = mlx.SliceStartStop(mixedBA, []int32{0, 0, 0, 0}, []int32{B, L, nk, vPerK})
|
|
a = mlx.SliceStartStop(mixedBA, []int32{0, 0, 0, vPerK}, []int32{B, L, nk, 2 * vPerK})
|
|
b = mlx.Reshape(b, B, L, nv)
|
|
a = mlx.Reshape(a, B, L, nv)
|
|
|
|
return q, k, v, z, b, a
|
|
}
|
|
|
|
func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
|
qg := a.QProj.Forward(x)
|
|
qg = mlx.Reshape(qg, B, L, cfg.NumAttentionHeads, cfg.HeadDim*2)
|
|
q := mlx.SliceStartStop(qg, []int32{0, 0, 0, 0}, []int32{B, L, cfg.NumAttentionHeads, cfg.HeadDim})
|
|
gate := mlx.SliceStartStop(qg, []int32{0, 0, 0, cfg.HeadDim}, []int32{B, L, cfg.NumAttentionHeads, cfg.HeadDim * 2})
|
|
gate = mlx.Reshape(gate, B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
|
|
|
k := a.KProj.Forward(x)
|
|
v := a.VProj.Forward(x)
|
|
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
|
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
|
|
|
q = a.QNorm.Forward(q, cfg.RMSNormEps)
|
|
k = a.KNorm.Forward(k, cfg.RMSNormEps)
|
|
|
|
q = mlx.Transpose(q, 0, 2, 1, 3)
|
|
k = mlx.Transpose(k, 0, 2, 1, 3)
|
|
v = mlx.Transpose(v, 0, 2, 1, 3)
|
|
|
|
offset := 0
|
|
if c != nil {
|
|
offset = c.Offset()
|
|
}
|
|
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
|
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
|
|
|
if c != nil {
|
|
k, v = c.Update(k, v)
|
|
}
|
|
|
|
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
|
|
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
|
gateSigmoid := mlx.Sigmoid(gate)
|
|
out = mlx.Mul(out, gateSigmoid)
|
|
out = a.OProj.Forward(out)
|
|
return out
|
|
}
|
|
|
|
func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
|
var qkv, z, b, a *mlx.Array
|
|
useSplitProj := g.InProjQKV != nil && g.InProjZ != nil && g.InProjB != nil && g.InProjA != nil
|
|
if useSplitProj {
|
|
qkv = g.InProjQKV.Forward(x)
|
|
z = g.InProjZ.Forward(x)
|
|
z = mlx.Reshape(z, B, L, cfg.LinearNumValueHeads, cfg.LinearValueHeadDim)
|
|
b = g.InProjB.Forward(x)
|
|
a = g.InProjA.Forward(x)
|
|
} else {
|
|
mixedQKVZ := g.InProjQKVZ.Forward(x)
|
|
mixedBA := g.InProjBA.Forward(x)
|
|
var q, k, v *mlx.Array
|
|
q, k, v, z, b, a = splitQKVZBA(mixedQKVZ, mixedBA, cfg, B, L)
|
|
qkv = mlx.Concatenate([]*mlx.Array{
|
|
mlx.Reshape(q, B, L, cfg.LinearNumKeyHeads*cfg.LinearKeyHeadDim),
|
|
mlx.Reshape(k, B, L, cfg.LinearNumKeyHeads*cfg.LinearKeyHeadDim),
|
|
mlx.Reshape(v, B, L, cfg.LinearNumValueHeads*cfg.LinearValueHeadDim),
|
|
}, -1)
|
|
}
|
|
convTail := cfg.LinearConvKernelDim - 1
|
|
var convState *mlx.Array
|
|
var rc *cache.RecurrentCache
|
|
if c != nil {
|
|
if typed, ok := c.(*cache.RecurrentCache); ok {
|
|
rc = typed
|
|
convState = rc.ConvState(int(B), x.DType())
|
|
}
|
|
}
|
|
if convState == nil {
|
|
convState = mlx.Zeros(x.DType(), int(B), int(convTail), int(2*cfg.LinearNumKeyHeads*cfg.LinearKeyHeadDim+cfg.LinearNumValueHeads*cfg.LinearValueHeadDim))
|
|
}
|
|
|
|
convInput := mlx.Concatenate([]*mlx.Array{convState, qkv}, 1)
|
|
var convOut *mlx.Array
|
|
if g.Conv1D != nil {
|
|
convOut = g.Conv1D.Forward(convInput)
|
|
} else {
|
|
convOut = depthwiseCausalConv1d(convInput, g.ConvWeight, L)
|
|
}
|
|
convOut = mlx.SiLU(convOut)
|
|
if rc != nil {
|
|
total := int32(convInput.Dim(1))
|
|
start := total - convTail
|
|
nextConv := mlx.SliceStartStop(convInput, []int32{0, start, 0}, []int32{B, total, int32(convInput.Dim(2))})
|
|
rc.SetConvState(nextConv)
|
|
}
|
|
|
|
keyDim := cfg.LinearNumKeyHeads * cfg.LinearKeyHeadDim
|
|
valueDim := cfg.LinearNumValueHeads * cfg.LinearValueHeadDim
|
|
q := mlx.SliceStartStop(convOut, []int32{0, 0, 0}, []int32{B, L, keyDim})
|
|
k := mlx.SliceStartStop(convOut, []int32{0, 0, keyDim}, []int32{B, L, 2 * keyDim})
|
|
v := mlx.SliceStartStop(convOut, []int32{0, 0, 2 * keyDim}, []int32{B, L, 2*keyDim + valueDim})
|
|
q = mlx.Reshape(q, B, L, cfg.LinearNumKeyHeads, cfg.LinearKeyHeadDim)
|
|
k = mlx.Reshape(k, B, L, cfg.LinearNumKeyHeads, cfg.LinearKeyHeadDim)
|
|
v = mlx.Reshape(v, B, L, cfg.LinearNumValueHeads, cfg.LinearValueHeadDim)
|
|
invScale := float32(1.0 / math.Sqrt(float64(cfg.LinearKeyHeadDim)))
|
|
q = mlx.MulScalar(mlx.RMSNormFn(q, nil, 1e-6), invScale*invScale)
|
|
k = mlx.MulScalar(mlx.RMSNormFn(k, nil, 1e-6), invScale)
|
|
|
|
gDecay := softplus(mlx.Add(a, g.DtBias))
|
|
gDecay = mlx.Mul(gDecay, g.AExp)
|
|
gDecay = mlx.Exp(mlx.MulScalar(gDecay, -1))
|
|
gDecay = gDecay.AsType(a.DType())
|
|
|
|
beta := mlx.Sigmoid(b)
|
|
|
|
var state *mlx.Array
|
|
if rc != nil {
|
|
state = rc.DeltaState(int(B), x.DType())
|
|
}
|
|
if state == nil {
|
|
state = mlx.Zeros(x.DType(), int(B), int(cfg.LinearNumValueHeads), int(cfg.LinearValueHeadDim), int(cfg.LinearKeyHeadDim))
|
|
}
|
|
|
|
out, state := mlx.GatedDelta(q, k, v, gDecay, beta, state)
|
|
outDType := out.DType()
|
|
out = mlx.RMSNormFn(out, g.NormWeight, cfg.RMSNormEps)
|
|
out = mlx.Mul(out.AsType(mlx.DTypeFloat32), mlx.SiLU(z.AsType(mlx.DTypeFloat32))).AsType(outDType)
|
|
out = mlx.Reshape(out, B, L, valueDim)
|
|
out = g.OutProj.Forward(out)
|
|
if rc != nil {
|
|
rc.SetDeltaState(state)
|
|
rc.Advance(int(L))
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (m *DenseMLP) Forward(x *mlx.Array, _ *Config) *mlx.Array {
|
|
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
|
|
}
|
|
|
|
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
|
|
dims := x.Dims()
|
|
B, L := int32(dims[0]), int32(dims[1])
|
|
topK := cfg.NumExpertsPerTok
|
|
|
|
xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2)
|
|
xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize)
|
|
idxFlat := mlx.Reshape(indices, B*L, topK)
|
|
|
|
doSort := B*L >= 64
|
|
var invOrder *mlx.Array
|
|
n := B * L * topK
|
|
|
|
if doSort {
|
|
idxAll := mlx.Flatten(idxFlat)
|
|
order := mlx.Argsort(idxAll, 0)
|
|
invOrder = mlx.Argsort(order, 0)
|
|
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1)
|
|
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
|
|
}
|
|
|
|
var gate, up, hidden, down *mlx.Array
|
|
if s.UseQuantized {
|
|
gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases,
|
|
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
|
|
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
|
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
|
hidden = mlx.SwiGLU(gate, up)
|
|
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
|
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
|
} else {
|
|
gate = mlx.GatherMM(xFlat, s.GateWeight, nil, idxFlat, doSort)
|
|
up = mlx.GatherMM(xFlat, s.UpWeight, nil, idxFlat, doSort)
|
|
hidden = mlx.SwiGLU(gate, up)
|
|
down = mlx.GatherMM(hidden, s.DownWeight, nil, idxFlat, doSort)
|
|
}
|
|
|
|
if doSort {
|
|
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize)
|
|
} else {
|
|
down = mlx.Squeeze(down, 2)
|
|
}
|
|
|
|
return mlx.Reshape(down, B, L, topK, cfg.HiddenSize)
|
|
}
|
|
|
|
func (m *SparseMoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
|
|
dims := x.Dims()
|
|
B, L := int32(dims[0]), int32(dims[1])
|
|
|
|
probs := mlx.SoftmaxAxis(m.Gate.Forward(x), -1, true)
|
|
neg := mlx.Neg(probs)
|
|
inds := mlx.Argpartition(neg, int(cfg.NumExpertsPerTok)-1, -1)
|
|
shape := inds.Dims()
|
|
inds = mlx.SliceStartStop(inds, []int32{0, 0, 0}, []int32{int32(shape[0]), int32(shape[1]), cfg.NumExpertsPerTok})
|
|
|
|
scores := mlx.TakeAlongAxis(probs, inds, -1)
|
|
if cfg.NormTopKProb && cfg.NumExpertsPerTok > 1 {
|
|
sumScores := mlx.Sum(scores, -1, true)
|
|
scores = mlx.Div(scores, sumScores)
|
|
}
|
|
|
|
expertOut := m.SwitchMLP.Forward(x, inds, cfg)
|
|
y := mlx.Sum(mlx.Mul(expertOut, mlx.ExpandDims(scores, -1)), 2, false)
|
|
|
|
if m.SharedExpert != nil {
|
|
shared := m.SharedExpert.Forward(x, cfg)
|
|
if m.SharedExpertGate != nil {
|
|
shared = mlx.Mul(shared, mlx.Sigmoid(m.SharedExpertGate.Forward(x)))
|
|
}
|
|
y = mlx.Add(y, shared)
|
|
}
|
|
|
|
return mlx.Reshape(y, B, L, cfg.HiddenSize)
|
|
}
|
|
|
|
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
|
var r *mlx.Array
|
|
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
|
|
if l.IsLinear {
|
|
r = l.Linear.Forward(normed, c, B, L, cfg)
|
|
} else {
|
|
r = l.FullAttn.Forward(normed, c, B, L, cfg)
|
|
}
|
|
h := mlx.Add(x, r)
|
|
r = l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps), cfg)
|
|
return mlx.Add(h, r)
|
|
}
|
|
|
|
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
|
dims := tokens.Dims()
|
|
B, L := int32(dims[0]), int32(dims[1])
|
|
|
|
h := m.EmbedTokens.Forward(tokens)
|
|
for i, layer := range m.Layers {
|
|
var c cache.Cache
|
|
if caches != nil && i < len(caches) {
|
|
c = caches[i]
|
|
}
|
|
h = layer.Forward(h, c, B, L, m.Config)
|
|
}
|
|
out := m.Norm.Forward(h, m.RMSNormEps)
|
|
return out
|
|
}
|
|
|
|
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
|
return m.LMHead.Forward(x)
|
|
}
|
|
|
|
func (m *Model) NumLayers() int {
|
|
return len(m.Layers)
|
|
}
|
|
|
|
func (m *Model) MaxContextLength() int {
|
|
return int(m.MaxPositionEmbeddings)
|
|
}
|
|
|
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
|
return m.tok
|
|
}
|
|
|
|
func (m *Model) NewCaches() []cache.Cache {
|
|
caches := make([]cache.Cache, len(m.Layers))
|
|
convTail := m.LinearConvKernelDim - 1
|
|
convDim := 2*m.LinearNumKeyHeads*m.LinearKeyHeadDim + m.LinearNumValueHeads*m.LinearValueHeadDim
|
|
for i, layer := range m.Layers {
|
|
if layer.IsLinear {
|
|
caches[i] = cache.NewRecurrentCache(convTail, convDim, m.LinearNumValueHeads, m.LinearValueHeadDim, m.LinearKeyHeadDim)
|
|
} else {
|
|
caches[i] = cache.NewKVCache()
|
|
}
|
|
}
|
|
return caches
|
|
}
|