mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 14:54:11 +02:00
Add QuantizedEmbedding and EmbeddingLayer interface so models can use quantized embedding weights and expose tied output projections. This change updates gemma3, glm4_moe_lite, llama, qwen3, and qwen3_5 to use the new interface.
255 lines
6.3 KiB
Go
255 lines
6.3 KiB
Go
package nn
|
|
|
|
import "github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
|
|
// Layer is the interface for neural network layers with a Forward method.
|
|
type Layer interface {
|
|
Forward(x *mlx.Array) *mlx.Array
|
|
}
|
|
|
|
// LinearLayer is an interface for linear layers (both regular and quantized).
|
|
type LinearLayer interface {
|
|
Forward(x *mlx.Array) *mlx.Array
|
|
OutputDim() int32
|
|
}
|
|
|
|
// EmbeddingLayer is an interface for embedding layers that can also expose a
|
|
// tied-output projection when the model reuses embedding weights as the LM head.
|
|
type EmbeddingLayer interface {
|
|
Forward(indices *mlx.Array) *mlx.Array
|
|
AsLinear() LinearLayer
|
|
}
|
|
|
|
// Conv1d applies 1D convolution over NLC input.
|
|
type Conv1d struct {
|
|
Weight *mlx.Array
|
|
Bias *mlx.Array
|
|
Stride int32
|
|
Padding int32
|
|
Dilation int32
|
|
Groups int32
|
|
}
|
|
|
|
func NewConv1d(weight, bias *mlx.Array, stride, padding, dilation, groups int32) *Conv1d {
|
|
if stride <= 0 {
|
|
stride = 1
|
|
}
|
|
if dilation <= 0 {
|
|
dilation = 1
|
|
}
|
|
if groups <= 0 {
|
|
groups = 1
|
|
}
|
|
return &Conv1d{
|
|
Weight: weight,
|
|
Bias: bias,
|
|
Stride: stride,
|
|
Padding: padding,
|
|
Dilation: dilation,
|
|
Groups: groups,
|
|
}
|
|
}
|
|
|
|
func (c *Conv1d) Forward(x *mlx.Array) *mlx.Array {
|
|
return mlx.Conv1d(x, c.Weight, c.Bias, c.Stride, c.Padding, c.Dilation, c.Groups)
|
|
}
|
|
|
|
// Linear applies an affine transformation: y = x @ W.T + b
|
|
type Linear struct {
|
|
Weight *mlx.Array
|
|
Bias *mlx.Array
|
|
}
|
|
|
|
func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
|
|
return &Linear{Weight: weight, Bias: bias}
|
|
}
|
|
|
|
func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
|
|
w := l.Weight.Transpose(1, 0)
|
|
if l.Bias != nil && l.Bias.Valid() {
|
|
return l.Bias.Addmm(x, w, 1.0, 1.0)
|
|
}
|
|
return x.Matmul(w)
|
|
}
|
|
|
|
func (l *Linear) OutputDim() int32 {
|
|
return int32(l.Weight.Dim(0))
|
|
}
|
|
|
|
// QuantizedLinear applies an affine transformation using quantized weights.
|
|
type QuantizedLinear struct {
|
|
Weight *mlx.Array // Quantized weight data
|
|
Scales *mlx.Array // Scale factors for dequantization
|
|
QBiases *mlx.Array // Quantization biases (nil for nvfp4)
|
|
Bias *mlx.Array // Layer bias [output_dims] or nil
|
|
GroupSize int
|
|
Bits int
|
|
Mode string
|
|
}
|
|
|
|
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
|
|
qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
|
|
if qbiases != nil {
|
|
mlx.Eval(qw, scales, qbiases)
|
|
} else {
|
|
mlx.Eval(qw, scales)
|
|
}
|
|
return &QuantizedLinear{
|
|
Weight: qw,
|
|
Scales: scales,
|
|
QBiases: qbiases,
|
|
Bias: bias,
|
|
GroupSize: groupSize,
|
|
Bits: bits,
|
|
Mode: mode,
|
|
}
|
|
}
|
|
|
|
func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
|
|
out := mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
|
|
if ql.Bias != nil && ql.Bias.Valid() {
|
|
out = out.Add(ql.Bias)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (ql *QuantizedLinear) OutputDim() int32 {
|
|
return int32(ql.Weight.Dim(0))
|
|
}
|
|
|
|
// RMSNorm represents an RMS normalization layer.
|
|
type RMSNorm struct {
|
|
Weight *mlx.Array
|
|
Eps float32
|
|
}
|
|
|
|
func NewRMSNorm(weight *mlx.Array, eps float32) *RMSNorm {
|
|
return &RMSNorm{Weight: weight, Eps: eps}
|
|
}
|
|
|
|
func (rn *RMSNorm) Forward(x *mlx.Array, eps float32) *mlx.Array {
|
|
if eps == 0 {
|
|
eps = rn.Eps
|
|
}
|
|
return mlx.RMSNormFn(x, rn.Weight, eps)
|
|
}
|
|
|
|
// Embedding represents an embedding layer.
|
|
type Embedding struct {
|
|
Weight *mlx.Array
|
|
}
|
|
|
|
func NewEmbedding(weight *mlx.Array) *Embedding {
|
|
return &Embedding{Weight: weight}
|
|
}
|
|
|
|
func (e *Embedding) Forward(indices *mlx.Array) *mlx.Array {
|
|
return e.Weight.TakeAxis(indices, 0)
|
|
}
|
|
|
|
func (e *Embedding) AsLinear() LinearLayer {
|
|
return NewLinear(e.Weight, nil)
|
|
}
|
|
|
|
// QuantizedEmbedding performs row-wise embedding lookup from affine/nvfp4/etc.
|
|
// packed weights and dequantizes only the selected rows.
|
|
type QuantizedEmbedding struct {
|
|
Weight *mlx.Array
|
|
Scales *mlx.Array
|
|
QBiases *mlx.Array
|
|
GroupSize int
|
|
Bits int
|
|
Mode string
|
|
}
|
|
|
|
func NewQuantizedEmbedding(weight, scales, qbiases *mlx.Array, groupSize, bits int, mode string) *QuantizedEmbedding {
|
|
return &QuantizedEmbedding{
|
|
Weight: weight,
|
|
Scales: scales,
|
|
QBiases: qbiases,
|
|
GroupSize: groupSize,
|
|
Bits: bits,
|
|
Mode: mode,
|
|
}
|
|
}
|
|
|
|
func (qe *QuantizedEmbedding) Forward(indices *mlx.Array) *mlx.Array {
|
|
weight := qe.Weight.TakeAxis(indices, 0)
|
|
scales := qe.Scales.TakeAxis(indices, 0)
|
|
var qbiases *mlx.Array
|
|
if qe.QBiases != nil && qe.QBiases.Valid() {
|
|
qbiases = qe.QBiases.TakeAxis(indices, 0)
|
|
}
|
|
return mlx.Dequantize(weight, scales, qbiases, qe.GroupSize, qe.Bits, qe.Mode)
|
|
}
|
|
|
|
func (qe *QuantizedEmbedding) AsLinear() LinearLayer {
|
|
return &QuantizedLinear{
|
|
Weight: qe.Weight,
|
|
Scales: qe.Scales,
|
|
QBiases: qe.QBiases,
|
|
GroupSize: qe.GroupSize,
|
|
Bits: qe.Bits,
|
|
Mode: qe.Mode,
|
|
}
|
|
}
|
|
|
|
// LayerNorm represents a standard layer normalization layer (with bias).
|
|
type LayerNorm struct {
|
|
Weight *mlx.Array
|
|
Bias *mlx.Array
|
|
Eps float32
|
|
}
|
|
|
|
func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
|
|
eps := ln.Eps
|
|
if eps == 0 {
|
|
eps = 1e-5
|
|
}
|
|
return mlx.LayerNormFn(x, ln.Weight, ln.Bias, eps)
|
|
}
|
|
|
|
// MultiLinearLayer is an interface for per-head linear layers.
|
|
type MultiLinearLayer interface {
|
|
Forward(x *mlx.Array) *mlx.Array
|
|
}
|
|
|
|
// MultiLinear performs per-head linear projections.
|
|
// Weight shape: [num_heads, output_dims, input_dims]
|
|
type MultiLinear struct {
|
|
Weight *mlx.Array
|
|
}
|
|
|
|
func NewMultiLinear(weight *mlx.Array) *MultiLinear {
|
|
return &MultiLinear{Weight: weight}
|
|
}
|
|
|
|
func (ml *MultiLinear) Forward(x *mlx.Array) *mlx.Array {
|
|
wT := ml.Weight.Transpose(0, 2, 1)
|
|
return x.Matmul(wT)
|
|
}
|
|
|
|
// ApplyCausalMask applies causal (lower triangular) mask to attention scores.
|
|
func ApplyCausalMask(scores *mlx.Array) *mlx.Array {
|
|
shape := scores.Dims()
|
|
seqLen := int32(shape[2])
|
|
mask := mlx.Tri(seqLen, seqLen, 0)
|
|
negInf := mlx.NewScalarArray(float32(-1e9))
|
|
mask = mask.ExpandDims(0).ExpandDims(0)
|
|
return mlx.Where(mask, scores, negInf)
|
|
}
|
|
|
|
// ApplyCausalMaskWithOffset applies causal mask for cached attention.
|
|
func ApplyCausalMaskWithOffset(scores *mlx.Array, offset int32) *mlx.Array {
|
|
if offset == 0 {
|
|
return ApplyCausalMask(scores)
|
|
}
|
|
shape := scores.Dims()
|
|
queryLen := int32(shape[2])
|
|
keyLen := int32(shape[3])
|
|
mask := mlx.Tri(queryLen, keyLen, int(offset))
|
|
negInf := mlx.NewScalarArray(float32(-1e9))
|
|
mask = mask.ExpandDims(0).ExpandDims(0)
|
|
return mlx.Where(mask, scores, negInf)
|
|
}
|