Compare commits

...

2 Commits

Author SHA1 Message Date
Jesse Gross
a50199cd70 mlxrunner: batch the sampler across multiple sequences
Register sequences with Add/Remove; each Sample call takes any subset of
registered slots and samples one token per row, appending to each slot's
ring-buffer history. When all slots share Options and penalty rings are
full, one fused transform pass runs over the whole batch via a persistent
pooled history tensor; otherwise calls fall back to per-slot serial
processing indexed against the same pool.

Performance is unchanged for a single sequence, which is all that is
exposed for now.
2026-04-21 15:09:19 -07:00
Jesse Gross
5264ba9194 mlxrunner: track sampler history in a fixed-size ring buffer
AppendToken used to concatenate the new token onto the history tensor
and slice it back to RepeatLastN every decode step, churning the graph
shape and reallocating a fresh tensor each call. The stateful penalties
don't care about order within the window, so a fixed-capacity ring with
one SliceUpdate per append keeps the tensor shape constant across
steps.
2026-04-21 14:40:19 -07:00
8 changed files with 770 additions and 223 deletions

View File

@@ -72,6 +72,10 @@ func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
} }
func (t *Array) Concatenate(axis int, others ...*Array) *Array { func (t *Array) Concatenate(axis int, others ...*Array) *Array {
if len(others) == 0 {
return t
}
vector := C.mlx_vector_array_new() vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector) defer C.mlx_vector_array_free(vector)
@@ -127,9 +131,9 @@ func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
return out return out
} }
func (t *Array) Logsumexp(keepDims bool) *Array { func (t *Array) LogsumexpAxis(axis int, keepDims bool) *Array {
out := New("LOGSUMEXP") out := New("LOGSUMEXP_AXIS")
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx) C.mlx_logsumexp_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out return out
} }

View File

@@ -376,6 +376,9 @@ func Concatenate(arrays []*Array, axis int) *Array {
if len(arrays) == 0 { if len(arrays) == 0 {
return nil return nil
} }
if len(arrays) == 1 {
return arrays[0]
}
return arrays[0].Concatenate(axis, arrays[1:]...) return arrays[0].Concatenate(axis, arrays[1:]...)
} }

View File

@@ -49,14 +49,15 @@ func (r *Runner) Prepare(request *Request) error {
return nil return nil
} }
// The runner serializes requests today so we just use a fixed slot ID.
const pipelineSlot = 0
func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) error { func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) error {
mlx.ResetPeakMemory() mlx.ResetPeakMemory()
var sample, nextSample sampler.Result var sample, nextSample sampler.Result
defer func() { defer func() {
if request.Sampler != nil { r.Sampler.Remove(pipelineSlot)
request.Sampler.Free()
}
mlx.Unpin(sample.Arrays()...) mlx.Unpin(sample.Arrays()...)
mlx.Unpin(nextSample.Arrays()...) mlx.Unpin(nextSample.Arrays()...)
mlx.Sweep() mlx.Sweep()
@@ -70,7 +71,6 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er
}() }()
inputs := request.Tokens inputs := request.Tokens
request.Sampler.ResetHistory(inputs)
session := r.cache.begin(r.Model, inputs) session := r.cache.begin(r.Model, inputs)
defer session.close() defer session.close()
@@ -122,7 +122,7 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er
} }
} }
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches) r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], 1, n), caches)
mlx.Sweep() mlx.Sweep()
materializeCaches() materializeCaches()
processed += n processed += n
@@ -139,21 +139,28 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er
mlx.ClearCache() mlx.ClearCache()
} }
// Register the sampler after prefill completes.
r.Sampler.Add(pipelineSlot, request.SamplerOpts, inputs)
step := func(token *mlx.Array) sampler.Result { step := func(token *mlx.Array) sampler.Result {
fwd := r.Model.Forward(token.ExpandDims(0), caches) fwd := r.Model.Forward(token, caches)
logits := r.Model.Unembed(fwd) logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1) logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
sample := request.Sampler.Sample(logits) sample := r.Sampler.Sample([]int{pipelineSlot}, logits)
mlx.Pin(sample.Arrays()...) mlx.Pin(sample.Arrays()...)
mlx.Sweep() mlx.Sweep()
mlx.AsyncEval(sample.Arrays()...) mlx.AsyncEval(sample.Arrays()...)
return sample return sample
} }
sample = step(mlx.FromValues(tokens[processed:], total-processed)) sample = step(mlx.FromValues(tokens[processed:], 1, total-processed))
dec := decoder{tokenizer: r.Tokenizer} dec := decoder{
tokenizer: r.Tokenizer,
wantLogprobs: request.SamplerOpts.Logprobs,
wantTopLogprobs: request.SamplerOpts.TopLogprobs,
}
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.NumPredict, DoneReason: 1} final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.NumPredict, DoneReason: 1}
for i := range request.Options.NumPredict { for i := range request.Options.NumPredict {
@@ -161,8 +168,7 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er
return err return err
} }
request.Sampler.AppendToken(sample.Token) nextSample = step(sample.Token.ExpandDims(-1))
nextSample = step(sample.Token)
if i == 0 { if i == 0 {
mlx.Eval(sample.Arrays()...) mlx.Eval(sample.Arrays()...)
@@ -209,15 +215,17 @@ func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) er
// with those bytes so Content and Logprobs stay aligned when a chunk does // with those bytes so Content and Logprobs stay aligned when a chunk does
// flush. // flush.
type decoder struct { type decoder struct {
tokenizer *tokenizer.Tokenizer tokenizer *tokenizer.Tokenizer
buf bytes.Buffer buf bytes.Buffer
logprobs []llm.Logprob logprobs []llm.Logprob
wantLogprobs bool
wantTopLogprobs int
} }
func (d *decoder) decode(res sampler.Result) (CompletionResponse, bool) { func (d *decoder) decode(res sampler.Result) (CompletionResponse, bool) {
output := int32(res.Token.Int()) output := int32(res.Token.Int())
d.buf.WriteString(d.tokenizer.Decode([]int32{output})) d.buf.WriteString(d.tokenizer.Decode([]int32{output}))
d.logprobs = append(d.logprobs, buildLogprob(res, d.tokenizer.Decode)...) d.logprobs = append(d.logprobs, buildLogprob(res, d.wantLogprobs, d.wantTopLogprobs, d.tokenizer.Decode)...)
content := flushValidUTF8Prefix(&d.buf) content := flushValidUTF8Prefix(&d.buf)
if content == "" { if content == "" {
@@ -228,8 +236,13 @@ func (d *decoder) decode(res sampler.Result) (CompletionResponse, bool) {
return resp, true return resp, true
} }
func buildLogprob(sample sampler.Result, decode func([]int32) string) []llm.Logprob { // buildLogprob converts the sampler's logprob tensors into the wire-format
if sample.Logprob == nil { // llm.Logprob entries the caller wants. The sampler populates its logprob
// tensors whenever any registered slot requested them, so the caller must
// gate emission on its own request config (wantLogprobs / wantTopLogprobs)
// rather than on whether the tensors happen to be non-nil.
func buildLogprob(sample sampler.Result, wantLogprobs bool, wantTopLogprobs int, decode func([]int32) string) []llm.Logprob {
if !wantLogprobs || sample.Logprob == nil {
return nil return nil
} }
tok := func(id int32) string { return decode([]int32{id}) } tok := func(id int32) string { return decode([]int32{id}) }
@@ -241,7 +254,7 @@ func buildLogprob(sample sampler.Result, decode func([]int32) string) []llm.Logp
}, },
} }
if sample.TopTokens != nil { if wantTopLogprobs > 0 && sample.TopTokens != nil {
ids := sample.TopTokens.Ints() ids := sample.TopTokens.Ints()
vals := sample.TopLogprobs.Floats() vals := sample.TopLogprobs.Floats()
pairs := make([]llm.TokenLogprob, len(ids)) pairs := make([]llm.TokenLogprob, len(ids))
@@ -251,9 +264,14 @@ func buildLogprob(sample sampler.Result, decode func([]int32) string) []llm.Logp
Logprob: float64(vals[i]), Logprob: float64(vals[i]),
} }
} }
// The sampler emits the top maxK across registered slots via
// Argpartition, which leaves entries unsorted.
sort.Slice(pairs, func(i, j int) bool { sort.Slice(pairs, func(i, j int) bool {
return pairs[i].Logprob > pairs[j].Logprob return pairs[i].Logprob > pairs[j].Logprob
}) })
if wantTopLogprobs < len(pairs) {
pairs = pairs[:wantTopLogprobs]
}
out.TopLogprobs = pairs out.TopLogprobs = pairs
} }
return []llm.Logprob{out} return []llm.Logprob{out}

