mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 21:54:08 +02:00
mlxrunner: positions tensor and RoPEWithBase
RoPEWithBase now takes a positions *Array instead of a scalar offset. For a single contiguous sequence it dispatches to the scalar mlx_fast_rope — bit-identical to before. For multi-sequence (future) it uses mlx_fast_rope_dynamic. Models compute positions via batch.SequentialPositions and pass them through the layer stack.
This commit is contained in:
20
x/mlxrunner/batch/positions.go
Normal file
20
x/mlxrunner/batch/positions.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package batch
|
||||
|
||||
import "github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
|
||||
// SequentialPositions builds per-token sequential positions for all sequences
|
||||
// in the batch. Each sequence's positions start at its corresponding offset.
|
||||
//
|
||||
// offsets must have one entry per sequence (matching batch.SeqIDs), representing
|
||||
// the starting position for that sequence's new tokens (typically the cache offset).
|
||||
func SequentialPositions(b *ForwardBatch, offsets []int32) *mlx.Array {
|
||||
total := b.TotalLen()
|
||||
pos := make([]int32, 0, total)
|
||||
for i, seqLen := range b.SeqLens {
|
||||
offset := offsets[i]
|
||||
for j := range seqLen {
|
||||
pos = append(pos, offset+int32(j))
|
||||
}
|
||||
}
|
||||
return mlx.NewArrayInt32(pos, []int32{int32(total)})
|
||||
}
|
||||
@@ -322,12 +322,55 @@ func SiLU(a *Array) *Array {
|
||||
return a.Multiply(sig)
|
||||
}
|
||||
|
||||
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
||||
// RoPEWithBase applies rotary position embeddings using per-token positions.
|
||||
//
|
||||
// positions is an int32 tensor of per-token absolute positions. For a single
|
||||
// sequence with contiguous positions starting at offset, this dispatches to
|
||||
// the scalar mlx_fast_rope (bit-identical to the old offset-based API).
|
||||
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, positions *Array) *Array {
|
||||
posData := positions.Ints()
|
||||
if len(posData) == 0 {
|
||||
return x
|
||||
}
|
||||
|
||||
// Fast path: single contiguous run — use scalar mlx_fast_rope
|
||||
offset := posData[0]
|
||||
contiguous := true
|
||||
for i := 1; i < len(posData); i++ {
|
||||
if posData[i] != posData[i-1]+1 {
|
||||
contiguous = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if contiguous {
|
||||
freqs := New("")
|
||||
out := New("FAST_ROPE")
|
||||
C.mlx_fast_rope(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
C.int(dims),
|
||||
C.bool(traditional),
|
||||
C.mlx_optional_float{
|
||||
value: C.float(base),
|
||||
has_value: C.bool(func() bool { return base != 0 }()),
|
||||
},
|
||||
C.float(scale),
|
||||
C.int(offset),
|
||||
freqs.ctx,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
// Multi-sequence path: use mlx_fast_rope_dynamic with per-token offsets.
|
||||
// Transpose [1, H, L, D] → [L, H, 1, D] so L becomes batch dim,
|
||||
// apply dynamic rope with positions [L], transpose back.
|
||||
rotIn := Transpose(x, 2, 1, 0, 3)
|
||||
freqs := New("")
|
||||
out := New("FAST_ROPE")
|
||||
C.mlx_fast_rope(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
rotOut := New("FAST_ROPE_DYNAMIC")
|
||||
C.mlx_fast_rope_dynamic(
|
||||
&rotOut.ctx,
|
||||
rotIn.ctx,
|
||||
C.int(dims),
|
||||
C.bool(traditional),
|
||||
C.mlx_optional_float{
|
||||
@@ -335,11 +378,11 @@ func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, off
|
||||
has_value: C.bool(func() bool { return base != 0 }()),
|
||||
},
|
||||
C.float(scale),
|
||||
C.int(offset),
|
||||
positions.ctx,
|
||||
freqs.ctx,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
return Transpose(rotOut, 2, 1, 0, 3)
|
||||
}
|
||||
|
||||
func Sigmoid(a *Array) *Array {
|
||||
|
||||
@@ -415,7 +415,7 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array
|
||||
if caches != nil && i < len(caches) {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.TextConfig)
|
||||
h = layer.Forward(h, b, c, B, L, m.TextConfig)
|
||||
}
|
||||
|
||||
return mlx.RMSNormFn(h, m.NormScaled, m.RMSNormEps)
|
||||
@@ -455,10 +455,10 @@ func (m *Model) FormatPrompt(prompt string) string {
|
||||
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
||||
}
|
||||
|
||||
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
|
||||
func (l *DecoderLayer) Forward(x *mlx.Array, b *batch.ForwardBatch, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
|
||||
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||
|
||||
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
|
||||
attnOut := l.Attention.Forward(normed, b, c, B, L, l.IsSliding, cfg)
|
||||
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||
h := mlx.Add(x, attnOut)
|
||||
|
||||
@@ -470,7 +470,7 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
|
||||
return mlx.Add(h, mlpOut)
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
|
||||
func (a *Attention) Forward(x *mlx.Array, b *batch.ForwardBatch, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
@@ -492,12 +492,9 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
||||
ropeTheta = cfg.RopeLocalBaseFreq
|
||||
}
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = int(c.Offsets()[0])
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
|
||||
positions := batch.SequentialPositions(b, c.Offsets())
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, positions)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, positions)
|
||||
|
||||
if c != nil {
|
||||
k, v, _ = c.Update(nil, k, v)
|
||||
|
||||
@@ -88,7 +88,7 @@ type MLAAttention struct {
|
||||
}
|
||||
|
||||
// Forward computes absorbed MLA attention output.
|
||||
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
func (a *MLAAttention) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
q := a.QAProj.Forward(x)
|
||||
q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
|
||||
q = a.QBProj.Forward(q)
|
||||
@@ -110,12 +110,8 @@ func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Con
|
||||
kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
|
||||
kvLatent = mlx.ExpandDims(kvLatent, 1)
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = int(c.Offsets()[0])
|
||||
}
|
||||
qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
||||
kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
|
||||
qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, positions)
|
||||
kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, positions)
|
||||
|
||||
qLatent := a.EmbedQ.Forward(qNope)
|
||||
|
||||
@@ -314,8 +310,8 @@ type DenseBlock struct {
|
||||
}
|
||||
|
||||
// Forward applies the dense block
|
||||
func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
||||
func (b *DenseBlock) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), positions, c, B, L, cfg)
|
||||
h := mlx.Add(x, r)
|
||||
|
||||
r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
|
||||
@@ -331,8 +327,8 @@ type MoEBlock struct {
|
||||
}
|
||||
|
||||
// Forward applies the MoE block
|
||||
func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
|
||||
func (b *MoEBlock) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), positions, c, B, L, cfg)
|
||||
h := mlx.Add(x, r)
|
||||
|
||||
r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
|
||||
@@ -341,7 +337,7 @@ func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config)
|
||||
|
||||
// Block interface for both dense and MoE blocks
|
||||
type Block interface {
|
||||
Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
|
||||
Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
|
||||
}
|
||||
|
||||
// Model represents the complete GLM4-MoE-Lite model
|
||||
@@ -708,13 +704,14 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
positions := batch.SequentialPositions(b, caches[0].Offsets())
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
h = layer.Forward(h, positions, c, B, L, m.Config)
|
||||
}
|
||||
|
||||
h = m.Norm.Forward(h, m.RMSNormEps)
|
||||
|
||||
@@ -242,12 +242,14 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
positions := batch.SequentialPositions(b, caches[0].Offsets())
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil && i < len(caches) {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
h = layer.Forward(h, positions, c, B, L, m.Config)
|
||||
}
|
||||
|
||||
return m.Norm.Forward(h, m.RMSNormEps)
|
||||
@@ -277,12 +279,12 @@ func (m *Model) NewCaches() []cache.Cache {
|
||||
return caches
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
|
||||
func (l *Layer) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), positions, c, B, L, cfg))
|
||||
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
func (a *Attention) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
@@ -296,12 +298,8 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
|
||||
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = int(c.Offsets()[0])
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
|
||||
|
||||
if c != nil {
|
||||
k, v, _ = c.Update(nil, k, v)
|
||||
|
||||
@@ -259,12 +259,14 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
positions := batch.SequentialPositions(b, caches[0].Offsets())
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil && i < len(caches) {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
h = layer.Forward(h, positions, c, B, L, m.Config)
|
||||
}
|
||||
|
||||
return m.Norm.Forward(h, m.RMSNormEps)
|
||||
@@ -294,12 +296,12 @@ func (m *Model) NewCaches() []cache.Cache {
|
||||
return caches
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
|
||||
func (l *Layer) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), positions, c, B, L, cfg))
|
||||
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
func (a *Attention) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
@@ -315,12 +317,8 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = int(c.Offsets()[0])
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
|
||||
|
||||
if c != nil {
|
||||
k, v, _ = c.Update(nil, k, v)
|
||||
|
||||
@@ -1127,7 +1127,7 @@ func splitQKVZBA(mixedQKVZ, mixedBA *mlx.Array, cfg *Config, B, L int32) (q, k,
|
||||
return q, k, v, z, b, a
|
||||
}
|
||||
|
||||
func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
func (a *FullAttention) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
qg := a.QProj.Forward(x)
|
||||
qg = mlx.Reshape(qg, B, L, cfg.NumAttentionHeads, cfg.HeadDim*2)
|
||||
q := mlx.SliceStartStop(qg, []int32{0, 0, 0, 0}, []int32{B, L, cfg.NumAttentionHeads, cfg.HeadDim})
|
||||
@@ -1146,12 +1146,8 @@ func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
offset := 0
|
||||
if c != nil {
|
||||
offset = int(c.Offsets()[0])
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, positions)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, positions)
|
||||
|
||||
if c != nil {
|
||||
k, v, _ = c.Update(nil, k, v)
|
||||
@@ -1333,13 +1329,13 @@ func (m *SparseMoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
|
||||
return mlx.Reshape(y, B, L, cfg.HiddenSize)
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
func (l *Layer) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
var r *mlx.Array
|
||||
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
|
||||
if l.IsLinear {
|
||||
r = l.Linear.Forward(normed, c, B, L, cfg)
|
||||
} else {
|
||||
r = l.FullAttn.Forward(normed, c, B, L, cfg)
|
||||
r = l.FullAttn.Forward(normed, positions, c, B, L, cfg)
|
||||
}
|
||||
h := mlx.Add(x, r)
|
||||
r = l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps), cfg)
|
||||
@@ -1351,12 +1347,14 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
h := m.EmbedTokens.Forward(b.InputIDs)
|
||||
positions := batch.SequentialPositions(b, caches[0].Offsets())
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil && i < len(caches) {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
h = layer.Forward(h, positions, c, B, L, m.Config)
|
||||
}
|
||||
out := m.Norm.Forward(h, m.RMSNormEps)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user