mirror of
https://github.com/ollama/ollama.git
synced 2026-04-23 09:15:44 +02:00
103 lines
2.1 KiB
Go
103 lines
2.1 KiB
Go
//go:build mlx
|
|
|
|
package sample
|
|
|
|
import (
|
|
"math"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
)
|
|
|
|
func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
|
|
// RepeatLastN = 1, PresencePenalty = 6
|
|
s := New(0, 0, 0, 0, 1, 1, 6, 0)
|
|
defer func() {
|
|
s.Free()
|
|
mlx.Sweep()
|
|
}()
|
|
|
|
s.ResetHistory([]int32{0})
|
|
s.AppendToken(mlx.NewArrayInt32([]int32{1}, []int32{1}))
|
|
|
|
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
|
|
got := s.Sample(logits)
|
|
mlx.Eval(got)
|
|
|
|
// logits will be [0, -1, 4] after the penalty
|
|
// and then (index) 2 after the greedy sampler
|
|
gotInt := got.Int()
|
|
if gotInt != 2 {
|
|
t.Fatalf("got %d, want 2", gotInt)
|
|
}
|
|
}
|
|
|
|
func TestRepeatPenaltyUsesHistoryWithoutPresencePenalty(t *testing.T) {
|
|
s := New(0, 0, 0, 0, 1, 2, 0, 0)
|
|
defer func() {
|
|
s.Free()
|
|
mlx.Sweep()
|
|
}()
|
|
|
|
s.ResetHistory([]int32{1})
|
|
|
|
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
|
|
got := s.Sample(logits)
|
|
mlx.Eval(got)
|
|
|
|
// token 1 is repeated and positive, so 5 / 2 falls below token 2.
|
|
gotInt := got.Int()
|
|
if gotInt != 2 {
|
|
t.Fatalf("got %d, want 2", gotInt)
|
|
}
|
|
}
|
|
|
|
func TestFrequencyPenaltyUsesTokenCounts(t *testing.T) {
|
|
s := New(0, 0, 0, 0, 4, 1, 0, 2)
|
|
defer func() {
|
|
s.Free()
|
|
mlx.Sweep()
|
|
}()
|
|
|
|
s.ResetHistory([]int32{1, 1})
|
|
|
|
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
|
|
got := s.Sample(logits)
|
|
mlx.Eval(got)
|
|
|
|
// token 1 appears twice, so 5 - (2 * 2) falls below token 2.
|
|
gotInt := got.Int()
|
|
if gotInt != 2 {
|
|
t.Fatalf("got %d, want 2", gotInt)
|
|
}
|
|
}
|
|
|
|
func TestMinPMasksTokensBelowThreshold(t *testing.T) {
|
|
s := New(0, 0, 0.5, 0, 0, 1, 0, 0)
|
|
defer func() {
|
|
s.Free()
|
|
mlx.Sweep()
|
|
}()
|
|
|
|
logits := mlx.FromValues([]float32{
|
|
float32(math.Log(0.5)),
|
|
float32(math.Log(0.3)),
|
|
float32(math.Log(0.2)),
|
|
}, 3)
|
|
got := minP(s, logits)
|
|
mlx.Eval(got)
|
|
|
|
gotFloats := got.Floats()
|
|
if len(gotFloats) != 3 {
|
|
t.Fatalf("got %d scores, want 3", len(gotFloats))
|
|
}
|
|
|
|
if math.IsInf(float64(gotFloats[0]), -1) || math.IsInf(float64(gotFloats[1]), -1) {
|
|
t.Fatalf("kept tokens were masked: %v", gotFloats)
|
|
}
|
|
|
|
if !math.IsInf(float64(gotFloats[2]), -1) {
|
|
t.Fatalf("lowest-probability token should be masked, got %v", gotFloats)
|
|
}
|
|
}
|