diff --git a/x/mlxrunner/batch/batch.go b/x/mlxrunner/batch/batch.go new file mode 100644 index 000000000..c93f52e29 --- /dev/null +++ b/x/mlxrunner/batch/batch.go @@ -0,0 +1,32 @@ +package batch + +import "github.com/ollama/ollama/x/mlxrunner/mlx" + +// ForwardBatch carries per-step metadata through the model forward pass. +// InputIDs shape is [1, N] where N = sum(SeqLens). SeqLens indicates how +// many tokens belong to each sequence. +type ForwardBatch struct { + // InputIDs holds token IDs across all sequences. Shape: [1, N]. + InputIDs *mlx.Array + + // SeqIDs uniquely identifies each sequence in the batch. + SeqIDs []int + + // SeqLens is the number of new tokens per sequence in this step. + // For decode batching every entry is 1. For prefill it may vary. + SeqLens []int +} + +// TotalLen returns the total number of tokens across all sequences. +func (b *ForwardBatch) TotalLen() int { + n := 0 + for _, l := range b.SeqLens { + n += l + } + return n +} + +// NumSeqs returns the number of sequences in the batch. +func (b *ForwardBatch) NumSeqs() int { + return len(b.SeqIDs) +} diff --git a/x/mlxrunner/model/base/base.go b/x/mlxrunner/model/base/base.go index 3a85b6eb0..f56458220 100644 --- a/x/mlxrunner/model/base/base.go +++ b/x/mlxrunner/model/base/base.go @@ -6,6 +6,7 @@ import ( "log/slog" "sync" + "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" @@ -14,7 +15,7 @@ import ( // Model is the interface that model implementations must satisfy. type Model interface { - Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array + Forward(b *batch.ForwardBatch, cache []cache.Cache) *mlx.Array Unembed(x *mlx.Array) *mlx.Array NumLayers() int Tokenizer() *tokenizer.Tokenizer diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index 4ad776389..a27237fca 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -9,6 +9,7 @@ import ( "time" "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/mlx" ) @@ -131,7 +132,11 @@ func (r *Runner) TextGenerationPipeline(request Request) error { } } - r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches) + r.Model.Forward(&batch.ForwardBatch{ + InputIDs: mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), + SeqIDs: []int{0}, + SeqLens: []int{n}, + }, caches) mlx.Sweep() materializeCaches() processed += n @@ -149,7 +154,11 @@ func (r *Runner) TextGenerationPipeline(request Request) error { } step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) { - fwd := r.Model.Forward(token.ExpandDims(0), caches) + fwd := r.Model.Forward(&batch.ForwardBatch{ + InputIDs: token.ExpandDims(0), + SeqIDs: []int{0}, + SeqLens: []int{1}, + }, caches) logits := r.Model.Unembed(fwd) logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1) diff --git a/x/models/gemma3/gemma3.go b/x/models/gemma3/gemma3.go index 266222b69..0da2355fc 100644 --- a/x/models/gemma3/gemma3.go +++ b/x/models/gemma3/gemma3.go @@ -6,6 +6,7 @@ import ( "fmt" "math" + "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" @@ -402,11 +403,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { return nil } -func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { - dims := tokens.Dims() +func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array { + dims := b.InputIDs.Dims() B, L := int32(dims[0]), int32(dims[1]) - h := m.EmbedTokens.Forward(tokens) + h := m.EmbedTokens.Forward(b.InputIDs) h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize)))) for i, layer := range m.Layers { diff --git a/x/models/glm4_moe_lite/glm4_moe_lite.go b/x/models/glm4_moe_lite/glm4_moe_lite.go index a0a37a3f0..8732bc418 100644 --- a/x/models/glm4_moe_lite/glm4_moe_lite.go +++ b/x/models/glm4_moe_lite/glm4_moe_lite.go @@ -7,6 +7,7 @@ import ( "fmt" "math" + "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" @@ -702,11 +703,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { } // Forward computes the forward pass of the model -func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { - dims := tokens.Dims() +func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array { + dims := b.InputIDs.Dims() B, L := int32(dims[0]), int32(dims[1]) - h := m.EmbedTokens.Forward(tokens) + h := m.EmbedTokens.Forward(b.InputIDs) for i, layer := range m.Layers { var c cache.Cache diff --git a/x/models/llama/llama.go b/x/models/llama/llama.go index ca99d9148..22a65d452 100644 --- a/x/models/llama/llama.go +++ b/x/models/llama/llama.go @@ -6,6 +6,7 @@ import ( "fmt" "math" + "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" @@ -236,11 +237,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { return nil } -func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { - dims := tokens.Dims() +func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array { + dims := b.InputIDs.Dims() B, L := int32(dims[0]), int32(dims[1]) - h := m.EmbedTokens.Forward(tokens) + h := m.EmbedTokens.Forward(b.InputIDs) for i, layer := range m.Layers { var c cache.Cache if caches != nil && i < len(caches) { diff --git a/x/models/qwen3/qwen3.go b/x/models/qwen3/qwen3.go index 6773f0eb5..60a5ca501 100644 --- a/x/models/qwen3/qwen3.go +++ b/x/models/qwen3/qwen3.go @@ -6,6 +6,7 @@ import ( "fmt" "math" + "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" @@ -253,11 +254,11 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { return nil } -func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { - dims := tokens.Dims() +func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array { + dims := b.InputIDs.Dims() B, L := int32(dims[0]), int32(dims[1]) - h := m.EmbedTokens.Forward(tokens) + h := m.EmbedTokens.Forward(b.InputIDs) for i, layer := range m.Layers { var c cache.Cache if caches != nil && i < len(caches) { diff --git a/x/models/qwen3_5/qwen3_5.go b/x/models/qwen3_5/qwen3_5.go index 5dbb59dce..30740588f 100644 --- a/x/models/qwen3_5/qwen3_5.go +++ b/x/models/qwen3_5/qwen3_5.go @@ -7,6 +7,7 @@ import ( "math" "strings" + "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" @@ -1345,11 +1346,11 @@ func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *m return mlx.Add(h, r) } -func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { - dims := tokens.Dims() +func (m *Model) Forward(b *batch.ForwardBatch, caches []cache.Cache) *mlx.Array { + dims := b.InputIDs.Dims() B, L := int32(dims[0]), int32(dims[1]) - h := m.EmbedTokens.Forward(tokens) + h := m.EmbedTokens.Forward(b.InputIDs) for i, layer := range m.Layers { var c cache.Cache if caches != nil && i < len(caches) {