View File

@@ -27,15 +27,16 @@ type Request struct {
Responses chan CompletionResponse Responses chan CompletionResponse
Pipeline func(context.Context, Request) error Pipeline func(context.Context, Request) error
Ctx context.Context //nolint:containedctx Ctx context.Context //nolint:containedctx
Tokens []int32 Tokens []int32
Sampler *sample.Sampler SamplerOpts sample.Options
} }
type Runner struct { type Runner struct {
Model base.Model Model base.Model
Tokenizer *tokenizer.Tokenizer Tokenizer *tokenizer.Tokenizer
Requests chan Request Requests chan Request
Sampler *sample.Sampler
cache kvCache cache kvCache
contextLength int contextLength int
} }
@@ -67,6 +68,7 @@ func (r *Runner) Load(modelName string) error {
r.Model = m r.Model = m
r.Tokenizer = m.Tokenizer() r.Tokenizer = m.Tokenizer()
r.contextLength = m.MaxContextLength() r.contextLength = m.MaxContextLength()
r.Sampler = sample.New(r.contextLength)
mlx.EnableCompile() mlx.EnableCompile()
return nil return nil

View File

@@ -24,14 +24,15 @@ type logprobEntry struct {
func runSampleLogprobs(t *testing.T, logits []float32, topK int) (int, float64, []logprobEntry) { func runSampleLogprobs(t *testing.T, logits []float32, topK int) (int, float64, []logprobEntry) {
t.Helper() t.Helper()
s := New(Options{Logprobs: true, TopLogprobs: topK}) s := New(128)
defer func() { defer func() {
s.Free() s.Free()
mlx.Sweep() mlx.Sweep()
}() }()
s.Add(0, Options{Logprobs: true, TopLogprobs: topK}, nil)
tensor := mlx.FromValues(logits, 1, len(logits)) tensor := mlx.FromValues(logits, 1, len(logits))
res := s.Sample(tensor) res := s.Sample([]int{0}, tensor)
mlx.Pin(res.Arrays()...) mlx.Pin(res.Arrays()...)
defer mlx.Unpin(res.Arrays()...) defer mlx.Unpin(res.Arrays()...)
@@ -225,6 +226,42 @@ func TestSampleLogprobsSelectedTokenCorrectness(t *testing.T) {
} }
} }
// TestBatchedLogprobsPerRow verifies that per-row logprobs in a batched
// sample call match the per-slot reference. The numerically-stable softmax
// must reduce along the last axis only, not over the whole batch.
func TestBatchedLogprobsPerRow(t *testing.T) {
rowA := []float32{2, 1, 0}
rowB := []float32{0, 5, 0}
_, wantA, _ := runSampleLogprobs(t, rowA, 0)
_, wantB, _ := runSampleLogprobs(t, rowB, 0)
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
})
s.Add(1, Options{Logprobs: true}, nil)
s.Add(2, Options{Logprobs: true}, nil)
logits := mlx.FromValues(append(append([]float32{}, rowA...), rowB...), 2, 3)
res := s.Sample([]int{1, 2}, logits)
mlx.Pin(res.Arrays()...)
t.Cleanup(func() { mlx.Unpin(res.Arrays()...) })
mlx.Eval(res.Arrays()...)
got := res.Logprob.Floats()
if len(got) != 2 {
t.Fatalf("Logprob length = %d, want 2", len(got))
}
if math.Abs(float64(got[0])-wantA) > 1e-5 {
t.Errorf("row 0 logprob = %f, want %f (per-slot reference)", got[0], wantA)
}
if math.Abs(float64(got[1])-wantB) > 1e-5 {
t.Errorf("row 1 logprob = %f, want %f (per-slot reference)", got[1], wantB)
}
}
func TestSampleLogprobsTopKOrdering(t *testing.T) { func TestSampleLogprobsTopKOrdering(t *testing.T) {
// Logits chosen so argmax order differs from index order. // Logits chosen so argmax order differs from index order.
logits := []float32{2.0, 5.0, 1.0, 4.0, 3.0} logits := []float32{2.0, 5.0, 1.0, 4.0, 3.0}

View File

@@ -1,13 +1,13 @@
package sample package sample
import ( import (
"fmt"
"math" "math"
"slices"
"github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/mlx"
) )
type Transform func(*Sampler, *mlx.Array) *mlx.Array
type Options struct { type Options struct {
Temperature float32 Temperature float32
TopP float32 TopP float32
@@ -24,21 +24,15 @@ type Options struct {
TopLogprobs int TopLogprobs int
} }
type Sampler struct { // Result bundles the outputs of one decode step. Logprob/TopTokens/
Options // TopLogprobs are populated whenever any registered slot has Logprobs
// (respectively TopLogprobs>0). Consumers need to filter by their
history *mlx.Array // per-slot Options.
historyLen int
transforms []Transform
}
// Result bundles the outputs of one decode step. The logprob tensors are
// populated only when the sampler is configured to report them.
type Result struct { type Result struct {
Token *mlx.Array // sampled token id, shape [B] Token *mlx.Array // sampled token ids, shape [B]
Logprob *mlx.Array // sampled-token logprob, shape [B,1]; nil unless Logprobs Logprob *mlx.Array // sampled-token logprobs, shape [B,1]; nil unless any registered slot has Logprobs
TopTokens *mlx.Array // top-K token ids, shape [B,K]; nil unless TopLogprobs > 0 TopTokens *mlx.Array // top-K token ids, shape [B,maxK]; nil unless any registered slot has TopLogprobs>0
TopLogprobs *mlx.Array // top-K logprobs, shape [B,K]; nil unless TopLogprobs > 0 TopLogprobs *mlx.Array // top-K logprobs, shape [B,maxK]; same
} }
// Arrays returns the tensor fields as a slice so callers can drive the mlx // Arrays returns the tensor fields as a slice so callers can drive the mlx
@@ -48,121 +42,300 @@ func (r Result) Arrays() []*mlx.Array {
return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs} return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs}
} }
func New(opts Options) *Sampler { // Sampler is a batched, slot-based sampler. Sequences are registered with
if opts.RepeatPenalty <= 0 { // Add and released with Remove. Each Sample call takes a subset of
opts.RepeatPenalty = 1 // registered slots (in any order) with their [B,V] logits, samples one
// token per row, and appends it to that slot's ring-buffer history. Slots
// not named in a given call are untouched.
type Sampler struct {
slots []*slotState
byID map[int]*slotState
// history is the pooled ring-buffer storage, [B, W] int32. Row i
// belongs to slots[i]; W is max(RepeatLastN) across penalty slots.
// Allocated on the first penalty slot, rebuilt only in Add/Remove.
history *mlx.Array
// allSameOpts: every registered slot shares Options. When true the
// canonical shared value is s.slots[0].opts.
allSameOpts bool
// anyLogprobs / maxTopLogprobs: compute-for-all output config.
// Sample populates Logprob (and Top* when maxTopLogprobs>0) whenever
// any registered slot requests them, even if that slot isn't in the
// current call.
anyLogprobs bool
maxTopLogprobs int
// numCtx is the runner's context window; normalize uses it to
// resolve the repeat_last_n == -1 sentinel.
numCtx int
}
type slotState struct {
opts Options
transforms []transform
historyLen int
}
type slotCtx struct {
opts Options
history *mlx.Array // 2D [B, W] when penalties are configured; nil otherwise
}
type transform func(*slotCtx, *mlx.Array) *mlx.Array
// New constructs an empty sampler with no registered slots. numCtx is
// the runner's context window and must be positive.
func New(numCtx int) *Sampler {
return &Sampler{
byID: make(map[int]*slotState),
allSameOpts: true,
numCtx: numCtx,
}
}
// historyWidth returns the column count of the pooled history tensor,
// or 0 when no penalty slot has forced it to be allocated.
func (s *Sampler) historyWidth() int {
if s.history == nil {
return 0
}
return s.history.Dim(1)
}
func (o Options) usesHistory() bool {
// RepeatLastN == 0 disables the penalty ring per the repeat_last_n API
// contract (0 = disabled), overriding any penalty coefficients.
if o.RepeatLastN == 0 {
return false
}
return o.RepeatPenalty != 1 || o.PresencePenalty != 0 || o.FrequencyPenalty != 0
}
func (o Options) normalize(numCtx int) Options {
if o.RepeatPenalty <= 0 {
o.RepeatPenalty = 1
}
// Resolve the repeat_last_n == -1 sentinel ("-1 = num_ctx") against
// the caller's context window.
if o.RepeatLastN < 0 {
o.RepeatLastN = numCtx
}
if !o.usesHistory() {
// Zero the ring capacity so slots that differ only in a spurious
// RepeatLastN still batch together and don't inflate pool width.
o.RepeatLastN = 0
}
return o
}
func (o Options) buildTransforms() []transform {
var ts []transform
if o.usesHistory() {
ts = append(ts, penalty)
} }
s := &Sampler{Options: opts} hasTopP := o.TopP > 0 && o.TopP < 1
hasTopK := o.TopK > 0
var transforms []Transform
if s.usesHistory() {
transforms = append(transforms, penalty)
}
hasTopP := opts.TopP > 0 && opts.TopP < 1
hasTopK := opts.TopK > 0
switch { switch {
case hasTopP: case hasTopP:
// topKTopP always does a full descending sort for the top-P // topKTopP always does a full descending sort for the top-P
// cumulative mask and opportunistically masks top-K during the // cumulative mask and opportunistically masks top-K during the
// same pass when it is also configured. // same pass when it is also configured.
transforms = append(transforms, topKTopP) ts = append(ts, topKTopP)
case hasTopK: case hasTopK:
// Argpartition (partial sort) is cheaper than a full sort. // Argpartition (partial sort) is cheaper than a full sort.
transforms = append(transforms, topK) ts = append(ts, topK)
} }
if opts.MinP != 0 { if o.MinP != 0 {
transforms = append(transforms, minP) ts = append(ts, minP)
} }
if opts.Temperature == 0 { if o.Temperature == 0 {
transforms = append(transforms, greedy) ts = append(ts, greedy)
} else { } else {
transforms = append(transforms, temperature) ts = append(ts, temperature)
} }
return ts
s.transforms = transforms
return s
} }
func (s *Sampler) usesHistory() bool { // Add registers a sequence under seqID. The last RepeatLastN entries of
return s.RepeatPenalty != 1 || s.PresencePenalty != 0 || s.FrequencyPenalty != 0 // priorTokens seed the ring buffer.
} func (s *Sampler) Add(seqID int, opts Options, priorTokens []int32) {
if _, dup := s.byID[seqID]; dup {
func (s *Sampler) setHistory(history *mlx.Array, historyLen int) { panic(fmt.Sprintf("sample.Sampler.Add: seqID %d already registered", seqID))
if history != nil {
mlx.Pin(history)
} }
if s.history != nil {
opts = opts.normalize(s.numCtx)
slot := &slotState{
opts: opts,
transforms: opts.buildTransforms(),
}
// Grow the pool to hold this slot's row. The pool is lazy — the first
// penalty slot allocates it — and thereafter every registered slot
// gets a row (rows for non-penalty slots are zero and never read).
// Invariant: s.history is pinned whenever non-nil.
if s.history != nil || opts.usesHistory() {
targetWidth := max(opts.RepeatLastN, s.historyWidth())
newRow := makeHistoryRow(priorTokens, opts.RepeatLastN, targetWidth)
var pool *mlx.Array
switch {
case s.history == nil && len(s.slots) == 0:
pool = newRow
case s.history == nil:
// First penalty slot with non-penalty slots already registered;
// seed zero rows so s.slots and pool row indices stay aligned.
zeros := mlx.Zeros(mlx.DTypeInt32, len(s.slots), targetWidth)
pool = zeros.Concatenate(0, newRow)
case targetWidth > s.historyWidth():
pad := mlx.Zeros(mlx.DTypeInt32, s.history.Dim(0), targetWidth-s.historyWidth())
pool = s.history.Concatenate(1, pad).Concatenate(0, newRow)
default:
pool = s.history.Concatenate(0, newRow)
}
mlx.Pin(pool)
mlx.Unpin(s.history) mlx.Unpin(s.history)
s.history = pool
if opts.usesHistory() {
// Cap on seed so the next write's ring position
// (historyLen % RepeatLastN) lands at 0, overwriting the
// oldest entry when the ring was filled from priors.
slot.historyLen = min(len(priorTokens), opts.RepeatLastN)
}
} }
s.history = history
s.historyLen = historyLen s.slots = append(s.slots, slot)
s.byID[seqID] = slot
s.recomputeInvariants()
} }
func (s *Sampler) ResetHistory(history []int32) { // makeHistoryRow builds a [1, width] int32 row with the last repeatLastN
if !s.usesHistory() { // entries of priorTokens packed into [0, min(len, repeatLastN)), zeros
// elsewhere.
func makeHistoryRow(priorTokens []int32, repeatLastN, width int) *mlx.Array {
take := min(len(priorTokens), repeatLastN)
if take <= 0 {
return mlx.Zeros(mlx.DTypeInt32, 1, width)
}
row := make([]int32, width)
copy(row, priorTokens[len(priorTokens)-take:])
return mlx.NewArrayInt32(row, []int32{1, int32(width)})
}
// recomputeInvariants refreshes allSameOpts and anyLogprobs/maxTopLogprobs
// from s.slots. Called at the end of Add and Remove.
func (s *Sampler) recomputeInvariants() {
if len(s.slots) == 0 {
s.allSameOpts = true
s.anyLogprobs = false
s.maxTopLogprobs = 0
return return
} }
if s.RepeatLastN > 0 && len(history) > s.RepeatLastN { first := s.slots[0].opts
history = history[len(history)-s.RepeatLastN:] s.allSameOpts = true
s.anyLogprobs = false
s.maxTopLogprobs = 0
for _, slot := range s.slots {
if slot.opts != first {
s.allSameOpts = false
}
if slot.opts.Logprobs {
s.anyLogprobs = true
if slot.opts.TopLogprobs > s.maxTopLogprobs {
s.maxTopLogprobs = slot.opts.TopLogprobs
}
}
} }
if len(history) == 0 { }
s.setHistory(nil, 0)
// Remove releases the slot. The pool tensor is rebuilt to drop the row.
func (s *Sampler) Remove(seqID int) {
slot, ok := s.byID[seqID]
if !ok {
return
}
delete(s.byID, seqID)
row := slices.Index(s.slots, slot)
s.slots = slices.Delete(s.slots, row, row+1)
s.recomputeInvariants()
if s.history == nil {
return return
} }
tokens := append([]int32(nil), history...) n := s.history.Dim(0)
s.setHistory(mlx.NewArrayInt32(tokens, []int32{int32(len(tokens))}), len(tokens)) var newHistory *mlx.Array
} switch {
case n == 1:
func (s *Sampler) AppendToken(token *mlx.Array) { newHistory = nil
if !s.usesHistory() || token == nil { case row == 0:
return newHistory = s.history.Slice(mlx.Slice(1, n), mlx.Slice())
} case row == n-1:
newHistory = s.history.Slice(mlx.Slice(0, row), mlx.Slice())
next := token.AsType(mlx.DTypeInt32) default:
nextLen := next.Size() before := s.history.Slice(mlx.Slice(0, row), mlx.Slice())
after := s.history.Slice(mlx.Slice(row+1, n), mlx.Slice())
if s.history != nil && s.historyLen > 0 { newHistory = before.Concatenate(0, after)
next = s.history.Concatenate(0, next) }
nextLen += s.historyLen
} mlx.Pin(newHistory)
mlx.Unpin(s.history)
if s.RepeatLastN > 0 && nextLen > s.RepeatLastN { s.history = newHistory
trim := nextLen - s.RepeatLastN
next = next.Slice(mlx.Slice(trim, nextLen))
nextLen = s.RepeatLastN
}
s.setHistory(next, nextLen)
} }
// Free releases the pooled history tensor and resets the sampler to the
// New-equivalent state so it may be reused.
func (s *Sampler) Free() { func (s *Sampler) Free() {
s.setHistory(nil, 0) mlx.Unpin(s.history)
*s = Sampler{
byID: make(map[int]*slotState),
allSameOpts: true,
numCtx: s.numCtx,
}
} }
// Sample runs the configured transform chain on the raw per-token logits // Sample draws one token per row of logits ([B,V]); seqIDs[i] names the
// and returns the sampled token id plus, when configured, the reported // slot whose logits live at row i. Each sampled token is appended to its
// log-probability tensors for the selected token and the top-K tokens. // slot's ring. Slots not named in seqIDs are untouched.
func (s *Sampler) Sample(logits *mlx.Array) Result { func (s *Sampler) Sample(seqIDs []int, logits *mlx.Array) Result {
scores := logits if len(seqIDs) == 0 {
for _, transform := range s.transforms { return Result{}
scores = transform(s, scores)
} }
res := Result{Token: scores}
if s.Logprobs { slots := make([]*slotState, len(seqIDs))
// Compute log_softmax in fp32 and subtract the max before for i, id := range seqIDs {
// logsumexp so the final subtraction stays on small values. slot, ok := s.byID[id]
// Otherwise it cancels two large numbers and loses precision. if !ok {
panic(fmt.Sprintf("sample.Sampler.Sample: seqID %d not registered", id))
}
slots[i] = slot
}
var token *mlx.Array
if opts0, ok := s.canBatch(slots); ok {
token = s.sampleTokensUniform(slots, opts0, logits)
} else {
token = s.sampleTokensSerial(slots, logits)
}
res := Result{Token: token}
if s.anyLogprobs {
// Log-softmax over original logits so every row holds a truthful
// value (compute-for-all; consumers filter per-slot). Subtract
// max first for numerical stability in the logsumexp.
lp := logits.AsType(mlx.DTypeFloat32) lp := logits.AsType(mlx.DTypeFloat32)
lp = lp.Subtract(lp.MaxAxis(-1, true)) lp = lp.Subtract(lp.MaxAxis(-1, true))
lp = lp.Subtract(lp.Logsumexp(true)) lp = lp.Subtract(lp.LogsumexpAxis(-1, true))
res.Logprob = lp.TakeAlongAxis(res.Token.ExpandDims(-1), -1) res.Logprob = lp.TakeAlongAxis(token.ExpandDims(-1), -1)
if k := s.TopLogprobs; k > 0 { if s.maxTopLogprobs > 0 {
k := s.maxTopLogprobs
if vocab := lp.Dim(lp.NumDims() - 1); k > vocab { if vocab := lp.Dim(lp.NumDims() - 1); k > vocab {
k = vocab k = vocab
} }
@@ -176,55 +349,180 @@ func (s *Sampler) Sample(logits *mlx.Array) Result {
return res return res
} }
func greedy(_ *Sampler, scores *mlx.Array) *mlx.Array { // canBatch reports whether the call can take the uniform batched path.
return scores.Argmax(-1, false) // All slots must share Options; when penalties are active the call must
// additionally cover every registered slot in registration order with a
// full ring, because the uniform path indexes the pool positionally.
func (s *Sampler) canBatch(slots []*slotState) (Options, bool) {
if !s.allSameOpts {
return Options{}, false
}
// slots is non-empty (Sample guards) and every slot is registered,
// so s.slots[0].opts is the canonical shared value.
shared := s.slots[0].opts
if !shared.usesHistory() {
return shared, true
}
if len(slots) != len(s.slots) {
return Options{}, false
}
for i, slot := range slots {
if s.slots[i] != slot || slot.historyLen < shared.RepeatLastN {
return Options{}, false
}
}
return shared, true
} }
func temperature(s *Sampler, scores *mlx.Array) *mlx.Array { // sampleTokensUniform runs one fused transform pass over the whole batch.
return mlx.DivScalar(scores, s.Temperature).Categorical(-1) // Reached only when canBatch is true, which lets the pool be used in place
// with a single PutAlongAxis write-back and no gather.
func (s *Sampler) sampleTokensUniform(slots []*slotState, opts Options, logits *mlx.Array) *mlx.Array {
B := len(slots)
var hist *mlx.Array
if opts.usesHistory() {
hist = s.history
if s.historyWidth() > opts.RepeatLastN {
hist = hist.Slice(mlx.Slice(), mlx.Slice(0, opts.RepeatLastN))
}
}
ctx := &slotCtx{opts: opts, history: hist}
scores := logits
for _, t := range slots[0].transforms {
scores = t(ctx, scores)
}
token := scores
if !opts.usesHistory() {
return token
}
writeIdxData := make([]int32, B)
for i, slot := range slots {
writeIdxData[i] = int32(slot.historyLen % opts.RepeatLastN)
slot.historyLen++
}
writeIdx := mlx.NewArrayInt32(writeIdxData, []int32{int32(B), 1})
s.history.Set(s.history.PutAlongAxis(writeIdx, token.ExpandDims(-1), 1))
return token
}
// sampleTokensSerial runs each slot's transforms against its own row of
// logits.
func (s *Sampler) sampleTokensSerial(slots []*slotState, logits *mlx.Array) *mlx.Array {
perSlotTokens := make([]*mlx.Array, len(slots))
rowOf := make(map[*slotState]int, len(s.slots))
for i, slot := range s.slots {
rowOf[slot] = i
}
for i, slot := range slots {
row := logits.Slice(mlx.Slice(i, i+1), mlx.Slice())
var hist *mlx.Array
if slot.opts.usesHistory() && slot.historyLen > 0 && s.history != nil {
poolRow := rowOf[slot]
fill := min(slot.historyLen, slot.opts.RepeatLastN)
hist = s.history.Slice(
mlx.Slice(poolRow, poolRow+1),
mlx.Slice(0, fill),
)
}
ctx := &slotCtx{opts: slot.opts, history: hist}
scores := row
for _, t := range slot.transforms {
scores = t(ctx, scores)
}
perSlotTokens[i] = scores
}
token := mlx.Concatenate(perSlotTokens, 0)
if s.history != nil {
// For each writing slot collect its flat (row-major) pool offset
// and the call-order position of its token. One PutAlongAxis on a
// flat view of the pool scatters all writes in a single op.
flatOffsets := make([]int32, 0, len(slots))
tokenPos := make([]int32, 0, len(slots))
for i, slot := range slots {
if !slot.opts.usesHistory() {
continue
}
ringPos := slot.historyLen % slot.opts.RepeatLastN
flatOffsets = append(flatOffsets, int32(rowOf[slot]*s.historyWidth()+ringPos))
tokenPos = append(tokenPos, int32(i))
slot.historyLen++
}
if len(flatOffsets) > 0 {
m := len(flatOffsets)
flatIdx := mlx.NewArrayInt32(flatOffsets, []int32{int32(m), 1})
writingTokens := token
if m != len(slots) {
tokenPosIdx := mlx.NewArrayInt32(tokenPos, []int32{int32(m)})
writingTokens = token.TakeAxis(tokenPosIdx, 0)
}
flatHist := s.history.Reshape(s.history.Dim(0)*s.historyWidth(), 1)
s.history.Set(flatHist.PutAlongAxis(flatIdx, writingTokens.ExpandDims(-1), 0).Reshape(s.history.Dim(0), s.historyWidth()))
}
}
return token
}
func greedy(_ *slotCtx, scores *mlx.Array) *mlx.Array {
return scores.Argmax(-1, false).AsType(mlx.DTypeInt32)
}
func temperature(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
return mlx.DivScalar(scores, ctx.opts.Temperature).Categorical(-1).AsType(mlx.DTypeInt32)
} }
// topKTopP applies top-P in a descending sort pass and, when top-K is also // topKTopP applies top-P in a descending sort pass and, when top-K is also
// configured, masks any surviving value below the K-th largest in the same // configured, masks any surviving value below the K-th largest in the same
// pass. Callers dispatch here whenever top-P is enabled — the top-K-only // pass. Callers dispatch here whenever top-P is enabled — the top-K-only case
// case uses a cheaper partial sort via the topK transform. // uses a cheaper partial sort via the topK transform.
func topKTopP(s *Sampler, scores *mlx.Array) *mlx.Array { func topKTopP(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
vocab := scores.Dim(scores.NumDims() - 1) vocab := scores.Dim(scores.NumDims() - 1)
applyTopK := s.TopK > 0 && s.TopK < vocab applyTopK := ctx.opts.TopK > 0 && ctx.opts.TopK < vocab
order := scores.Negative().ArgsortAxis(-1) order := scores.Negative().ArgsortAxis(-1)
sorted := scores.TakeAlongAxis(order, -1) sorted := scores.TakeAlongAxis(order, -1)
negInf := mlx.FromValue(float32(math.Inf(-1))) negInf := mlx.FromValue(float32(math.Inf(-1)))
// Top-P: in descending order, keep tokens whose exclusive cumulative // Top-P: in descending order, keep tokens whose exclusive cumulative
// probability is still below s.TopP. // probability is still below TopP.
probs := mlx.SoftmaxAxis(sorted, -1, true) probs := mlx.SoftmaxAxis(sorted, -1, true)
prevCumProbs := probs.Cumsum(-1, false, true).Subtract(probs) prevCumProbs := probs.Cumsum(-1, false, true).Subtract(probs)
keep := prevCumProbs.Less(mlx.FromValue(s.TopP)) keep := prevCumProbs.Less(mlx.FromValue(ctx.opts.TopP))
sorted = mlx.Where(keep, sorted, negInf) sorted = mlx.Where(keep, sorted, negInf)
out := scores.PutAlongAxis(order, sorted, -1) out := scores.PutAlongAxis(order, sorted, -1)
// Top-K: sorted is already in descending order, so positions [K, V) // Top-K: sorted is already in descending order, so positions [K, V) are
// are the ones to drop. Scatter -inf through their original-layout // the ones to drop. Scatter -inf through their original-layout indices
// indices (order[K:]). Positional (not value-based) so exactly K // (order[K:]). Positional (not value-based) so exactly K tokens survive —
// tokens survive — ties at the K-th logit get broken by the sort // ties at the K-th logit get broken by the sort order rather than
// order rather than promoted through the filter. // promoted through the filter.
if applyTopK { if applyTopK {
dropOrder := order.Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End)) dropOrder := order.Slice(mlx.Slice(), mlx.Slice(ctx.opts.TopK, mlx.End))
out = out.PutAlongAxis(dropOrder, negInf, -1) out = out.PutAlongAxis(dropOrder, negInf, -1)
} }
return out return out
} }
func minP(s *Sampler, scores *mlx.Array) *mlx.Array { func minP(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
if s.MinP <= 0 || s.MinP > 1 { if ctx.opts.MinP <= 0 || ctx.opts.MinP > 1 {
return scores return scores
} }
maxScore := scores.MaxAxis(-1, true) maxScore := scores.MaxAxis(-1, true)
threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(s.MinP)))) threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(ctx.opts.MinP))))
return mlx.Where( return mlx.Where(
scores.Less(threshold), scores.Less(threshold),
@@ -233,48 +531,43 @@ func minP(s *Sampler, scores *mlx.Array) *mlx.Array {
) )
} }
func topK(s *Sampler, scores *mlx.Array) *mlx.Array { func topK(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
if s.TopK <= 0 { if ctx.opts.TopK <= 0 {
return scores return scores
} }
vocab := scores.Dim(scores.NumDims() - 1) vocab := scores.Dim(scores.NumDims() - 1)
if s.TopK >= vocab { if ctx.opts.TopK >= vocab {
return scores return scores
} }
mask := scores.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End)) mask := scores.Negative().ArgpartitionAxis(ctx.opts.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(ctx.opts.TopK, mlx.End))
return scores.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1) return scores.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
} }
func penalty(s *Sampler, scores *mlx.Array) *mlx.Array { func penalty(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
if s.historyLen == 0 { tokenIndices := ctx.history
if tokenIndices == nil {
return scores return scores
} }
tokenIndices := s.history if ctx.opts.RepeatPenalty != 1 || ctx.opts.PresencePenalty != 0 {
if scores.NumDims() > 1 {
tokenIndices = tokenIndices.ExpandDims(0)
}
if s.RepeatPenalty != 1 || s.PresencePenalty != 0 {
adjusted := scores.TakeAlongAxis(tokenIndices, -1) adjusted := scores.TakeAlongAxis(tokenIndices, -1)
if s.RepeatPenalty != 1 { if ctx.opts.RepeatPenalty != 1 {
factor := mlx.Where( factor := mlx.Where(
adjusted.Less(mlx.FromValue(float32(0))), adjusted.Less(mlx.FromValue(float32(0))),
mlx.FromValue(s.RepeatPenalty), mlx.FromValue(ctx.opts.RepeatPenalty),
mlx.FromValue(1/s.RepeatPenalty), mlx.FromValue(1/ctx.opts.RepeatPenalty),
) )
adjusted = adjusted.Multiply(factor) adjusted = adjusted.Multiply(factor)
} }
if s.PresencePenalty != 0 { if ctx.opts.PresencePenalty != 0 {
adjusted = mlx.AddScalar(adjusted, -s.PresencePenalty) adjusted = mlx.AddScalar(adjusted, -ctx.opts.PresencePenalty)
} }
scores = scores.PutAlongAxis(tokenIndices, adjusted, -1) scores = scores.PutAlongAxis(tokenIndices, adjusted, -1)
} }
if s.FrequencyPenalty != 0 { if ctx.opts.FrequencyPenalty != 0 {
scores = scores.ScatterAddAxis(tokenIndices, mlx.FromValue(-s.FrequencyPenalty), -1) scores = scores.ScatterAddAxis(tokenIndices, mlx.FromValue(-ctx.opts.FrequencyPenalty), -1)
} }
return scores return scores

View File

@@ -9,93 +9,283 @@ import (
"github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/mlx"
) )
func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) { // slotLogits builds a [1, V] logits tensor for a single-slot Sample call.
s := New(Options{RepeatLastN: 1, PresencePenalty: 6}) func slotLogits(values []float32) *mlx.Array {
defer func() { return mlx.FromValues(values, 1, len(values))
}
// batchLogits stacks per-row float32 slices of equal length into a [B, V]
// logits tensor.
func batchLogits(rows ...[]float32) *mlx.Array {
v := len(rows[0])
flat := make([]float32, 0, len(rows)*v)
for _, r := range rows {
if len(r) != v {
panic("batchLogits: rows must share vocab size")
}
flat = append(flat, r...)
}
return mlx.FromValues(flat, len(rows), v)
}
// sampleOne runs Sample on a freshly-added single slot and returns the
// sampled token id. Used both for the single-slot options table and as the
// reference oracle for the batched-equivalence test.
func sampleOne(t *testing.T, opts Options, priorTokens []int32, values []float32) int {
t.Helper()
s := New(128)
t.Cleanup(func() {
s.Free() s.Free()
mlx.Sweep() mlx.Sweep()
}() })
s.Add(0, opts, priorTokens)
s.ResetHistory([]int32{0}) got := s.Sample([]int{0}, slotLogits(values)).Token
s.AppendToken(mlx.NewArrayInt32([]int32{1}, []int32{1}))
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
got := s.Sample(logits).Token
mlx.Eval(got) mlx.Eval(got)
return got.Int()
}
// logits will be [0, -1, 4] after the penalty // logOf returns log(p) as a float32 so tests can build logits that softmax to
// and then (index) 2 after the greedy sampler // a chosen probability distribution.
gotInt := got.Int() func logOf(p float64) float32 { return float32(math.Log(p)) }
if gotInt != 2 {
t.Fatalf("got %d, want 2", gotInt) // TestSampleSingleSlotOptions pins the per-slot behavior of each Options
// knob against a concrete expected token. Expected values are worked out by
// hand from the math of each transform, not from a second call into the
// sampler — so a regression in any single transform shows up here.
func TestSampleSingleSlotOptions(t *testing.T) {
cases := []struct {
name string
opts Options
priors []int32
logits []float32
want int
}{
{
name: "presence penalty",
opts: Options{RepeatLastN: 1, PresencePenalty: 6},
priors: []int32{1},
logits: []float32{0, 5, 4},
want: 2, // token 1: 5 - 6 = -1, argmax shifts to 2
},
{
name: "repeat penalty on positive logits",
opts: Options{RepeatLastN: 1, RepeatPenalty: 2},
priors: []int32{1},
logits: []float32{0, 5, 4},
want: 2, // token 1 positive → divided: 5/2 = 2.5, argmax shifts to 2
},
{
name: "repeat penalty on negative logits",
opts: Options{RepeatLastN: 1, RepeatPenalty: 4},
priors: []int32{1},
logits: []float32{-5, -1, -3},
want: 2, // token 1 negative → multiplied: -1*4 = -4, argmax shifts to 2
},
{
name: "frequency penalty",
opts: Options{RepeatLastN: 4, FrequencyPenalty: 2},
priors: []int32{1, 1},
logits: []float32{0, 5, 4},
want: 2, // 5 - 2*count(1)=2*2=4 → 1, argmax shifts to 2
},
{
name: "top-k",
opts: Options{Temperature: 1, TopK: 1},
logits: []float32{1, 5, 4},
want: 1, // only argmax survives → deterministic even with temperature
},
{
name: "top-p",
opts: Options{Temperature: 1, TopP: 0.4},
logits: []float32{logOf(0.5), logOf(0.3), logOf(0.2)},
want: 0, // exclusive cumsum below 0.4 keeps only token 0
},
{
name: "min-p",
opts: Options{Temperature: 1, MinP: 0.7},
logits: []float32{logOf(0.5), logOf(0.3), logOf(0.2)},
want: 0, // threshold 0.5*0.7=0.35 drops all but the top token
},
{
name: "RepeatLastN=0 disables penalties",
opts: Options{RepeatLastN: 0, RepeatPenalty: 2, PresencePenalty: 10},
priors: []int32{1},
logits: []float32{0, 5, 4},
want: 1, // 0 = disabled per API contract, argmax unchanged
},
{
name: "RepeatLastN=-1 resolves to num_ctx",
opts: Options{RepeatLastN: -1, PresencePenalty: 6},
priors: []int32{1},
logits: []float32{0, 5, 4},
want: 2, // -1 → num_ctx (128); penalty applies, argmax shifts
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
if got := sampleOne(t, tc.opts, tc.priors, tc.logits); got != tc.want {
t.Errorf("got %d, want %d", got, tc.want)
}
})
} }
} }
func TestRepeatPenaltyUsesHistoryWithoutPresencePenalty(t *testing.T) { // TestSampleHistoryWindow verifies that penalty history respects the
s := New(Options{RepeatLastN: 1, RepeatPenalty: 2}) // RepeatLastN window: priors longer than RepeatLastN are trimmed on Add,
defer func() { // and once the ring wraps, tokens that rotate out no longer contribute
// to penalties.
func TestSampleHistoryWindow(t *testing.T) {
s := New(128)
t.Cleanup(func() {
s.Free() s.Free()
mlx.Sweep() mlx.Sweep()
}() })
s.ResetHistory([]int32{1}) // RepeatLastN=2 with priors {1, 2, 3}: makeHistoryRow keeps only
// {2, 3}. Token 1 was trimmed — its penalty is NOT active.
s.Add(0, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{1, 2, 3})
logits := mlx.FromValues([]float32{0, 5, 4}, 3) // Step 1: logits favor token 1 (trimmed). If the trim were broken it
got := s.Sample(logits).Token // would be penalized and the argmax would move.
mlx.Eval(got) step1 := s.Sample([]int{0}, slotLogits([]float32{0, 5, 0, 0, 0})).Token
mlx.Eval(step1)
if got := step1.Int(); got != 1 {
t.Fatalf("step 1 = %d, want 1 (token 1 trimmed from priors)", got)
}
// After step 1 the ring holds {1, 3}; token 2 has rotated out.
// token 1 is repeated and positive, so 5 / 2 falls below token 2. // Step 2: logits favor token 2 (rotated out). If the ring wrap were
gotInt := got.Int() // wrong, token 2 would still be penalized.
if gotInt != 2 { step2 := s.Sample([]int{0}, slotLogits([]float32{0, 0, 5, 0, 0})).Token
t.Fatalf("got %d, want 2", gotInt) mlx.Eval(step2)
if got := step2.Int(); got != 2 {
t.Fatalf("step 2 = %d, want 2 (token 2 rotated out of ring)", got)
} }
} }
func TestFrequencyPenaltyUsesTokenCounts(t *testing.T) { // TestBatchSamplingPreservesPerSlotBehavior is the core equivalence test:
s := New(Options{RepeatLastN: 4, FrequencyPenalty: 2}) // for every representative dispatch branch (uniform, serial on mixed opts,
defer func() { // serial on partial ring, subset/out-of-order), a batched Sample call must
s.Free() // produce the same token per row as running the same slot alone.
mlx.Sweep() func TestBatchSamplingPreservesPerSlotBehavior(t *testing.T) {
}() type slot struct {
id int
opts Options
priors []int32
}
s.ResetHistory([]int32{1, 1}) cases := []struct {
name string
slots []slot
sample []int
rows [][]float32
}{
{
name: "uniform",
slots: []slot{
{10, Options{RepeatLastN: 2, PresencePenalty: 5}, []int32{1, 2}},
{20, Options{RepeatLastN: 2, PresencePenalty: 5}, []int32{0, 2}},
},
sample: []int{10, 20},
rows: [][]float32{{0, 5, 4}, {3, 0, 0}},
},
{
name: "serial — mixed opts",
slots: []slot{
{1, Options{RepeatLastN: 1, RepeatPenalty: 2}, []int32{1}},
{2, Options{Temperature: 1, TopK: 1}, nil},
},
sample: []int{1, 2},
rows: [][]float32{{0, 5, 4, 1}, {2, 1, 5, 3}},
},
{
name: "serial — partial ring",
slots: []slot{
{1, Options{RepeatLastN: 4, PresencePenalty: 5}, []int32{1, 1, 1, 1}},
{2, Options{RepeatLastN: 4, PresencePenalty: 5}, []int32{2}},
},
sample: []int{1, 2},
rows: [][]float32{{0, 5, 4}, {0, 4, 5}},
},
{
name: "subset out-of-order",
slots: []slot{
{10, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{1, 1}},
{20, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{2, 2}},
{30, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{3, 3}},
},
sample: []int{30, 10},
rows: [][]float32{{5, 5, 5, 0, 5, 5}, {5, 0, 5, 5, 0, 5}},
},
}
logits := mlx.FromValues([]float32{0, 5, 4}, 3) for _, tc := range cases {
got := s.Sample(logits).Token t.Run(tc.name, func(t *testing.T) {
mlx.Eval(got) // Per-slot reference for each sampled seq.
want := make([]int, len(tc.sample))
for i, id := range tc.sample {
var spec slot
for _, s := range tc.slots {
if s.id == id {
spec = s
break
}
}
want[i] = sampleOne(t, spec.opts, spec.priors, tc.rows[i])
}
// token 1 appears twice, so 5 - (2 * 2) falls below token 2. // Batched call.
gotInt := got.Int() s := New(128)
if gotInt != 2 { t.Cleanup(func() {
t.Fatalf("got %d, want 2", gotInt) s.Free()
mlx.Sweep()
})
for _, spec := range tc.slots {
s.Add(spec.id, spec.opts, spec.priors)
}
res := s.Sample(tc.sample, batchLogits(tc.rows...))
mlx.Eval(res.Token)
got := res.Token.Ints()
for i, id := range tc.sample {
if got[i] != want[i] {
t.Errorf("seq %d: batched = %d, per-slot = %d", id, got[i], want[i])
}
}
})
} }
} }
func TestMinPMasksTokensBelowThreshold(t *testing.T) { // TestRemoveDoesNotLeakHistory: after Remove, a newly-added slot at the
s := New(Options{MinP: 0.5}) // recycled row must start from its own priors only — no carryover from
defer func() { // the removed slot's history.
func TestRemoveDoesNotLeakHistory(t *testing.T) {
opts := Options{RepeatLastN: 1, PresencePenalty: 10}
s := New(128)
t.Cleanup(func() {
s.Free() s.Free()
mlx.Sweep() mlx.Sweep()
}() })
s.Add(1, opts, []int32{1})
s.Add(2, opts, []int32{2})
s.Remove(1)
s.Add(3, opts, []int32{0})
logits := mlx.FromValues([]float32{ // Slot 2 retains history {2}; slot 3 retains history {0}. With
float32(math.Log(0.5)), // equal logits and PresencePenalty=10 the argmax drops to the first
float32(math.Log(0.3)), // unpenalized token.
float32(math.Log(0.2)), res := s.Sample([]int{2, 3}, batchLogits(
}, 3) []float32{3, 3, 0},
got := minP(s, logits) []float32{3, 3, 0},
mlx.Eval(got) ))
mlx.Eval(res.Token)
gotFloats := got.Floats() tokens := res.Token.Ints()
if len(gotFloats) != 3 { if tokens[0] != 0 {
t.Fatalf("got %d scores, want 3", len(gotFloats)) t.Errorf("slot 2 = %d, want 0 (token 2 penalized)", tokens[0])
} }
if tokens[1] != 1 {
if math.IsInf(float64(gotFloats[0]), -1) || math.IsInf(float64(gotFloats[1]), -1) { t.Errorf("slot 3 = %d, want 1 (token 0 penalized, no slot-1 carryover)", tokens[1])
t.Fatalf("kept tokens were masked: %v", gotFloats)
}
if !math.IsInf(float64(gotFloats[2]), -1) {
t.Fatalf("lowest-probability token should be masked, got %v", gotFloats)
} }
} }

View File

@@ -93,7 +93,7 @@ func Execute(args []string) error {
} }
request.Pipeline = runner.TextGenerationPipeline request.Pipeline = runner.TextGenerationPipeline
request.Sampler = sample.New(sample.Options{ request.SamplerOpts = sample.Options{
Temperature: request.Options.Temperature, Temperature: request.Options.Temperature,
TopP: request.Options.TopP, TopP: request.Options.TopP,
MinP: request.Options.MinP, MinP: request.Options.MinP,
@@ -104,7 +104,7 @@ func Execute(args []string) error {
FrequencyPenalty: request.Options.FrequencyPenalty, FrequencyPenalty: request.Options.FrequencyPenalty,
Logprobs: request.Logprobs, Logprobs: request.Logprobs,
TopLogprobs: request.TopLogprobs, TopLogprobs: request.TopLogprobs,
}) }
if err := runner.Prepare(&request); err != nil { if err := runner.Prepare(&request); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)