This commit is contained in:
ParthSareen
2025-03-12 00:46:12 -04:00
parent 4aeb67ef4c
commit a5d638dfe7
4 changed files with 191 additions and 63 deletions

View File

@@ -3,6 +3,7 @@ package sample
import (
"container/heap"
"math"
"math/rand"
"slices"
)
@@ -25,32 +26,6 @@ func (h *tokenHeap) Pop() any {
return x
}
// temperature applies scaling and softmax to the logits
func temperature(ts []token, temp float32) []token {
// Find max logit for numerical stability
maxLogit := float32(math.Inf(-1))
for _, t := range ts {
if t.value > maxLogit {
maxLogit = t.value
}
}
// Apply temperature and compute exp(x - max)
temp = max(temp, 1e-7)
var sum float32
for i, v := range ts {
ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp)))
sum += ts[i].value
}
// Normalize
for i := range ts {
ts[i].value /= sum
}
return ts
}
// topK limits the number of tokens considered to the k highest logits
func topK(ts []token, k int) []token {
if k >= len(ts) || k <= 0 {
@@ -134,3 +109,59 @@ func minP(ts []token, p float32) []token {
ts = validTokens
return ts
}
func temperature(ts []token, temp float32) {
for i := range ts {
ts[i].value /= temp
}
}
func softmax(ts []token) {
if len(ts) == 0 {
return
}
// Find max logit for numerical stability
maxLogit := ts[0].value
for _, t := range ts {
if t.value > maxLogit {
maxLogit = t.value
}
}
// Compute exp(logit - maxLogit) and sum them
var sumExp float32
for i, t := range ts {
expVal := float32(math.Exp(float64(t.value - maxLogit)))
ts[i].value = expVal
sumExp += expVal
}
// Normalize probabilities
for i := range ts {
ts[i].value /= sumExp
}
}
// applyDist selects a token based on probabilities and seed
func dist(ts []token, seed int64) int {
rng := rand.New(rand.NewSource(seed))
cdf := make([]float32, len(ts))
var cumSum float32
for i, t := range ts {
cumSum += t.value
cdf[i] = cumSum
}
r := rng.Float32() * cumSum
// Select token based on CDF
for i, probSum := range cdf {
if r < probSum {
return i
}
}
return len(ts) - 1
}