mirror of
https://github.com/ollama/ollama.git
synced 2026-04-25 02:06:11 +02:00
mlxrunner: introduce ForwardBatch for model forward pass
Replace the raw *mlx.Array token input with a ForwardBatch struct that carries InputIDs alongside sequence metadata (SeqIDs, SeqLens). InputIDs remain [1, N] shaped — all model code is unchanged beyond the signature.
This commit is contained in:
32
x/mlxrunner/batch/batch.go
Normal file
32
x/mlxrunner/batch/batch.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package batch
|
||||
|
||||
import "github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
|
||||
// ForwardBatch carries per-step metadata through the model forward pass.
|
||||
// InputIDs shape is [1, N] where N = sum(SeqLens). SeqLens indicates how
|
||||
// many tokens belong to each sequence.
|
||||
type ForwardBatch struct {
|
||||
// InputIDs holds token IDs across all sequences. Shape: [1, N].
|
||||
InputIDs *mlx.Array
|
||||
|
||||
// SeqIDs uniquely identifies each sequence in the batch.
|
||||
SeqIDs []int
|
||||
|
||||
// SeqLens is the number of new tokens per sequence in this step.
|
||||
// For decode batching every entry is 1. For prefill it may vary.
|
||||
SeqLens []int
|
||||
}
|
||||
|
||||
// TotalLen returns the total number of tokens across all sequences.
|
||||
func (b *ForwardBatch) TotalLen() int {
|
||||
n := 0
|
||||
for _, l := range b.SeqLens {
|
||||
n += l
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// NumSeqs returns the number of sequences in the batch.
|
||||
func (b *ForwardBatch) NumSeqs() int {
|
||||
return len(b.SeqIDs)
|
||||
}
|
||||
Reference in New Issue
Block a user