mlxrunner: positions tensor and RoPEWithBase

RoPEWithBase now takes a positions *Array instead of a scalar offset.
For a single contiguous sequence it dispatches to the scalar
mlx_fast_rope — bit-identical to before. For multi-sequence (future)
it uses mlx_fast_rope_dynamic.

Models compute positions via batch.SequentialPositions and pass them
through the layer stack.
This commit is contained in:
Jesse Gross
2026-04-02 12:08:20 -07:00
parent b7b2aa5d4e
commit 1ea8e70d94
7 changed files with 111 additions and 60 deletions

View File

@@ -0,0 +1,20 @@
package batch
import "github.com/ollama/ollama/x/mlxrunner/mlx"
// SequentialPositions builds per-token sequential positions for all sequences
// in the batch. Each sequence's positions start at its corresponding offset.
//
// offsets must have one entry per sequence (matching batch.SeqIDs), representing
// the starting position for that sequence's new tokens (typically the cache offset).
func SequentialPositions(b *ForwardBatch, offsets []int32) *mlx.Array {
total := b.TotalLen()
pos := make([]int32, 0, total)
for i, seqLen := range b.SeqLens {
offset := offsets[i]
for j := range seqLen {
pos = append(pos, offset+int32(j))
}
}
return mlx.NewArrayInt32(pos, []int32{int32(total)})
}

View File

@@ -322,12 +322,55 @@ func SiLU(a *Array) *Array {
return a.Multiply(sig)
}
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
// RoPEWithBase applies rotary position embeddings using per-token positions.
//
// positions is an int32 tensor of per-token absolute positions. For a single
// sequence with contiguous positions starting at offset, this dispatches to
// the scalar mlx_fast_rope (bit-identical to the old offset-based API).
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, positions *Array) *Array {
posData := positions.Ints()
if len(posData) == 0 {
return x
}
// Fast path: single contiguous run — use scalar mlx_fast_rope
offset := posData[0]
contiguous := true
for i := 1; i < len(posData); i++ {
if posData[i] != posData[i-1]+1 {
contiguous = false
break
}
}
if contiguous {
freqs := New("")
out := New("FAST_ROPE")
C.mlx_fast_rope(
&out.ctx,
x.ctx,
C.int(dims),
C.bool(traditional),
C.mlx_optional_float{
value: C.float(base),
has_value: C.bool(func() bool { return base != 0 }()),
},
C.float(scale),
C.int(offset),
freqs.ctx,
DefaultStream().ctx,
)
return out
}
// Multi-sequence path: use mlx_fast_rope_dynamic with per-token offsets.
// Transpose [1, H, L, D] → [L, H, 1, D] so L becomes batch dim,
// apply dynamic rope with positions [L], transpose back.
rotIn := Transpose(x, 2, 1, 0, 3)
freqs := New("")
out := New("FAST_ROPE")
C.mlx_fast_rope(
&out.ctx,
x.ctx,
rotOut := New("FAST_ROPE_DYNAMIC")
C.mlx_fast_rope_dynamic(
&rotOut.ctx,
rotIn.ctx,
C.int(dims),
C.bool(traditional),
C.mlx_optional_float{
@@ -335,11 +378,11 @@ func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, off
has_value: C.bool(func() bool { return base != 0 }()),
},
C.float(scale),
C.int(offset),
positions.ctx,
freqs.ctx,
DefaultStream().ctx,
)
return out
return Transpose(rotOut, 2, 1, 0, 3)
}
func Sigmoid(a *Array) *Array {

View File

@@ -415,7 +415,7 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.TextConfig)
h = layer.Forward(h, b, c, B, L, m.TextConfig)
}
return mlx.RMSNormFn(h, m.NormScaled, m.RMSNormEps)
@@ -455,10 +455,10 @@ func (m *Model) FormatPrompt(prompt string) string {
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
func (l *DecoderLayer) Forward(x *mlx.Array, b *batch.ForwardBatch, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
attnOut := l.Attention.Forward(normed, b, c, B, L, l.IsSliding, cfg)
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := mlx.Add(x, attnOut)
@@ -470,7 +470,7 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
return mlx.Add(h, mlpOut)
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
func (a *Attention) Forward(x *mlx.Array, b *batch.ForwardBatch, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
@@ -492,12 +492,9 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
ropeTheta = cfg.RopeLocalBaseFreq
}
offset := 0
if c != nil {
offset = int(c.Offsets()[0])
}
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
positions := batch.SequentialPositions(b, c.Offsets())
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, positions)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, positions)
if c != nil {
k, v, _ = c.Update(nil, k, v)

View File

@@ -88,7 +88,7 @@ type MLAAttention struct {
}
// Forward computes absorbed MLA attention output.
func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
func (a *MLAAttention) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
q := a.QAProj.Forward(x)
q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
q = a.QBProj.Forward(q)
@@ -110,12 +110,8 @@ func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Con
kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
kvLatent = mlx.ExpandDims(kvLatent, 1)
offset := 0
if c != nil {
offset = int(c.Offsets()[0])
}
qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, positions)
kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, positions)
qLatent := a.EmbedQ.Forward(qNope)
@@ -314,8 +310,8 @@ type DenseBlock struct {
}
// Forward applies the dense block
func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
func (b *DenseBlock) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), positions, c, B, L, cfg)
h := mlx.Add(x, r)
r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
@@ -331,8 +327,8 @@ type MoEBlock struct {
}
// Forward applies the MoE block
func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
func (b *MoEBlock) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), positions, c, B, L, cfg)
h := mlx.Add(x, r)
r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
@@ -341,7 +337,7 @@ func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config)
// Block interface for both dense and MoE blocks
type Block interface {
Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
}
// Model represents the complete GLM4-MoE-Lite model
@@ -708,13 +704,14 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(b.InputIDs)
positions := batch.SequentialPositions(b, caches[0].Offsets())
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
h = layer.Forward(h, positions, c, B, L, m.Config)
}
h = m.Norm.Forward(h, m.RMSNormEps)

