mirror of
https://github.com/ollama/ollama.git
synced 2026-04-23 01:05:47 +02:00
When both filters are active, avoid paying for a full sort in top-P and a partial sort in top-K. Single-filter paths are unchanged. Improves generation throughput on gemma4:e4b by 1.5%.
282 lines
7.7 KiB
Go
282 lines
7.7 KiB
Go
package sample
|
|
|
|
import (
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
)
|
|
|
|
type Transform func(*Sampler, *mlx.Array) *mlx.Array
|
|
|
|
type Options struct {
|
|
Temperature float32
|
|
TopP float32
|
|
MinP float32
|
|
TopK int
|
|
RepeatLastN int
|
|
RepeatPenalty float32
|
|
PresencePenalty float32
|
|
FrequencyPenalty float32
|
|
|
|
// Logprobs causes Sample to populate Result.Logprob with the selected
|
|
// token's log-probability. TopLogprobs (when > 0) adds top-K pairs.
|
|
Logprobs bool
|
|
TopLogprobs int
|
|
}
|
|
|
|
type Sampler struct {
|
|
Options
|
|
|
|
history *mlx.Array
|
|
historyLen int
|
|
transforms []Transform
|
|
}
|
|
|
|
// Result bundles the outputs of one decode step. The logprob tensors are
|
|
// populated only when the sampler is configured to report them.
|
|
type Result struct {
|
|
Token *mlx.Array // sampled token id, shape [B]
|
|
Logprob *mlx.Array // sampled-token logprob, shape [B,1]; nil unless Logprobs
|
|
TopTokens *mlx.Array // top-K token ids, shape [B,K]; nil unless TopLogprobs > 0
|
|
TopLogprobs *mlx.Array // top-K logprobs, shape [B,K]; nil unless TopLogprobs > 0
|
|
}
|
|
|
|
// Arrays returns the tensor fields as a slice so callers can drive the mlx
|
|
// lifecycle verbs (Pin, Unpin, Eval, AsyncEval) over the whole group. Unset
|
|
// fields stay nil; the mlx helpers skip them.
|
|
func (r Result) Arrays() []*mlx.Array {
|
|
return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs}
|
|
}
|
|
|
|
func New(opts Options) *Sampler {
|
|
if opts.RepeatPenalty <= 0 {
|
|
opts.RepeatPenalty = 1
|
|
}
|
|
|
|
s := &Sampler{Options: opts}
|
|
|
|
var transforms []Transform
|
|
if s.usesHistory() {
|
|
transforms = append(transforms, penalty)
|
|
}
|
|
|
|
hasTopP := opts.TopP > 0 && opts.TopP < 1
|
|
hasTopK := opts.TopK > 0
|
|
switch {
|
|
case hasTopP:
|
|
// topKTopP always does a full descending sort for the top-P
|
|
// cumulative mask and opportunistically masks top-K during the
|
|
// same pass when it is also configured.
|
|
transforms = append(transforms, topKTopP)
|
|
case hasTopK:
|
|
// Argpartition (partial sort) is cheaper than a full sort.
|
|
transforms = append(transforms, topK)
|
|
}
|
|
|
|
if opts.MinP != 0 {
|
|
transforms = append(transforms, minP)
|
|
}
|
|
|
|
if opts.Temperature == 0 {
|
|
transforms = append(transforms, greedy)
|
|
} else {
|
|
transforms = append(transforms, temperature)
|
|
}
|
|
|
|
s.transforms = transforms
|
|
return s
|
|
}
|
|
|
|
func (s *Sampler) usesHistory() bool {
|
|
return s.RepeatPenalty != 1 || s.PresencePenalty != 0 || s.FrequencyPenalty != 0
|
|
}
|
|
|
|
func (s *Sampler) setHistory(history *mlx.Array, historyLen int) {
|
|
if history != nil {
|
|
mlx.Pin(history)
|
|
}
|
|
if s.history != nil {
|
|
mlx.Unpin(s.history)
|
|
}
|
|
s.history = history
|
|
s.historyLen = historyLen
|
|
}
|
|
|
|
func (s *Sampler) ResetHistory(history []int32) {
|
|
if !s.usesHistory() {
|
|
return
|
|
}
|
|
if s.RepeatLastN > 0 && len(history) > s.RepeatLastN {
|
|
history = history[len(history)-s.RepeatLastN:]
|
|
}
|
|
if len(history) == 0 {
|
|
s.setHistory(nil, 0)
|
|
return
|
|
}
|
|
|
|
tokens := append([]int32(nil), history...)
|
|
s.setHistory(mlx.NewArrayInt32(tokens, []int32{int32(len(tokens))}), len(tokens))
|
|
}
|
|
|
|
func (s *Sampler) AppendToken(token *mlx.Array) {
|
|
if !s.usesHistory() || token == nil {
|
|
return
|
|
}
|
|
|
|
next := token.AsType(mlx.DTypeInt32)
|
|
nextLen := next.Size()
|
|
|
|
if s.history != nil && s.historyLen > 0 {
|
|
next = s.history.Concatenate(0, next)
|
|
nextLen += s.historyLen
|
|
}
|
|
|
|
if s.RepeatLastN > 0 && nextLen > s.RepeatLastN {
|
|
trim := nextLen - s.RepeatLastN
|
|
next = next.Slice(mlx.Slice(trim, nextLen))
|
|
nextLen = s.RepeatLastN
|
|
}
|
|
|
|
s.setHistory(next, nextLen)
|
|
}
|
|
|
|
func (s *Sampler) Free() {
|
|
s.setHistory(nil, 0)
|
|
}
|
|
|
|
// Sample runs the configured transform chain on the raw per-token logits
|
|
// and returns the sampled token id plus, when configured, the reported
|
|
// log-probability tensors for the selected token and the top-K tokens.
|
|
func (s *Sampler) Sample(logits *mlx.Array) Result {
|
|
scores := logits
|
|
for _, transform := range s.transforms {
|
|
scores = transform(s, scores)
|
|
}
|
|
res := Result{Token: scores}
|
|
|
|
if s.Logprobs {
|
|
// Compute log_softmax in fp32 and subtract the max before
|
|
// logsumexp so the final subtraction stays on small values.
|
|
// Otherwise it cancels two large numbers and loses precision.
|
|
lp := logits.AsType(mlx.DTypeFloat32)
|
|
lp = lp.Subtract(lp.MaxAxis(-1, true))
|
|
lp = lp.Subtract(lp.Logsumexp(true))
|
|
res.Logprob = lp.TakeAlongAxis(res.Token.ExpandDims(-1), -1)
|
|
if k := s.TopLogprobs; k > 0 {
|
|
if vocab := lp.Dim(lp.NumDims() - 1); k > vocab {
|
|
k = vocab
|
|
}
|
|
// Argpartition on the negated values places the K largest
|
|
// (unsorted) in positions [0:K].
|
|
idx := lp.Negative().ArgpartitionAxis(k-1, -1).Slice(mlx.Slice(), mlx.Slice(0, k))
|
|
res.TopTokens = idx.AsType(mlx.DTypeInt32)
|
|
res.TopLogprobs = lp.TakeAlongAxis(idx, -1)
|
|
}
|
|
}
|
|
return res
|
|
}
|
|
|
|
func greedy(_ *Sampler, scores *mlx.Array) *mlx.Array {
|
|
return scores.Argmax(-1, false)
|
|
}
|
|
|
|
func temperature(s *Sampler, scores *mlx.Array) *mlx.Array {
|
|
return mlx.DivScalar(scores, s.Temperature).Categorical(-1)
|
|
}
|
|
|
|
// topKTopP applies top-P in a descending sort pass and, when top-K is also
|
|
// configured, masks any surviving value below the K-th largest in the same
|
|
// pass. Callers dispatch here whenever top-P is enabled — the top-K-only
|
|
// case uses a cheaper partial sort via the topK transform.
|
|
func topKTopP(s *Sampler, scores *mlx.Array) *mlx.Array {
|
|
vocab := scores.Dim(scores.NumDims() - 1)
|
|
applyTopK := s.TopK > 0 && s.TopK < vocab
|
|
|
|
order := scores.Negative().ArgsortAxis(-1)
|
|
sorted := scores.TakeAlongAxis(order, -1)
|
|
negInf := mlx.FromValue(float32(math.Inf(-1)))
|
|
|
|
// Top-P: in descending order, keep tokens whose exclusive cumulative
|
|
// probability is still below s.TopP.
|
|
probs := mlx.SoftmaxAxis(sorted, -1, true)
|
|
prevCumProbs := probs.Cumsum(-1, false, true).Subtract(probs)
|
|
keep := prevCumProbs.Less(mlx.FromValue(s.TopP))
|
|
sorted = mlx.Where(keep, sorted, negInf)
|
|
|
|
out := scores.PutAlongAxis(order, sorted, -1)
|
|
|
|
// Top-K: sorted is already in descending order, so positions [K, V)
|
|
// are the ones to drop. Scatter -inf through their original-layout
|
|
// indices (order[K:]). Positional (not value-based) so exactly K
|
|
// tokens survive — ties at the K-th logit get broken by the sort
|
|
// order rather than promoted through the filter.
|
|
if applyTopK {
|
|
dropOrder := order.Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
|
|
out = out.PutAlongAxis(dropOrder, negInf, -1)
|
|
}
|
|
|
|
return out
|
|
}
|
|
|
|
func minP(s *Sampler, scores *mlx.Array) *mlx.Array {
|
|
if s.MinP <= 0 || s.MinP > 1 {
|
|
return scores
|
|
}
|
|
|
|
maxScore := scores.MaxAxis(-1, true)
|
|
threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(s.MinP))))
|
|
|
|
return mlx.Where(
|
|
scores.Less(threshold),
|
|
mlx.FromValue(float32(math.Inf(-1))),
|
|
scores,
|
|
)
|
|
}
|
|
|
|
func topK(s *Sampler, scores *mlx.Array) *mlx.Array {
|
|
if s.TopK <= 0 {
|
|
return scores
|
|
}
|
|
|
|
vocab := scores.Dim(scores.NumDims() - 1)
|
|
if s.TopK >= vocab {
|
|
return scores
|
|
}
|
|
|
|
mask := scores.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
|
|
return scores.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
|
}
|
|
|
|
func penalty(s *Sampler, scores *mlx.Array) *mlx.Array {
|
|
if s.historyLen == 0 {
|
|
return scores
|
|
}
|
|
|
|
tokenIndices := s.history
|
|
if scores.NumDims() > 1 {
|
|
tokenIndices = tokenIndices.ExpandDims(0)
|
|
}
|
|
|
|
if s.RepeatPenalty != 1 || s.PresencePenalty != 0 {
|
|
adjusted := scores.TakeAlongAxis(tokenIndices, -1)
|
|
if s.RepeatPenalty != 1 {
|
|
factor := mlx.Where(
|
|
adjusted.Less(mlx.FromValue(float32(0))),
|
|
mlx.FromValue(s.RepeatPenalty),
|
|
mlx.FromValue(1/s.RepeatPenalty),
|
|
)
|
|
adjusted = adjusted.Multiply(factor)
|
|
}
|
|
if s.PresencePenalty != 0 {
|
|
adjusted = mlx.AddScalar(adjusted, -s.PresencePenalty)
|
|
}
|
|
scores = scores.PutAlongAxis(tokenIndices, adjusted, -1)
|
|
}
|
|
|
|
if s.FrequencyPenalty != 0 {
|
|
scores = scores.ScatterAddAxis(tokenIndices, mlx.FromValue(-s.FrequencyPenalty), -1)
|
|
}
|
|
|
|
return scores
|
|
}
|