Files
ollama/x/mlxrunner/sample/sample.go
Daniel Hiltgen 10e51c5177 MLX: add header vendoring and remove go build tag (#14642)
* prefer rocm v6 on windows

Avoid building with v7 - more changes are needed

* MLX: add header vendoring and remove go build tag

This switches to using a vendoring approach for the mlx-c headers so that Go
can build without requiring a cmake first.  This enables building the new MLX
based code by default.  Every time cmake runs, the headers are refreshed, so we
can easily keep them in sync when we bump mlx versions.  Basic Windows
and Linux support are verified.

* ci: harden for flaky choco repo servers

CI sometimes fails due to choco not actually installing cache.  Since it just speeds up the build, we can proceed without.

* review comments
2026-03-09 17:24:45 -07:00

190 lines
4.3 KiB
Go

package sample
import (
"math"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type Transform func(*Sampler, *mlx.Array) *mlx.Array
type Sampler struct {
Temperature float32
TopP float32
MinP float32
TopK int
RepeatLastN int
PresencePenalty float32
history *mlx.Array
historyLen int
transforms []Transform
}
func New(temp, top_p, min_p float32, top_k, repeatLastN int, presencePenalty float32) *Sampler {
s := &Sampler{
Temperature: temp,
TopP: top_p,
MinP: min_p,
TopK: top_k,
RepeatLastN: repeatLastN,
PresencePenalty: presencePenalty,
}
var transforms []Transform
if presencePenalty != 0 {
transforms = append(transforms, penalty)
}
if top_p > 0 && top_p < 1 {
transforms = append(transforms, topP)
}
if min_p != 0 {
transforms = append(transforms, minP)
}
if top_k > 0 {
transforms = append(transforms, topK)
}
if temp == 0 {
transforms = append(transforms, greedy)
} else {
transforms = append(transforms, temperature)
}
s.transforms = transforms
return s
}
func (s *Sampler) usesHistory() bool {
return s.PresencePenalty != 0
}
func (s *Sampler) setHistory(history *mlx.Array, historyLen int) {
if history != nil {
mlx.Pin(history)
}
if s.history != nil {
mlx.Unpin(s.history)
}
s.history = history
s.historyLen = historyLen
}
func (s *Sampler) ResetHistory(history []int32) {
if !s.usesHistory() {
return
}
if s.RepeatLastN > 0 && len(history) > s.RepeatLastN {
history = history[len(history)-s.RepeatLastN:]
}
if len(history) == 0 {
s.setHistory(nil, 0)
return
}
tokens := append([]int32(nil), history...)
s.setHistory(mlx.NewArrayInt32(tokens, []int32{int32(len(tokens))}), len(tokens))
}
func (s *Sampler) AppendToken(token *mlx.Array) {
if !s.usesHistory() || token == nil {
return
}
next := token.AsType(mlx.DTypeInt32)
nextLen := next.Size()
if s.history != nil && s.historyLen > 0 {
next = s.history.Concatenate(0, next)
nextLen += s.historyLen
}
if s.RepeatLastN > 0 && nextLen > s.RepeatLastN {
trim := nextLen - s.RepeatLastN
next = next.Slice(mlx.Slice(trim, nextLen))
nextLen = s.RepeatLastN
}
s.setHistory(next, nextLen)
}
func (s *Sampler) Free() {
s.setHistory(nil, 0)
}
func (s *Sampler) Sample(logits *mlx.Array) *mlx.Array {
for _, transform := range s.transforms {
logits = transform(s, logits)
}
return logits
}
func greedy(_ *Sampler, logits *mlx.Array) *mlx.Array {
return logits.Argmax(-1, false)
}
func temperature(s *Sampler, logits *mlx.Array) *mlx.Array {
return mlx.DivScalar(logits, s.Temperature).Categorical(-1)
}
func topP(s *Sampler, logprobs *mlx.Array) *mlx.Array {
if s.TopP <= 0 || s.TopP >= 1 {
return logprobs
}
order := logprobs.Negative().ArgsortAxis(-1)
sortedLogprobs := logprobs.TakeAlongAxis(order, -1)
sortedProbs := mlx.SoftmaxAxis(sortedLogprobs, -1, true)
prevCumProbs := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs)
keep := prevCumProbs.Less(mlx.FromValue(s.TopP))
filtered := mlx.Where(keep, sortedLogprobs, mlx.FromValue(float32(math.Inf(-1))))
return logprobs.PutAlongAxis(order, filtered, -1)
}
func minP(s *Sampler, logprobs *mlx.Array) *mlx.Array {
if s.MinP <= 0 || s.MinP > 1 {
return logprobs
}
maxLogprobs := logprobs.TakeAlongAxis(logprobs.Argmax(-1, true), -1)
minLogprobs := mlx.AddScalar(maxLogprobs, float32(math.Log(float64(s.MinP))))
return mlx.Where(
logprobs.Less(minLogprobs),
mlx.FromValue(float32(math.Inf(-1))),
logprobs,
)
}
func topK(s *Sampler, logprobs *mlx.Array) *mlx.Array {
if s.TopK <= 0 {
return logprobs
}
vocab := logprobs.Dim(logprobs.NumDims() - 1)
if s.TopK >= vocab {
return logprobs
}
mask := logprobs.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, 0))
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
}
func penalty(s *Sampler, logprobs *mlx.Array) *mlx.Array {
if s.history == nil || s.historyLen == 0 || s.PresencePenalty == 0 {
return logprobs
}
tokenIndices := s.history
if logprobs.NumDims() > 1 {
tokenIndices = tokenIndices.ExpandDims(0)
}
selected := logprobs.TakeAlongAxis(tokenIndices, -1)
adjusted := mlx.AddScalar(selected, -s.PresencePenalty)
return logprobs.PutAlongAxis(tokenIndices, adjusted, -1)
}