diff --git a/x/mlxrunner/batch/positions.go b/x/mlxrunner/batch/positions.go new file mode 100644 index 000000000..48d7afc6f --- /dev/null +++ b/x/mlxrunner/batch/positions.go @@ -0,0 +1,20 @@ +package batch + +import "github.com/ollama/ollama/x/mlxrunner/mlx" + +// SequentialPositions builds per-token sequential positions for all sequences +// in the batch. Each sequence's positions start at its corresponding offset. +// +// offsets must have one entry per sequence (matching batch.SeqIDs), representing +// the starting position for that sequence's new tokens (typically the cache offset). +func SequentialPositions(b *ForwardBatch, offsets []int32) *mlx.Array { + total := b.TotalLen() + pos := make([]int32, 0, total) + for i, seqLen := range b.SeqLens { + offset := offsets[i] + for j := range seqLen { + pos = append(pos, offset+int32(j)) + } + } + return mlx.NewArrayInt32(pos, []int32{int32(total)}) +} diff --git a/x/mlxrunner/mlx/ops_extra.go b/x/mlxrunner/mlx/ops_extra.go index 9de0037da..503565998 100644 --- a/x/mlxrunner/mlx/ops_extra.go +++ b/x/mlxrunner/mlx/ops_extra.go @@ -322,12 +322,55 @@ func SiLU(a *Array) *Array { return a.Multiply(sig) } -func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array { +// RoPEWithBase applies rotary position embeddings using per-token positions. +// +// positions is an int32 tensor of per-token absolute positions. For a single +// sequence with contiguous positions starting at offset, this dispatches to +// the scalar mlx_fast_rope (bit-identical to the old offset-based API). +func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, positions *Array) *Array { + posData := positions.Ints() + if len(posData) == 0 { + return x + } + + // Fast path: single contiguous run — use scalar mlx_fast_rope + offset := posData[0] + contiguous := true + for i := 1; i < len(posData); i++ { + if posData[i] != posData[i-1]+1 { + contiguous = false + break + } + } + if contiguous { + freqs := New("") + out := New("FAST_ROPE") + C.mlx_fast_rope( + &out.ctx, + x.ctx, + C.int(dims), + C.bool(traditional), + C.mlx_optional_float{ + value: C.float(base), + has_value: C.bool(func() bool { return base != 0 }()), + }, + C.float(scale), + C.int(offset), + freqs.ctx, + DefaultStream().ctx, + ) + return out + } + + // Multi-sequence path: use mlx_fast_rope_dynamic with per-token offsets. + // Transpose [1, H, L, D] → [L, H, 1, D] so L becomes batch dim, + // apply dynamic rope with positions [L], transpose back. + rotIn := Transpose(x, 2, 1, 0, 3) freqs := New("") - out := New("FAST_ROPE") - C.mlx_fast_rope( - &out.ctx, - x.ctx, + rotOut := New("FAST_ROPE_DYNAMIC") + C.mlx_fast_rope_dynamic( + &rotOut.ctx, + rotIn.ctx, C.int(dims), C.bool(traditional), C.mlx_optional_float{ @@ -335,11 +378,11 @@ func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, off has_value: C.bool(func() bool { return base != 0 }()), }, C.float(scale), - C.int(offset), + positions.ctx, freqs.ctx, DefaultStream().ctx, ) - return out + return Transpose(rotOut, 2, 1, 0, 3) } func Sigmoid(a *Array) *Array { diff --git a/x/models/gemma3/gemma3.go b/x/models/gemma3/gemma3.go index f297810ba..688b47e00 100644 --- a/x/models/gemma3/gemma3.go +++ b/x/models/gemma3/gemma3.go @@ -415,7 +415,7 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array if caches != nil && i < len(caches) { c = caches[i] } - h = layer.Forward(h, c, B, L, m.TextConfig) + h = layer.Forward(h, b, c, B, L, m.TextConfig) } return mlx.RMSNormFn(h, m.NormScaled, m.RMSNormEps) @@ -455,10 +455,10 @@ func (m *Model) FormatPrompt(prompt string) string { return fmt.Sprintf("user\n%s\nmodel\n", prompt) } -func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array { +func (l *DecoderLayer) Forward(x *mlx.Array, b *batch.ForwardBatch, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array { normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps) - attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg) + attnOut := l.Attention.Forward(normed, b, c, B, L, l.IsSliding, cfg) attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps) h := mlx.Add(x, attnOut) @@ -470,7 +470,7 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex return mlx.Add(h, mlpOut) } -func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array { +func (a *Attention) Forward(x *mlx.Array, b *batch.ForwardBatch, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array { q := a.QProj.Forward(x) k := a.KProj.Forward(x) v := a.VProj.Forward(x) @@ -492,12 +492,9 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b ropeTheta = cfg.RopeLocalBaseFreq } - offset := 0 - if c != nil { - offset = int(c.Offsets()[0]) - } - q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, offset) - k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, offset) + positions := batch.SequentialPositions(b, c.Offsets()) + q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, positions) + k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, positions) if c != nil { k, v, _ = c.Update(nil, k, v) diff --git a/x/models/glm4_moe_lite/glm4_moe_lite.go b/x/models/glm4_moe_lite/glm4_moe_lite.go index 9d9c7e3da..724eb206f 100644 --- a/x/models/glm4_moe_lite/glm4_moe_lite.go +++ b/x/models/glm4_moe_lite/glm4_moe_lite.go @@ -88,7 +88,7 @@ type MLAAttention struct { } // Forward computes absorbed MLA attention output. -func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { +func (a *MLAAttention) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { q := a.QAProj.Forward(x) q = a.QALayerNorm.Forward(q, cfg.RMSNormEps) q = a.QBProj.Forward(q) @@ -110,12 +110,8 @@ func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Con kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps) kvLatent = mlx.ExpandDims(kvLatent, 1) - offset := 0 - if c != nil { - offset = int(c.Offsets()[0]) - } - qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset) - kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset) + qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, positions) + kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, positions) qLatent := a.EmbedQ.Forward(qNope) @@ -314,8 +310,8 @@ type DenseBlock struct { } // Forward applies the dense block -func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { - r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg) +func (b *DenseBlock) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { + r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), positions, c, B, L, cfg) h := mlx.Add(x, r) r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps)) @@ -331,8 +327,8 @@ type MoEBlock struct { } // Forward applies the MoE block -func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { - r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg) +func (b *MoEBlock) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { + r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), positions, c, B, L, cfg) h := mlx.Add(x, r) r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg) @@ -341,7 +337,7 @@ func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) // Block interface for both dense and MoE blocks type Block interface { - Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array + Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array } // Model represents the complete GLM4-MoE-Lite model @@ -708,13 +704,14 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array B, L := int32(dims[0]), int32(dims[1]) h := m.EmbedTokens.Forward(b.InputIDs) + positions := batch.SequentialPositions(b, caches[0].Offsets()) for i, layer := range m.Layers { var c cache.Cache if caches != nil { c = caches[i] } - h = layer.Forward(h, c, B, L, m.Config) + h = layer.Forward(h, positions, c, B, L, m.Config) } h = m.Norm.Forward(h, m.RMSNormEps) diff --git a/x/models/llama/llama.go b/x/models/llama/llama.go index c1d53cb44..ee93efd74 100644 --- a/x/models/llama/llama.go +++ b/x/models/llama/llama.go @@ -242,12 +242,14 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array B, L := int32(dims[0]), int32(dims[1]) h := m.EmbedTokens.Forward(b.InputIDs) + positions := batch.SequentialPositions(b, caches[0].Offsets()) + for i, layer := range m.Layers { var c cache.Cache if caches != nil && i < len(caches) { c = caches[i] } - h = layer.Forward(h, c, B, L, m.Config) + h = layer.Forward(h, positions, c, B, L, m.Config) } return m.Norm.Forward(h, m.RMSNormEps) @@ -277,12 +279,12 @@ func (m *Model) NewCaches() []cache.Cache { return caches } -func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { - h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)) +func (l *Layer) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { + h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), positions, c, B, L, cfg)) return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps))) } -func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { +func (a *Attention) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { q := a.QProj.Forward(x) k := a.KProj.Forward(x) v := a.VProj.Forward(x) @@ -296,12 +298,8 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim) v = mlx.Transpose(v, 0, 2, 1, 3) - offset := 0 - if c != nil { - offset = int(c.Offsets()[0]) - } - q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset) - k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset) + q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions) + k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions) if c != nil { k, v, _ = c.Update(nil, k, v) diff --git a/x/models/qwen3/qwen3.go b/x/models/qwen3/qwen3.go index 022eef4ce..0fd2ec565 100644 --- a/x/models/qwen3/qwen3.go +++ b/x/models/qwen3/qwen3.go @@ -259,12 +259,14 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array B, L := int32(dims[0]), int32(dims[1]) h := m.EmbedTokens.Forward(b.InputIDs) + positions := batch.SequentialPositions(b, caches[0].Offsets()) + for i, layer := range m.Layers { var c cache.Cache if caches != nil && i < len(caches) { c = caches[i] } - h = layer.Forward(h, c, B, L, m.Config) + h = layer.Forward(h, positions, c, B, L, m.Config) } return m.Norm.Forward(h, m.RMSNormEps) @@ -294,12 +296,12 @@ func (m *Model) NewCaches() []cache.Cache { return caches } -func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { - h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)) +func (l *Layer) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { + h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), positions, c, B, L, cfg)) return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps))) } -func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { +func (a *Attention) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { q := a.QProj.Forward(x) k := a.KProj.Forward(x) v := a.VProj.Forward(x) @@ -315,12 +317,8 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config k = mlx.Transpose(k, 0, 2, 1, 3) v = mlx.Transpose(v, 0, 2, 1, 3) - offset := 0 - if c != nil { - offset = int(c.Offsets()[0]) - } - q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset) - k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset) + q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions) + k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions) if c != nil { k, v, _ = c.Update(nil, k, v) diff --git a/x/models/qwen3_5/qwen3_5.go b/x/models/qwen3_5/qwen3_5.go index 9b08fc029..f1e32b443 100644 --- a/x/models/qwen3_5/qwen3_5.go +++ b/x/models/qwen3_5/qwen3_5.go @@ -1127,7 +1127,7 @@ func splitQKVZBA(mixedQKVZ, mixedBA *mlx.Array, cfg *Config, B, L int32) (q, k, return q, k, v, z, b, a } -func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { +func (a *FullAttention) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { qg := a.QProj.Forward(x) qg = mlx.Reshape(qg, B, L, cfg.NumAttentionHeads, cfg.HeadDim*2) q := mlx.SliceStartStop(qg, []int32{0, 0, 0, 0}, []int32{B, L, cfg.NumAttentionHeads, cfg.HeadDim}) @@ -1146,12 +1146,8 @@ func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co k = mlx.Transpose(k, 0, 2, 1, 3) v = mlx.Transpose(v, 0, 2, 1, 3) - offset := 0 - if c != nil { - offset = int(c.Offsets()[0]) - } - q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset) - k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset) + q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, positions) + k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, positions) if c != nil { k, v, _ = c.Update(nil, k, v) @@ -1333,13 +1329,13 @@ func (m *SparseMoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array { return mlx.Reshape(y, B, L, cfg.HiddenSize) } -func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { +func (l *Layer) Forward(x, positions *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array { var r *mlx.Array normed := l.InputNorm.Forward(x, cfg.RMSNormEps) if l.IsLinear { r = l.Linear.Forward(normed, c, B, L, cfg) } else { - r = l.FullAttn.Forward(normed, c, B, L, cfg) + r = l.FullAttn.Forward(normed, positions, c, B, L, cfg) } h := mlx.Add(x, r) r = l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps), cfg) @@ -1351,12 +1347,14 @@ func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array B, L := int32(dims[0]), int32(dims[1]) h := m.EmbedTokens.Forward(b.InputIDs) + positions := batch.SequentialPositions(b, caches[0].Offsets()) + for i, layer := range m.Layers { var c cache.Cache if caches != nil && i < len(caches) { c = caches[i] } - h = layer.Forward(h, c, B, L, m.Config) + h = layer.Forward(h, positions, c, B, L, m.Config) } out := m.Norm.Forward(h, m.RMSNormEps) return out