View File

@@ -242,12 +242,14 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(b.InputIDs)
positions := batch.SequentialPositions(b, caches[0].Offsets())
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
h = layer.Forward(h, positions, c, B, L, m.Config)
}
return m.Norm.Forward(h, m.RMSNormEps)
@@ -277,12 +279,12 @@ func (m *Model) NewCaches() []cache.Cache {
return caches
}
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
func (l *Layer) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), positions, c, B, L, cfg))
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
func (a *Attention) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
@@ -296,12 +298,8 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
v = mlx.Transpose(v, 0, 2, 1, 3)
offset := 0
if c != nil {
offset = int(c.Offsets()[0])
}
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
if c != nil {
k, v, _ = c.Update(nil, k, v)

View File

@@ -259,12 +259,14 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(b.InputIDs)
positions := batch.SequentialPositions(b, caches[0].Offsets())
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
h = layer.Forward(h, positions, c, B, L, m.Config)
}
return m.Norm.Forward(h, m.RMSNormEps)
@@ -294,12 +296,12 @@ func (m *Model) NewCaches() []cache.Cache {
return caches
}
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
func (l *Layer) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), positions, c, B, L, cfg))
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
func (a *Attention) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
@@ -315,12 +317,8 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
offset := 0
if c != nil {
offset = int(c.Offsets()[0])
}
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
if c != nil {
k, v, _ = c.Update(nil, k, v)

View File

@@ -1127,7 +1127,7 @@ func splitQKVZBA(mixedQKVZ, mixedBA *mlx.Array, cfg *Config, B, L int32) (q, k,
return q, k, v, z, b, a
}
func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
func (a *FullAttention) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
qg := a.QProj.Forward(x)
qg = mlx.Reshape(qg, B, L, cfg.NumAttentionHeads, cfg.HeadDim*2)
q := mlx.SliceStartStop(qg, []int32{0, 0, 0, 0}, []int32{B, L, cfg.NumAttentionHeads, cfg.HeadDim})
@@ -1146,12 +1146,8 @@ func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
offset := 0
if c != nil {
offset = int(c.Offsets()[0])
}
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, positions)
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, positions)
if c != nil {
k, v, _ = c.Update(nil, k, v)
@@ -1333,13 +1329,13 @@ func (m *SparseMoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
return mlx.Reshape(y, B, L, cfg.HiddenSize)
}
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
func (l *Layer) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
var r *mlx.Array
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
if l.IsLinear {
r = l.Linear.Forward(normed, c, B, L, cfg)
} else {
r = l.FullAttn.Forward(normed, c, B, L, cfg)
r = l.FullAttn.Forward(normed, positions, c, B, L, cfg)
}
h := mlx.Add(x, r)
r = l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps), cfg)
@@ -1351,12 +1347,14 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(b.InputIDs)
positions := batch.SequentialPositions(b, caches[0].Offsets())
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
h = layer.Forward(h, positions, c, B, L, m.Config)
}
out := m.Norm.Forward(h, m.RMSNormEps)
return out