mirror of
https://github.com/ollama/ollama.git
synced 2026-04-22 00:36:11 +02:00
Compare commits
6 Commits
pdevine/ad
...
jessegross
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a50199cd70 | ||
|
|
5264ba9194 | ||
|
|
ce99f24731 | ||
|
|
04f5f0cdb4 | ||
|
|
fb36a01ffe | ||
|
|
0c65ed33bc |
@@ -381,7 +381,7 @@ export const useSendMessage = (chatId: string) => {
|
||||
role: "assistant",
|
||||
content: "",
|
||||
thinking: "",
|
||||
model: effectiveModel,
|
||||
model: effectiveModel.model,
|
||||
}),
|
||||
);
|
||||
lastMessage = newMessages[newMessages.length - 1];
|
||||
@@ -433,7 +433,7 @@ export const useSendMessage = (chatId: string) => {
|
||||
role: "assistant",
|
||||
content: "",
|
||||
thinking: "",
|
||||
model: effectiveModel,
|
||||
model: effectiveModel.model,
|
||||
}),
|
||||
);
|
||||
lastMessage = newMessages[newMessages.length - 1];
|
||||
@@ -520,7 +520,7 @@ export const useSendMessage = (chatId: string) => {
|
||||
thinkingTimeStart:
|
||||
lastMessage.thinkingTimeStart || event.thinkingTimeStart,
|
||||
thinkingTimeEnd: event.thinkingTimeEnd,
|
||||
model: selectedModel,
|
||||
model: selectedModel.model,
|
||||
});
|
||||
newMessages[newMessages.length - 1] = updatedMessage;
|
||||
} else {
|
||||
@@ -533,7 +533,7 @@ export const useSendMessage = (chatId: string) => {
|
||||
tool_calls: event.toolCalls,
|
||||
thinkingTimeStart: event.thinkingTimeStart,
|
||||
thinkingTimeEnd: event.thinkingTimeEnd,
|
||||
model: selectedModel,
|
||||
model: selectedModel.model,
|
||||
}),
|
||||
);
|
||||
}
|
||||
@@ -699,7 +699,7 @@ export const useSendMessage = (chatId: string) => {
|
||||
queryClient.setQueryData(["chat", newId], {
|
||||
chat: new Chat({
|
||||
id: newId,
|
||||
model: effectiveModel,
|
||||
model: effectiveModel.model,
|
||||
messages: [
|
||||
new Message({
|
||||
role: "user",
|
||||
|
||||
57
cmd/cmd.go
57
cmd/cmd.go
@@ -1975,8 +1975,61 @@ func launchInteractiveModel(cmd *cobra.Command, modelName string) error {
|
||||
Options: map[string]any{},
|
||||
ShowConnect: true,
|
||||
}
|
||||
// loadOrUnloadModel is cloud-safe here: remote/cloud models skip local preload
|
||||
// and only validate auth/connectivity before interactive chat starts.
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
requestedCloud := modelref.HasExplicitCloudSource(modelName)
|
||||
|
||||
info, err := func() (*api.ShowResponse, error) {
|
||||
showReq := &api.ShowRequest{Name: modelName}
|
||||
info, err := client.Show(cmd.Context(), showReq)
|
||||
var se api.StatusError
|
||||
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
||||
if requestedCloud {
|
||||
return nil, err
|
||||
}
|
||||
if err := PullHandler(cmd, []string{modelName}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return client.Show(cmd.Context(), &api.ShowRequest{Name: modelName})
|
||||
}
|
||||
return info, err
|
||||
}()
|
||||
if err != nil {
|
||||
if handleCloudAuthorizationError(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
ensureCloudStub(cmd.Context(), client, modelName)
|
||||
|
||||
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
audioCapable := slices.Contains(info.Capabilities, model.CapabilityAudio)
|
||||
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) || audioCapable
|
||||
|
||||
// TODO: remove the projector info and vision info checks below,
|
||||
// these are left in for backwards compatibility with older servers
|
||||
// that don't have the capabilities field in the model info
|
||||
if len(info.ProjectorInfo) != 0 {
|
||||
opts.MultiModal = true
|
||||
}
|
||||
for k := range info.ModelInfo {
|
||||
if strings.Contains(k, ".vision.") {
|
||||
opts.MultiModal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
applyShowResponseToRunOptions(&opts, info)
|
||||
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
return fmt.Errorf("error loading model: %w", err)
|
||||
}
|
||||
|
||||
@@ -226,7 +226,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("%s", strings.TrimSpace(string(respBody)))
|
||||
return api.StatusError{StatusCode: resp.StatusCode, ErrorMessage: strings.TrimSpace(string(respBody))}
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
@@ -18,20 +20,28 @@ import (
|
||||
type Array struct {
|
||||
ctx C.mlx_array
|
||||
name string
|
||||
pinned int
|
||||
pinned atomic.Int32
|
||||
}
|
||||
|
||||
var arrays []*Array
|
||||
var (
|
||||
arrays []*Array
|
||||
arraysMu sync.Mutex
|
||||
)
|
||||
|
||||
// constructor utilities
|
||||
|
||||
func New(name string) *Array {
|
||||
t := &Array{name: name}
|
||||
|
||||
if tracing {
|
||||
traceScratch = append(traceScratch, t)
|
||||
} else {
|
||||
arraysMu.Lock()
|
||||
defer arraysMu.Unlock()
|
||||
|
||||
arrays = append(arrays, t)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
@@ -131,7 +141,7 @@ func (t *Array) Clone() *Array {
|
||||
func Pin(s ...*Array) {
|
||||
for _, t := range s {
|
||||
if t != nil {
|
||||
t.pinned++
|
||||
t.pinned.Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -140,8 +150,7 @@ func Pin(s ...*Array) {
|
||||
func Unpin(s ...*Array) {
|
||||
for _, t := range s {
|
||||
if t != nil {
|
||||
t.pinned--
|
||||
if t.pinned < 0 {
|
||||
if t.pinned.Add(-1) < 0 {
|
||||
panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name))
|
||||
}
|
||||
}
|
||||
@@ -151,9 +160,11 @@ func Unpin(s ...*Array) {
|
||||
// Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly
|
||||
// free them when there are no other references, including dependencies in the graph.
|
||||
func Sweep() {
|
||||
arraysMu.Lock()
|
||||
defer arraysMu.Unlock()
|
||||
n := 0
|
||||
for _, t := range arrays {
|
||||
if t.pinned > 0 && t.Valid() {
|
||||
if t.pinned.Load() > 0 && t.Valid() {
|
||||
arrays[n] = t
|
||||
n++
|
||||
} else if t.Valid() {
|
||||
@@ -180,7 +191,7 @@ func (t *Array) String() string {
|
||||
func (t *Array) LogValue() slog.Value {
|
||||
attrs := []slog.Attr{
|
||||
slog.String("name", t.name),
|
||||
slog.Int("pinned", t.pinned),
|
||||
slog.Int("pinned", int(t.pinned.Load())),
|
||||
}
|
||||
if t.Valid() {
|
||||
attrs = append(attrs,
|
||||
@@ -194,19 +205,19 @@ func (t *Array) LogValue() slog.Value {
|
||||
|
||||
// shape utilities
|
||||
|
||||
func (t Array) Size() int {
|
||||
func (t *Array) Size() int {
|
||||
return int(C.mlx_array_size(t.ctx))
|
||||
}
|
||||
|
||||
func (t Array) NumBytes() int {
|
||||
func (t *Array) NumBytes() int {
|
||||
return int(C.mlx_array_nbytes(t.ctx))
|
||||
}
|
||||
|
||||
func (t Array) NumDims() int {
|
||||
func (t *Array) NumDims() int {
|
||||
return int(C.mlx_array_ndim(t.ctx))
|
||||
}
|
||||
|
||||
func (t Array) Dims() []int {
|
||||
func (t *Array) Dims() []int {
|
||||
dims := make([]int, t.NumDims())
|
||||
for i := range dims {
|
||||
dims[i] = t.Dim(i)
|
||||
@@ -215,29 +226,29 @@ func (t Array) Dims() []int {
|
||||
return dims
|
||||
}
|
||||
|
||||
func (t Array) Dim(dim int) int {
|
||||
func (t *Array) Dim(dim int) int {
|
||||
return int(C.mlx_array_dim(t.ctx, C.int(dim)))
|
||||
}
|
||||
|
||||
func (t Array) DType() DType {
|
||||
func (t *Array) DType() DType {
|
||||
return DType(C.mlx_array_dtype(t.ctx))
|
||||
}
|
||||
|
||||
// data utilities
|
||||
|
||||
func (t Array) Int() int {
|
||||
func (t *Array) Int() int {
|
||||
var item C.int64_t
|
||||
C.mlx_array_item_int64(&item, t.ctx)
|
||||
return int(item)
|
||||
}
|
||||
|
||||
func (t Array) Float() float64 {
|
||||
func (t *Array) Float() float64 {
|
||||
var item C.double
|
||||
C.mlx_array_item_float64(&item, t.ctx)
|
||||
return float64(item)
|
||||
}
|
||||
|
||||
func (t Array) Ints() []int {
|
||||
func (t *Array) Ints() []int {
|
||||
if dt := t.DType(); dt != DTypeInt32 {
|
||||
panic(fmt.Sprintf("mlx: Ints requires DTypeInt32, got %v", dt))
|
||||
}
|
||||
@@ -248,7 +259,7 @@ func (t Array) Ints() []int {
|
||||
return ints
|
||||
}
|
||||
|
||||
func (t Array) Floats() []float32 {
|
||||
func (t *Array) Floats() []float32 {
|
||||
if dt := t.DType(); dt != DTypeFloat32 {
|
||||
panic(fmt.Sprintf("mlx: Floats requires DTypeFloat32, got %v", dt))
|
||||
}
|
||||
@@ -259,7 +270,7 @@ func (t Array) Floats() []float32 {
|
||||
return floats
|
||||
}
|
||||
|
||||
func (t Array) Save(name string) error {
|
||||
func (t *Array) Save(name string) error {
|
||||
cName := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
C.mlx_save(cName, t.ctx)
|
||||
@@ -268,6 +279,8 @@ func (t Array) Save(name string) error {
|
||||
|
||||
// LogArrays logs all live arrays, sorted by size
|
||||
func LogArrays() {
|
||||
arraysMu.Lock()
|
||||
defer arraysMu.Unlock()
|
||||
sort.Slice(arrays, func(i, j int) bool {
|
||||
return arrays[i].NumBytes() > arrays[j].NumBytes()
|
||||
})
|
||||
@@ -276,7 +289,7 @@ func LogArrays() {
|
||||
for _, t := range arrays {
|
||||
nb := t.NumBytes()
|
||||
total += nb
|
||||
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims()))
|
||||
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned.Load(), t.Dims()))
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s, active: %s", len(arrays), PrettyBytes(total), PrettyBytes(ActiveMemory())))
|
||||
}
|
||||
|
||||
@@ -150,7 +150,7 @@ func closureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload
|
||||
traceScratch = nil
|
||||
defer func() {
|
||||
for _, a := range traceScratch {
|
||||
if a.pinned > 0 {
|
||||
if a.pinned.Load() > 0 {
|
||||
panic("mlx: traced array was pinned during compilation")
|
||||
}
|
||||
if a.Valid() {
|
||||
|
||||
@@ -24,8 +24,8 @@ func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *A
|
||||
}
|
||||
|
||||
type LayerNorm struct {
|
||||
Weight Array `weight:"weight"`
|
||||
Bias Array `weight:"bias"`
|
||||
Weight *Array `weight:"weight"`
|
||||
Bias *Array `weight:"bias"`
|
||||
}
|
||||
|
||||
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
|
||||
@@ -35,10 +35,10 @@ func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
|
||||
}
|
||||
|
||||
type RMSNorm struct {
|
||||
Weight Array `weight:"weight"`
|
||||
Weight *Array `weight:"weight"`
|
||||
}
|
||||
|
||||
func (r RMSNorm) Forward(x *Array, eps float32) *Array {
|
||||
func (r *RMSNorm) Forward(x *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM")
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package mlx
|
||||
|
||||
type Linear struct {
|
||||
Weight Array `weight:"weight"`
|
||||
Bias Array `weight:"bias"`
|
||||
Weight *Array `weight:"weight"`
|
||||
Bias *Array `weight:"bias"`
|
||||
}
|
||||
|
||||
// Forward computes the linear transformation: x @ Weight.T + Bias
|
||||
func (m Linear) Forward(x *Array) *Array {
|
||||
func (m *Linear) Forward(x *Array) *Array {
|
||||
w := m.Weight.Transpose(1, 0)
|
||||
if m.Bias.Valid() {
|
||||
return m.Bias.Addmm(x, w, 1.0, 1.0)
|
||||
@@ -15,14 +15,14 @@ func (m Linear) Forward(x *Array) *Array {
|
||||
return x.Matmul(w)
|
||||
}
|
||||
|
||||
func (m Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
|
||||
func (m *Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
|
||||
w := m.Weight.Transpose(0, 2, 1)
|
||||
// TODO: bias
|
||||
return x.GatherMM(w, lhs, rhs, sorted)
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Weight Array `weight:"weight"`
|
||||
Weight *Array `weight:"weight"`
|
||||
}
|
||||
|
||||
func (e *Embedding) Forward(indices *Array) *Array {
|
||||
|
||||
@@ -72,6 +72,10 @@ func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
|
||||
}
|
||||
|
||||
func (t *Array) Concatenate(axis int, others ...*Array) *Array {
|
||||
if len(others) == 0 {
|
||||
return t
|
||||
}
|
||||
|
||||
vector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
@@ -127,9 +131,9 @@ func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Logsumexp(keepDims bool) *Array {
|
||||
out := New("LOGSUMEXP")
|
||||
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
|
||||
func (t *Array) LogsumexpAxis(axis int, keepDims bool) *Array {
|
||||
out := New("LOGSUMEXP_AXIS")
|
||||
C.mlx_logsumexp_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
|
||||
@@ -376,6 +376,9 @@ func Concatenate(arrays []*Array, axis int) *Array {
|
||||
if len(arrays) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(arrays) == 1 {
|
||||
return arrays[0]
|
||||
}
|
||||
return arrays[0].Concatenate(axis, arrays[1:]...)
|
||||
}
|
||||
|
||||
|
||||
@@ -6,11 +6,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
@@ -22,19 +20,44 @@ func prefillChunkSize() int {
|
||||
return 2 << 10
|
||||
}
|
||||
|
||||
func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
// Prepare tokenizes the prompt and validates it against the model's
|
||||
// context length. It is safe to call from any goroutine. On success it
|
||||
// populates request.Tokens and adjusts request.Options.NumPredict.
|
||||
func (r *Runner) Prepare(request *Request) error {
|
||||
if r.Model == nil {
|
||||
return errors.New("model not loaded")
|
||||
}
|
||||
|
||||
tokens := r.Tokenizer.Encode(request.Prompt, r.Tokenizer.AddBOS())
|
||||
if len(tokens) == 0 {
|
||||
return errors.New("empty prompt")
|
||||
}
|
||||
|
||||
if len(tokens) >= r.contextLength {
|
||||
return fmt.Errorf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(tokens), r.contextLength)
|
||||
}
|
||||
|
||||
// Cap generation to stay within the model's context length
|
||||
maxGenerate := r.contextLength - len(tokens)
|
||||
if request.Options.NumPredict <= 0 {
|
||||
request.Options.NumPredict = maxGenerate
|
||||
} else {
|
||||
request.Options.NumPredict = min(request.Options.NumPredict, maxGenerate)
|
||||
}
|
||||
|
||||
request.Tokens = tokens
|
||||
return nil
|
||||
}
|
||||
|
||||
// The runner serializes requests today so we just use a fixed slot ID.
|
||||
const pipelineSlot = 0
|
||||
|
||||
func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) error {
|
||||
mlx.ResetPeakMemory()
|
||||
ctx := request.Ctx
|
||||
var sample, nextSample sampler.Result
|
||||
|
||||
defer func() {
|
||||
if request.Sampler != nil {
|
||||
request.Sampler.Free()
|
||||
}
|
||||
r.Sampler.Remove(pipelineSlot)
|
||||
mlx.Unpin(sample.Arrays()...)
|
||||
mlx.Unpin(nextSample.Arrays()...)
|
||||
mlx.Sweep()
|
||||
@@ -47,27 +70,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
|
||||
}()
|
||||
|
||||
inputs := r.Tokenizer.Encode(request.Prompt, r.Tokenizer.AddBOS())
|
||||
if len(inputs) == 0 {
|
||||
return errors.New("empty prompt")
|
||||
}
|
||||
|
||||
if len(inputs) >= r.contextLength {
|
||||
return api.StatusError{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
|
||||
}
|
||||
}
|
||||
|
||||
// Cap generation to stay within the model's context length
|
||||
maxGenerate := r.contextLength - len(inputs)
|
||||
if request.Options.NumPredict <= 0 {
|
||||
request.Options.NumPredict = maxGenerate
|
||||
} else {
|
||||
request.Options.NumPredict = min(request.Options.NumPredict, maxGenerate)
|
||||
}
|
||||
|
||||
request.Sampler.ResetHistory(inputs)
|
||||
inputs := request.Tokens
|
||||
|
||||
session := r.cache.begin(r.Model, inputs)
|
||||
defer session.close()
|
||||
@@ -119,7 +122,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
}
|
||||
|
||||
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], 1, n), caches)
|
||||
mlx.Sweep()
|
||||
materializeCaches()
|
||||
processed += n
|
||||
@@ -136,21 +139,28 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
|
||||
// Register the sampler after prefill completes.
|
||||
r.Sampler.Add(pipelineSlot, request.SamplerOpts, inputs)
|
||||
|
||||
step := func(token *mlx.Array) sampler.Result {
|
||||
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
||||
fwd := r.Model.Forward(token, caches)
|
||||
logits := r.Model.Unembed(fwd)
|
||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
|
||||
sample := request.Sampler.Sample(logits)
|
||||
sample := r.Sampler.Sample([]int{pipelineSlot}, logits)
|
||||
mlx.Pin(sample.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(sample.Arrays()...)
|
||||
return sample
|
||||
}
|
||||
|
||||
sample = step(mlx.FromValues(tokens[processed:], total-processed))
|
||||
sample = step(mlx.FromValues(tokens[processed:], 1, total-processed))
|
||||
|
||||
dec := decoder{tokenizer: r.Tokenizer}
|
||||
dec := decoder{
|
||||
tokenizer: r.Tokenizer,
|
||||
wantLogprobs: request.SamplerOpts.Logprobs,
|
||||
wantTopLogprobs: request.SamplerOpts.TopLogprobs,
|
||||
}
|
||||
|
||||
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.NumPredict, DoneReason: 1}
|
||||
for i := range request.Options.NumPredict {
|
||||
@@ -158,8 +168,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
return err
|
||||
}
|
||||
|
||||
request.Sampler.AppendToken(sample.Token)
|
||||
nextSample = step(sample.Token)
|
||||
nextSample = step(sample.Token.ExpandDims(-1))
|
||||
|
||||
if i == 0 {
|
||||
mlx.Eval(sample.Arrays()...)
|
||||
@@ -206,15 +215,17 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
// with those bytes so Content and Logprobs stay aligned when a chunk does
|
||||
// flush.
|
||||
type decoder struct {
|
||||
tokenizer *tokenizer.Tokenizer
|
||||
buf bytes.Buffer
|
||||
logprobs []llm.Logprob
|
||||
tokenizer *tokenizer.Tokenizer
|
||||
buf bytes.Buffer
|
||||
logprobs []llm.Logprob
|
||||
wantLogprobs bool
|
||||
wantTopLogprobs int
|
||||
}
|
||||
|
||||
func (d *decoder) decode(res sampler.Result) (CompletionResponse, bool) {
|
||||
output := int32(res.Token.Int())
|
||||
d.buf.WriteString(d.tokenizer.Decode([]int32{output}))
|
||||
d.logprobs = append(d.logprobs, buildLogprob(res, d.tokenizer.Decode)...)
|
||||
d.logprobs = append(d.logprobs, buildLogprob(res, d.wantLogprobs, d.wantTopLogprobs, d.tokenizer.Decode)...)
|
||||
|
||||
content := flushValidUTF8Prefix(&d.buf)
|
||||
if content == "" {
|
||||
@@ -225,8 +236,13 @@ func (d *decoder) decode(res sampler.Result) (CompletionResponse, bool) {
|
||||
return resp, true
|
||||
}
|
||||
|
||||
func buildLogprob(sample sampler.Result, decode func([]int32) string) []llm.Logprob {
|
||||
if sample.Logprob == nil {
|
||||
// buildLogprob converts the sampler's logprob tensors into the wire-format
|
||||
// llm.Logprob entries the caller wants. The sampler populates its logprob
|
||||
// tensors whenever any registered slot requested them, so the caller must
|
||||
// gate emission on its own request config (wantLogprobs / wantTopLogprobs)
|
||||
// rather than on whether the tensors happen to be non-nil.
|
||||
func buildLogprob(sample sampler.Result, wantLogprobs bool, wantTopLogprobs int, decode func([]int32) string) []llm.Logprob {
|
||||
if !wantLogprobs || sample.Logprob == nil {
|
||||
return nil
|
||||
}
|
||||
tok := func(id int32) string { return decode([]int32{id}) }
|
||||
@@ -238,7 +254,7 @@ func buildLogprob(sample sampler.Result, decode func([]int32) string) []llm.Logp
|
||||
},
|
||||
}
|
||||
|
||||
if sample.TopTokens != nil {
|
||||
if wantTopLogprobs > 0 && sample.TopTokens != nil {
|
||||
ids := sample.TopTokens.Ints()
|
||||
vals := sample.TopLogprobs.Floats()
|
||||
pairs := make([]llm.TokenLogprob, len(ids))
|
||||
@@ -248,9 +264,14 @@ func buildLogprob(sample sampler.Result, decode func([]int32) string) []llm.Logp
|
||||
Logprob: float64(vals[i]),
|
||||
}
|
||||
}
|
||||
// The sampler emits the top maxK across registered slots via
|
||||
// Argpartition, which leaves entries unsorted.
|
||||
sort.Slice(pairs, func(i, j int) bool {
|
||||
return pairs[i].Logprob > pairs[j].Logprob
|
||||
})
|
||||
if wantTopLogprobs < len(pairs) {
|
||||
pairs = pairs[:wantTopLogprobs]
|
||||
}
|
||||
out.TopLogprobs = pairs
|
||||
}
|
||||
return []llm.Logprob{out}
|
||||
|
||||
@@ -18,20 +18,25 @@ import (
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
// Request is a short-lived struct that carries a completion request through
|
||||
// a channel from the HTTP handler to the runner goroutine. The ctx field
|
||||
// must travel with the request so that cancellation propagates across the
|
||||
// channel boundary.
|
||||
type Request struct {
|
||||
CompletionRequest
|
||||
Responses chan CompletionResponse
|
||||
Pipeline func(Request) error
|
||||
Pipeline func(context.Context, Request) error
|
||||
|
||||
Ctx context.Context
|
||||
|
||||
Sampler *sample.Sampler
|
||||
Ctx context.Context //nolint:containedctx
|
||||
Tokens []int32
|
||||
SamplerOpts sample.Options
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
Model base.Model
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
Sampler *sample.Sampler
|
||||
cache kvCache
|
||||
contextLength int
|
||||
}
|
||||
@@ -63,6 +68,7 @@ func (r *Runner) Load(modelName string) error {
|
||||
r.Model = m
|
||||
r.Tokenizer = m.Tokenizer()
|
||||
r.contextLength = m.MaxContextLength()
|
||||
r.Sampler = sample.New(r.contextLength)
|
||||
|
||||
mlx.EnableCompile()
|
||||
return nil
|
||||
@@ -131,7 +137,7 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case request := <-r.Requests:
|
||||
if err := request.Pipeline(request); err != nil {
|
||||
if err := request.Pipeline(request.Ctx, request); err != nil {
|
||||
slog.Info("Request terminated", "error", err)
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
|
||||
@@ -24,14 +24,15 @@ type logprobEntry struct {
|
||||
func runSampleLogprobs(t *testing.T, logits []float32, topK int) (int, float64, []logprobEntry) {
|
||||
t.Helper()
|
||||
|
||||
s := New(Options{Logprobs: true, TopLogprobs: topK})
|
||||
s := New(128)
|
||||
defer func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
}()
|
||||
s.Add(0, Options{Logprobs: true, TopLogprobs: topK}, nil)
|
||||
|
||||
tensor := mlx.FromValues(logits, 1, len(logits))
|
||||
res := s.Sample(tensor)
|
||||
res := s.Sample([]int{0}, tensor)
|
||||
|
||||
mlx.Pin(res.Arrays()...)
|
||||
defer mlx.Unpin(res.Arrays()...)
|
||||
@@ -225,6 +226,42 @@ func TestSampleLogprobsSelectedTokenCorrectness(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestBatchedLogprobsPerRow verifies that per-row logprobs in a batched
|
||||
// sample call match the per-slot reference. The numerically-stable softmax
|
||||
// must reduce along the last axis only, not over the whole batch.
|
||||
func TestBatchedLogprobsPerRow(t *testing.T) {
|
||||
rowA := []float32{2, 1, 0}
|
||||
rowB := []float32{0, 5, 0}
|
||||
|
||||
_, wantA, _ := runSampleLogprobs(t, rowA, 0)
|
||||
_, wantB, _ := runSampleLogprobs(t, rowB, 0)
|
||||
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
s.Add(1, Options{Logprobs: true}, nil)
|
||||
s.Add(2, Options{Logprobs: true}, nil)
|
||||
|
||||
logits := mlx.FromValues(append(append([]float32{}, rowA...), rowB...), 2, 3)
|
||||
res := s.Sample([]int{1, 2}, logits)
|
||||
mlx.Pin(res.Arrays()...)
|
||||
t.Cleanup(func() { mlx.Unpin(res.Arrays()...) })
|
||||
mlx.Eval(res.Arrays()...)
|
||||
|
||||
got := res.Logprob.Floats()
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("Logprob length = %d, want 2", len(got))
|
||||
}
|
||||
if math.Abs(float64(got[0])-wantA) > 1e-5 {
|
||||
t.Errorf("row 0 logprob = %f, want %f (per-slot reference)", got[0], wantA)
|
||||
}
|
||||
if math.Abs(float64(got[1])-wantB) > 1e-5 {
|
||||
t.Errorf("row 1 logprob = %f, want %f (per-slot reference)", got[1], wantB)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSampleLogprobsTopKOrdering(t *testing.T) {
|
||||
// Logits chosen so argmax order differs from index order.
|
||||
logits := []float32{2.0, 5.0, 1.0, 4.0, 3.0}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
type Transform func(*Sampler, *mlx.Array) *mlx.Array
|
||||
|
||||
type Options struct {
|
||||
Temperature float32
|
||||
TopP float32
|
||||
@@ -24,21 +24,15 @@ type Options struct {
|
||||
TopLogprobs int
|
||||
}
|
||||
|
||||
type Sampler struct {
|
||||
Options
|
||||
|
||||
history *mlx.Array
|
||||
historyLen int
|
||||
transforms []Transform
|
||||
}
|
||||
|
||||
// Result bundles the outputs of one decode step. The logprob tensors are
|
||||
// populated only when the sampler is configured to report them.
|
||||
// Result bundles the outputs of one decode step. Logprob/TopTokens/
|
||||
// TopLogprobs are populated whenever any registered slot has Logprobs
|
||||
// (respectively TopLogprobs>0). Consumers need to filter by their
|
||||
// per-slot Options.
|
||||
type Result struct {
|
||||
Token *mlx.Array // sampled token id, shape [B]
|
||||
Logprob *mlx.Array // sampled-token logprob, shape [B,1]; nil unless Logprobs
|
||||
TopTokens *mlx.Array // top-K token ids, shape [B,K]; nil unless TopLogprobs > 0
|
||||
TopLogprobs *mlx.Array // top-K logprobs, shape [B,K]; nil unless TopLogprobs > 0
|
||||
Token *mlx.Array // sampled token ids, shape [B]
|
||||
Logprob *mlx.Array // sampled-token logprobs, shape [B,1]; nil unless any registered slot has Logprobs
|
||||
TopTokens *mlx.Array // top-K token ids, shape [B,maxK]; nil unless any registered slot has TopLogprobs>0
|
||||
TopLogprobs *mlx.Array // top-K logprobs, shape [B,maxK]; same
|
||||
}
|
||||
|
||||
// Arrays returns the tensor fields as a slice so callers can drive the mlx
|
||||
@@ -48,121 +42,300 @@ func (r Result) Arrays() []*mlx.Array {
|
||||
return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs}
|
||||
}
|
||||
|
||||
func New(opts Options) *Sampler {
|
||||
if opts.RepeatPenalty <= 0 {
|
||||
opts.RepeatPenalty = 1
|
||||
// Sampler is a batched, slot-based sampler. Sequences are registered with
|
||||
// Add and released with Remove. Each Sample call takes a subset of
|
||||
// registered slots (in any order) with their [B,V] logits, samples one
|
||||
// token per row, and appends it to that slot's ring-buffer history. Slots
|
||||
// not named in a given call are untouched.
|
||||
type Sampler struct {
|
||||
slots []*slotState
|
||||
byID map[int]*slotState
|
||||
|
||||
// history is the pooled ring-buffer storage, [B, W] int32. Row i
|
||||
// belongs to slots[i]; W is max(RepeatLastN) across penalty slots.
|
||||
// Allocated on the first penalty slot, rebuilt only in Add/Remove.
|
||||
history *mlx.Array
|
||||
|
||||
// allSameOpts: every registered slot shares Options. When true the
|
||||
// canonical shared value is s.slots[0].opts.
|
||||
allSameOpts bool
|
||||
|
||||
// anyLogprobs / maxTopLogprobs: compute-for-all output config.
|
||||
// Sample populates Logprob (and Top* when maxTopLogprobs>0) whenever
|
||||
// any registered slot requests them, even if that slot isn't in the
|
||||
// current call.
|
||||
anyLogprobs bool
|
||||
maxTopLogprobs int
|
||||
|
||||
// numCtx is the runner's context window; normalize uses it to
|
||||
// resolve the repeat_last_n == -1 sentinel.
|
||||
numCtx int
|
||||
}
|
||||
|
||||
type slotState struct {
|
||||
opts Options
|
||||
transforms []transform
|
||||
historyLen int
|
||||
}
|
||||
|
||||
type slotCtx struct {
|
||||
opts Options
|
||||
history *mlx.Array // 2D [B, W] when penalties are configured; nil otherwise
|
||||
}
|
||||
|
||||
type transform func(*slotCtx, *mlx.Array) *mlx.Array
|
||||
|
||||
// New constructs an empty sampler with no registered slots. numCtx is
|
||||
// the runner's context window and must be positive.
|
||||
func New(numCtx int) *Sampler {
|
||||
return &Sampler{
|
||||
byID: make(map[int]*slotState),
|
||||
allSameOpts: true,
|
||||
numCtx: numCtx,
|
||||
}
|
||||
}
|
||||
|
||||
// historyWidth returns the column count of the pooled history tensor,
|
||||
// or 0 when no penalty slot has forced it to be allocated.
|
||||
func (s *Sampler) historyWidth() int {
|
||||
if s.history == nil {
|
||||
return 0
|
||||
}
|
||||
return s.history.Dim(1)
|
||||
}
|
||||
|
||||
func (o Options) usesHistory() bool {
|
||||
// RepeatLastN == 0 disables the penalty ring per the repeat_last_n API
|
||||
// contract (0 = disabled), overriding any penalty coefficients.
|
||||
if o.RepeatLastN == 0 {
|
||||
return false
|
||||
}
|
||||
return o.RepeatPenalty != 1 || o.PresencePenalty != 0 || o.FrequencyPenalty != 0
|
||||
}
|
||||
|
||||
func (o Options) normalize(numCtx int) Options {
|
||||
if o.RepeatPenalty <= 0 {
|
||||
o.RepeatPenalty = 1
|
||||
}
|
||||
// Resolve the repeat_last_n == -1 sentinel ("-1 = num_ctx") against
|
||||
// the caller's context window.
|
||||
if o.RepeatLastN < 0 {
|
||||
o.RepeatLastN = numCtx
|
||||
}
|
||||
if !o.usesHistory() {
|
||||
// Zero the ring capacity so slots that differ only in a spurious
|
||||
// RepeatLastN still batch together and don't inflate pool width.
|
||||
o.RepeatLastN = 0
|
||||
}
|
||||
return o
|
||||
}
|
||||
|
||||
func (o Options) buildTransforms() []transform {
|
||||
var ts []transform
|
||||
if o.usesHistory() {
|
||||
ts = append(ts, penalty)
|
||||
}
|
||||
|
||||
s := &Sampler{Options: opts}
|
||||
|
||||
var transforms []Transform
|
||||
if s.usesHistory() {
|
||||
transforms = append(transforms, penalty)
|
||||
}
|
||||
|
||||
hasTopP := opts.TopP > 0 && opts.TopP < 1
|
||||
hasTopK := opts.TopK > 0
|
||||
hasTopP := o.TopP > 0 && o.TopP < 1
|
||||
hasTopK := o.TopK > 0
|
||||
switch {
|
||||
case hasTopP:
|
||||
// topKTopP always does a full descending sort for the top-P
|
||||
// cumulative mask and opportunistically masks top-K during the
|
||||
// same pass when it is also configured.
|
||||
transforms = append(transforms, topKTopP)
|
||||
ts = append(ts, topKTopP)
|
||||
case hasTopK:
|
||||
// Argpartition (partial sort) is cheaper than a full sort.
|
||||
transforms = append(transforms, topK)
|
||||
ts = append(ts, topK)
|
||||
}
|
||||
|
||||
if opts.MinP != 0 {
|
||||
transforms = append(transforms, minP)
|
||||
if o.MinP != 0 {
|
||||
ts = append(ts, minP)
|
||||
}
|
||||
|
||||
if opts.Temperature == 0 {
|
||||
transforms = append(transforms, greedy)
|
||||
if o.Temperature == 0 {
|
||||
ts = append(ts, greedy)
|
||||
} else {
|
||||
transforms = append(transforms, temperature)
|
||||
ts = append(ts, temperature)
|
||||
}
|
||||
|
||||
s.transforms = transforms
|
||||
return s
|
||||
return ts
|
||||
}
|
||||
|
||||
func (s *Sampler) usesHistory() bool {
|
||||
return s.RepeatPenalty != 1 || s.PresencePenalty != 0 || s.FrequencyPenalty != 0
|
||||
}
|
||||
|
||||
func (s *Sampler) setHistory(history *mlx.Array, historyLen int) {
|
||||
if history != nil {
|
||||
mlx.Pin(history)
|
||||
// Add registers a sequence under seqID. The last RepeatLastN entries of
|
||||
// priorTokens seed the ring buffer.
|
||||
func (s *Sampler) Add(seqID int, opts Options, priorTokens []int32) {
|
||||
if _, dup := s.byID[seqID]; dup {
|
||||
panic(fmt.Sprintf("sample.Sampler.Add: seqID %d already registered", seqID))
|
||||
}
|
||||
if s.history != nil {
|
||||
|
||||
opts = opts.normalize(s.numCtx)
|
||||
slot := &slotState{
|
||||
opts: opts,
|
||||
transforms: opts.buildTransforms(),
|
||||
}
|
||||
|
||||
// Grow the pool to hold this slot's row. The pool is lazy — the first
|
||||
// penalty slot allocates it — and thereafter every registered slot
|
||||
// gets a row (rows for non-penalty slots are zero and never read).
|
||||
// Invariant: s.history is pinned whenever non-nil.
|
||||
if s.history != nil || opts.usesHistory() {
|
||||
targetWidth := max(opts.RepeatLastN, s.historyWidth())
|
||||
newRow := makeHistoryRow(priorTokens, opts.RepeatLastN, targetWidth)
|
||||
|
||||
var pool *mlx.Array
|
||||
switch {
|
||||
case s.history == nil && len(s.slots) == 0:
|
||||
pool = newRow
|
||||
case s.history == nil:
|
||||
// First penalty slot with non-penalty slots already registered;
|
||||
// seed zero rows so s.slots and pool row indices stay aligned.
|
||||
zeros := mlx.Zeros(mlx.DTypeInt32, len(s.slots), targetWidth)
|
||||
pool = zeros.Concatenate(0, newRow)
|
||||
case targetWidth > s.historyWidth():
|
||||
pad := mlx.Zeros(mlx.DTypeInt32, s.history.Dim(0), targetWidth-s.historyWidth())
|
||||
pool = s.history.Concatenate(1, pad).Concatenate(0, newRow)
|
||||
default:
|
||||
pool = s.history.Concatenate(0, newRow)
|
||||
}
|
||||
|
||||
mlx.Pin(pool)
|
||||
mlx.Unpin(s.history)
|
||||
s.history = pool
|
||||
|
||||
if opts.usesHistory() {
|
||||
// Cap on seed so the next write's ring position
|
||||
// (historyLen % RepeatLastN) lands at 0, overwriting the
|
||||
// oldest entry when the ring was filled from priors.
|
||||
slot.historyLen = min(len(priorTokens), opts.RepeatLastN)
|
||||
}
|
||||
}
|
||||
s.history = history
|
||||
s.historyLen = historyLen
|
||||
|
||||
s.slots = append(s.slots, slot)
|
||||
s.byID[seqID] = slot
|
||||
s.recomputeInvariants()
|
||||
}
|
||||
|
||||
func (s *Sampler) ResetHistory(history []int32) {
|
||||
if !s.usesHistory() {
|
||||
// makeHistoryRow builds a [1, width] int32 row with the last repeatLastN
|
||||
// entries of priorTokens packed into [0, min(len, repeatLastN)), zeros
|
||||
// elsewhere.
|
||||
func makeHistoryRow(priorTokens []int32, repeatLastN, width int) *mlx.Array {
|
||||
take := min(len(priorTokens), repeatLastN)
|
||||
if take <= 0 {
|
||||
return mlx.Zeros(mlx.DTypeInt32, 1, width)
|
||||
}
|
||||
row := make([]int32, width)
|
||||
copy(row, priorTokens[len(priorTokens)-take:])
|
||||
return mlx.NewArrayInt32(row, []int32{1, int32(width)})
|
||||
}
|
||||
|
||||
// recomputeInvariants refreshes allSameOpts and anyLogprobs/maxTopLogprobs
|
||||
// from s.slots. Called at the end of Add and Remove.
|
||||
func (s *Sampler) recomputeInvariants() {
|
||||
if len(s.slots) == 0 {
|
||||
s.allSameOpts = true
|
||||
s.anyLogprobs = false
|
||||
s.maxTopLogprobs = 0
|
||||
return
|
||||
}
|
||||
if s.RepeatLastN > 0 && len(history) > s.RepeatLastN {
|
||||
history = history[len(history)-s.RepeatLastN:]
|
||||
first := s.slots[0].opts
|
||||
s.allSameOpts = true
|
||||
s.anyLogprobs = false
|
||||
s.maxTopLogprobs = 0
|
||||
for _, slot := range s.slots {
|
||||
if slot.opts != first {
|
||||
s.allSameOpts = false
|
||||
}
|
||||
if slot.opts.Logprobs {
|
||||
s.anyLogprobs = true
|
||||
if slot.opts.TopLogprobs > s.maxTopLogprobs {
|
||||
s.maxTopLogprobs = slot.opts.TopLogprobs
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(history) == 0 {
|
||||
s.setHistory(nil, 0)
|
||||
}
|
||||
|
||||
// Remove releases the slot. The pool tensor is rebuilt to drop the row.
|
||||
func (s *Sampler) Remove(seqID int) {
|
||||
slot, ok := s.byID[seqID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(s.byID, seqID)
|
||||
|
||||
row := slices.Index(s.slots, slot)
|
||||
s.slots = slices.Delete(s.slots, row, row+1)
|
||||
s.recomputeInvariants()
|
||||
|
||||
if s.history == nil {
|
||||
return
|
||||
}
|
||||
|
||||
tokens := append([]int32(nil), history...)
|
||||
s.setHistory(mlx.NewArrayInt32(tokens, []int32{int32(len(tokens))}), len(tokens))
|
||||
}
|
||||
|
||||
func (s *Sampler) AppendToken(token *mlx.Array) {
|
||||
if !s.usesHistory() || token == nil {
|
||||
return
|
||||
}
|
||||
|
||||
next := token.AsType(mlx.DTypeInt32)
|
||||
nextLen := next.Size()
|
||||
|
||||
if s.history != nil && s.historyLen > 0 {
|
||||
next = s.history.Concatenate(0, next)
|
||||
nextLen += s.historyLen
|
||||
}
|
||||
|
||||
if s.RepeatLastN > 0 && nextLen > s.RepeatLastN {
|
||||
trim := nextLen - s.RepeatLastN
|
||||
next = next.Slice(mlx.Slice(trim, nextLen))
|
||||
nextLen = s.RepeatLastN
|
||||
}
|
||||
|
||||
s.setHistory(next, nextLen)
|
||||
n := s.history.Dim(0)
|
||||
var newHistory *mlx.Array
|
||||
switch {
|
||||
case n == 1:
|
||||
newHistory = nil
|
||||
case row == 0:
|
||||
newHistory = s.history.Slice(mlx.Slice(1, n), mlx.Slice())
|
||||
case row == n-1:
|
||||
newHistory = s.history.Slice(mlx.Slice(0, row), mlx.Slice())
|
||||
default:
|
||||
before := s.history.Slice(mlx.Slice(0, row), mlx.Slice())
|
||||
after := s.history.Slice(mlx.Slice(row+1, n), mlx.Slice())
|
||||
newHistory = before.Concatenate(0, after)
|
||||
}
|
||||
|
||||
mlx.Pin(newHistory)
|
||||
mlx.Unpin(s.history)
|
||||
s.history = newHistory
|
||||
}
|
||||
|
||||
// Free releases the pooled history tensor and resets the sampler to the
|
||||
// New-equivalent state so it may be reused.
|
||||
func (s *Sampler) Free() {
|
||||
s.setHistory(nil, 0)
|
||||
mlx.Unpin(s.history)
|
||||
*s = Sampler{
|
||||
byID: make(map[int]*slotState),
|
||||
allSameOpts: true,
|
||||
numCtx: s.numCtx,
|
||||
}
|
||||
}
|
||||
|
||||
// Sample runs the configured transform chain on the raw per-token logits
|
||||
// and returns the sampled token id plus, when configured, the reported
|
||||
// log-probability tensors for the selected token and the top-K tokens.
|
||||
func (s *Sampler) Sample(logits *mlx.Array) Result {
|
||||
scores := logits
|
||||
for _, transform := range s.transforms {
|
||||
scores = transform(s, scores)
|
||||
// Sample draws one token per row of logits ([B,V]); seqIDs[i] names the
|
||||
// slot whose logits live at row i. Each sampled token is appended to its
|
||||
// slot's ring. Slots not named in seqIDs are untouched.
|
||||
func (s *Sampler) Sample(seqIDs []int, logits *mlx.Array) Result {
|
||||
if len(seqIDs) == 0 {
|
||||
return Result{}
|
||||
}
|
||||
res := Result{Token: scores}
|
||||
|
||||
if s.Logprobs {
|
||||
// Compute log_softmax in fp32 and subtract the max before
|
||||
// logsumexp so the final subtraction stays on small values.
|
||||
// Otherwise it cancels two large numbers and loses precision.
|
||||
slots := make([]*slotState, len(seqIDs))
|
||||
for i, id := range seqIDs {
|
||||
slot, ok := s.byID[id]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("sample.Sampler.Sample: seqID %d not registered", id))
|
||||
}
|
||||
slots[i] = slot
|
||||
}
|
||||
|
||||
var token *mlx.Array
|
||||
if opts0, ok := s.canBatch(slots); ok {
|
||||
token = s.sampleTokensUniform(slots, opts0, logits)
|
||||
} else {
|
||||
token = s.sampleTokensSerial(slots, logits)
|
||||
}
|
||||
|
||||
res := Result{Token: token}
|
||||
if s.anyLogprobs {
|
||||
// Log-softmax over original logits so every row holds a truthful
|
||||
// value (compute-for-all; consumers filter per-slot). Subtract
|
||||
// max first for numerical stability in the logsumexp.
|
||||
lp := logits.AsType(mlx.DTypeFloat32)
|
||||
lp = lp.Subtract(lp.MaxAxis(-1, true))
|
||||
lp = lp.Subtract(lp.Logsumexp(true))
|
||||
res.Logprob = lp.TakeAlongAxis(res.Token.ExpandDims(-1), -1)
|
||||
if k := s.TopLogprobs; k > 0 {
|
||||
lp = lp.Subtract(lp.LogsumexpAxis(-1, true))
|
||||
res.Logprob = lp.TakeAlongAxis(token.ExpandDims(-1), -1)
|
||||
if s.maxTopLogprobs > 0 {
|
||||
k := s.maxTopLogprobs
|
||||
if vocab := lp.Dim(lp.NumDims() - 1); k > vocab {
|
||||
k = vocab
|
||||
}
|
||||
@@ -176,55 +349,180 @@ func (s *Sampler) Sample(logits *mlx.Array) Result {
|
||||
return res
|
||||
}
|
||||
|
||||
func greedy(_ *Sampler, scores *mlx.Array) *mlx.Array {
|
||||
return scores.Argmax(-1, false)
|
||||
// canBatch reports whether the call can take the uniform batched path.
|
||||
// All slots must share Options; when penalties are active the call must
|
||||
// additionally cover every registered slot in registration order with a
|
||||
// full ring, because the uniform path indexes the pool positionally.
|
||||
func (s *Sampler) canBatch(slots []*slotState) (Options, bool) {
|
||||
if !s.allSameOpts {
|
||||
return Options{}, false
|
||||
}
|
||||
// slots is non-empty (Sample guards) and every slot is registered,
|
||||
// so s.slots[0].opts is the canonical shared value.
|
||||
shared := s.slots[0].opts
|
||||
if !shared.usesHistory() {
|
||||
return shared, true
|
||||
}
|
||||
if len(slots) != len(s.slots) {
|
||||
return Options{}, false
|
||||
}
|
||||
for i, slot := range slots {
|
||||
if s.slots[i] != slot || slot.historyLen < shared.RepeatLastN {
|
||||
return Options{}, false
|
||||
}
|
||||
}
|
||||
return shared, true
|
||||
}
|
||||
|
||||
func temperature(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||
return mlx.DivScalar(scores, s.Temperature).Categorical(-1)
|
||||
// sampleTokensUniform runs one fused transform pass over the whole batch.
|
||||
// Reached only when canBatch is true, which lets the pool be used in place
|
||||
// with a single PutAlongAxis write-back and no gather.
|
||||
func (s *Sampler) sampleTokensUniform(slots []*slotState, opts Options, logits *mlx.Array) *mlx.Array {
|
||||
B := len(slots)
|
||||
|
||||
var hist *mlx.Array
|
||||
if opts.usesHistory() {
|
||||
hist = s.history
|
||||
if s.historyWidth() > opts.RepeatLastN {
|
||||
hist = hist.Slice(mlx.Slice(), mlx.Slice(0, opts.RepeatLastN))
|
||||
}
|
||||
}
|
||||
|
||||
ctx := &slotCtx{opts: opts, history: hist}
|
||||
scores := logits
|
||||
for _, t := range slots[0].transforms {
|
||||
scores = t(ctx, scores)
|
||||
}
|
||||
token := scores
|
||||
|
||||
if !opts.usesHistory() {
|
||||
return token
|
||||
}
|
||||
|
||||
writeIdxData := make([]int32, B)
|
||||
for i, slot := range slots {
|
||||
writeIdxData[i] = int32(slot.historyLen % opts.RepeatLastN)
|
||||
slot.historyLen++
|
||||
}
|
||||
writeIdx := mlx.NewArrayInt32(writeIdxData, []int32{int32(B), 1})
|
||||
|
||||
s.history.Set(s.history.PutAlongAxis(writeIdx, token.ExpandDims(-1), 1))
|
||||
return token
|
||||
}
|
||||
|
||||
// sampleTokensSerial runs each slot's transforms against its own row of
|
||||
// logits.
|
||||
func (s *Sampler) sampleTokensSerial(slots []*slotState, logits *mlx.Array) *mlx.Array {
|
||||
perSlotTokens := make([]*mlx.Array, len(slots))
|
||||
|
||||
rowOf := make(map[*slotState]int, len(s.slots))
|
||||
for i, slot := range s.slots {
|
||||
rowOf[slot] = i
|
||||
}
|
||||
|
||||
for i, slot := range slots {
|
||||
row := logits.Slice(mlx.Slice(i, i+1), mlx.Slice())
|
||||
|
||||
var hist *mlx.Array
|
||||
if slot.opts.usesHistory() && slot.historyLen > 0 && s.history != nil {
|
||||
poolRow := rowOf[slot]
|
||||
fill := min(slot.historyLen, slot.opts.RepeatLastN)
|
||||
hist = s.history.Slice(
|
||||
mlx.Slice(poolRow, poolRow+1),
|
||||
mlx.Slice(0, fill),
|
||||
)
|
||||
}
|
||||
|
||||
ctx := &slotCtx{opts: slot.opts, history: hist}
|
||||
scores := row
|
||||
for _, t := range slot.transforms {
|
||||
scores = t(ctx, scores)
|
||||
}
|
||||
perSlotTokens[i] = scores
|
||||
}
|
||||
|
||||
token := mlx.Concatenate(perSlotTokens, 0)
|
||||
|
||||
if s.history != nil {
|
||||
// For each writing slot collect its flat (row-major) pool offset
|
||||
// and the call-order position of its token. One PutAlongAxis on a
|
||||
// flat view of the pool scatters all writes in a single op.
|
||||
flatOffsets := make([]int32, 0, len(slots))
|
||||
tokenPos := make([]int32, 0, len(slots))
|
||||
for i, slot := range slots {
|
||||
if !slot.opts.usesHistory() {
|
||||
continue
|
||||
}
|
||||
ringPos := slot.historyLen % slot.opts.RepeatLastN
|
||||
flatOffsets = append(flatOffsets, int32(rowOf[slot]*s.historyWidth()+ringPos))
|
||||
tokenPos = append(tokenPos, int32(i))
|
||||
slot.historyLen++
|
||||
}
|
||||
|
||||
if len(flatOffsets) > 0 {
|
||||
m := len(flatOffsets)
|
||||
flatIdx := mlx.NewArrayInt32(flatOffsets, []int32{int32(m), 1})
|
||||
writingTokens := token
|
||||
if m != len(slots) {
|
||||
tokenPosIdx := mlx.NewArrayInt32(tokenPos, []int32{int32(m)})
|
||||
writingTokens = token.TakeAxis(tokenPosIdx, 0)
|
||||
}
|
||||
flatHist := s.history.Reshape(s.history.Dim(0)*s.historyWidth(), 1)
|
||||
s.history.Set(flatHist.PutAlongAxis(flatIdx, writingTokens.ExpandDims(-1), 0).Reshape(s.history.Dim(0), s.historyWidth()))
|
||||
}
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func greedy(_ *slotCtx, scores *mlx.Array) *mlx.Array {
|
||||
return scores.Argmax(-1, false).AsType(mlx.DTypeInt32)
|
||||
}
|
||||
|
||||
func temperature(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
|
||||
return mlx.DivScalar(scores, ctx.opts.Temperature).Categorical(-1).AsType(mlx.DTypeInt32)
|
||||
}
|
||||
|
||||
// topKTopP applies top-P in a descending sort pass and, when top-K is also
|
||||
// configured, masks any surviving value below the K-th largest in the same
|
||||
// pass. Callers dispatch here whenever top-P is enabled — the top-K-only
|
||||
// case uses a cheaper partial sort via the topK transform.
|
||||
func topKTopP(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||
// pass. Callers dispatch here whenever top-P is enabled — the top-K-only case
|
||||
// uses a cheaper partial sort via the topK transform.
|
||||
func topKTopP(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
|
||||
vocab := scores.Dim(scores.NumDims() - 1)
|
||||
applyTopK := s.TopK > 0 && s.TopK < vocab
|
||||
applyTopK := ctx.opts.TopK > 0 && ctx.opts.TopK < vocab
|
||||
|
||||
order := scores.Negative().ArgsortAxis(-1)
|
||||
sorted := scores.TakeAlongAxis(order, -1)
|
||||
negInf := mlx.FromValue(float32(math.Inf(-1)))
|
||||
|
||||
// Top-P: in descending order, keep tokens whose exclusive cumulative
|
||||
// probability is still below s.TopP.
|
||||
// probability is still below TopP.
|
||||
probs := mlx.SoftmaxAxis(sorted, -1, true)
|
||||
prevCumProbs := probs.Cumsum(-1, false, true).Subtract(probs)
|
||||
keep := prevCumProbs.Less(mlx.FromValue(s.TopP))
|
||||
keep := prevCumProbs.Less(mlx.FromValue(ctx.opts.TopP))
|
||||
sorted = mlx.Where(keep, sorted, negInf)
|
||||
|
||||
out := scores.PutAlongAxis(order, sorted, -1)
|
||||
|
||||
// Top-K: sorted is already in descending order, so positions [K, V)
|
||||
// are the ones to drop. Scatter -inf through their original-layout
|
||||
// indices (order[K:]). Positional (not value-based) so exactly K
|
||||
// tokens survive — ties at the K-th logit get broken by the sort
|
||||
// order rather than promoted through the filter.
|
||||
// Top-K: sorted is already in descending order, so positions [K, V) are
|
||||
// the ones to drop. Scatter -inf through their original-layout indices
|
||||
// (order[K:]). Positional (not value-based) so exactly K tokens survive —
|
||||
// ties at the K-th logit get broken by the sort order rather than
|
||||
// promoted through the filter.
|
||||
if applyTopK {
|
||||
dropOrder := order.Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
|
||||
dropOrder := order.Slice(mlx.Slice(), mlx.Slice(ctx.opts.TopK, mlx.End))
|
||||
out = out.PutAlongAxis(dropOrder, negInf, -1)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func minP(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||
if s.MinP <= 0 || s.MinP > 1 {
|
||||
func minP(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
|
||||
if ctx.opts.MinP <= 0 || ctx.opts.MinP > 1 {
|
||||
return scores
|
||||
}
|
||||
|
||||
maxScore := scores.MaxAxis(-1, true)
|
||||
threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(s.MinP))))
|
||||
threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(ctx.opts.MinP))))
|
||||
|
||||
return mlx.Where(
|
||||
scores.Less(threshold),
|
||||
@@ -233,48 +531,43 @@ func minP(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||
)
|
||||
}
|
||||
|
||||
func topK(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||
if s.TopK <= 0 {
|
||||
func topK(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
|
||||
if ctx.opts.TopK <= 0 {
|
||||
return scores
|
||||
}
|
||||
|
||||
vocab := scores.Dim(scores.NumDims() - 1)
|
||||
if s.TopK >= vocab {
|
||||
if ctx.opts.TopK >= vocab {
|
||||
return scores
|
||||
}
|
||||
|
||||
mask := scores.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
|
||||
mask := scores.Negative().ArgpartitionAxis(ctx.opts.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(ctx.opts.TopK, mlx.End))
|
||||
return scores.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
||||
}
|
||||
|
||||
func penalty(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||
if s.historyLen == 0 {
|
||||
func penalty(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
|
||||
tokenIndices := ctx.history
|
||||
if tokenIndices == nil {
|
||||
return scores
|
||||
}
|
||||
|
||||
tokenIndices := s.history
|
||||
if scores.NumDims() > 1 {
|
||||
tokenIndices = tokenIndices.ExpandDims(0)
|
||||
}
|
||||
|
||||
if s.RepeatPenalty != 1 || s.PresencePenalty != 0 {
|
||||
if ctx.opts.RepeatPenalty != 1 || ctx.opts.PresencePenalty != 0 {
|
||||
adjusted := scores.TakeAlongAxis(tokenIndices, -1)
|
||||
if s.RepeatPenalty != 1 {
|
||||
if ctx.opts.RepeatPenalty != 1 {
|
||||
factor := mlx.Where(
|
||||
adjusted.Less(mlx.FromValue(float32(0))),
|
||||
mlx.FromValue(s.RepeatPenalty),
|
||||
mlx.FromValue(1/s.RepeatPenalty),
|
||||
mlx.FromValue(ctx.opts.RepeatPenalty),
|
||||
mlx.FromValue(1/ctx.opts.RepeatPenalty),
|
||||
)
|
||||
adjusted = adjusted.Multiply(factor)
|
||||
}
|
||||
if s.PresencePenalty != 0 {
|
||||
adjusted = mlx.AddScalar(adjusted, -s.PresencePenalty)
|
||||
if ctx.opts.PresencePenalty != 0 {
|
||||
adjusted = mlx.AddScalar(adjusted, -ctx.opts.PresencePenalty)
|
||||
}
|
||||
scores = scores.PutAlongAxis(tokenIndices, adjusted, -1)
|
||||
}
|
||||
|
||||
if s.FrequencyPenalty != 0 {
|
||||
scores = scores.ScatterAddAxis(tokenIndices, mlx.FromValue(-s.FrequencyPenalty), -1)
|
||||
if ctx.opts.FrequencyPenalty != 0 {
|
||||
scores = scores.ScatterAddAxis(tokenIndices, mlx.FromValue(-ctx.opts.FrequencyPenalty), -1)
|
||||
}
|
||||
|
||||
return scores
|
||||
|
||||
@@ -9,93 +9,283 @@ import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
|
||||
s := New(Options{RepeatLastN: 1, PresencePenalty: 6})
|
||||
defer func() {
|
||||
// slotLogits builds a [1, V] logits tensor for a single-slot Sample call.
|
||||
func slotLogits(values []float32) *mlx.Array {
|
||||
return mlx.FromValues(values, 1, len(values))
|
||||
}
|
||||
|
||||
// batchLogits stacks per-row float32 slices of equal length into a [B, V]
|
||||
// logits tensor.
|
||||
func batchLogits(rows ...[]float32) *mlx.Array {
|
||||
v := len(rows[0])
|
||||
flat := make([]float32, 0, len(rows)*v)
|
||||
for _, r := range rows {
|
||||
if len(r) != v {
|
||||
panic("batchLogits: rows must share vocab size")
|
||||
}
|
||||
flat = append(flat, r...)
|
||||
}
|
||||
return mlx.FromValues(flat, len(rows), v)
|
||||
}
|
||||
|
||||
// sampleOne runs Sample on a freshly-added single slot and returns the
|
||||
// sampled token id. Used both for the single-slot options table and as the
|
||||
// reference oracle for the batched-equivalence test.
|
||||
func sampleOne(t *testing.T, opts Options, priorTokens []int32, values []float32) int {
|
||||
t.Helper()
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
}()
|
||||
})
|
||||
s.Add(0, opts, priorTokens)
|
||||
|
||||
s.ResetHistory([]int32{0})
|
||||
s.AppendToken(mlx.NewArrayInt32([]int32{1}, []int32{1}))
|
||||
|
||||
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
|
||||
got := s.Sample(logits).Token
|
||||
got := s.Sample([]int{0}, slotLogits(values)).Token
|
||||
mlx.Eval(got)
|
||||
return got.Int()
|
||||
}
|
||||
|
||||
// logits will be [0, -1, 4] after the penalty
|
||||
// and then (index) 2 after the greedy sampler
|
||||
gotInt := got.Int()
|
||||
if gotInt != 2 {
|
||||
t.Fatalf("got %d, want 2", gotInt)
|
||||
// logOf returns log(p) as a float32 so tests can build logits that softmax to
|
||||
// a chosen probability distribution.
|
||||
func logOf(p float64) float32 { return float32(math.Log(p)) }
|
||||
|
||||
// TestSampleSingleSlotOptions pins the per-slot behavior of each Options
|
||||
// knob against a concrete expected token. Expected values are worked out by
|
||||
// hand from the math of each transform, not from a second call into the
|
||||
// sampler — so a regression in any single transform shows up here.
|
||||
func TestSampleSingleSlotOptions(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
opts Options
|
||||
priors []int32
|
||||
logits []float32
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "presence penalty",
|
||||
opts: Options{RepeatLastN: 1, PresencePenalty: 6},
|
||||
priors: []int32{1},
|
||||
logits: []float32{0, 5, 4},
|
||||
want: 2, // token 1: 5 - 6 = -1, argmax shifts to 2
|
||||
},
|
||||
{
|
||||
name: "repeat penalty on positive logits",
|
||||
opts: Options{RepeatLastN: 1, RepeatPenalty: 2},
|
||||
priors: []int32{1},
|
||||
logits: []float32{0, 5, 4},
|
||||
want: 2, // token 1 positive → divided: 5/2 = 2.5, argmax shifts to 2
|
||||
},
|
||||
{
|
||||
name: "repeat penalty on negative logits",
|
||||
opts: Options{RepeatLastN: 1, RepeatPenalty: 4},
|
||||
priors: []int32{1},
|
||||
logits: []float32{-5, -1, -3},
|
||||
want: 2, // token 1 negative → multiplied: -1*4 = -4, argmax shifts to 2
|
||||
},
|
||||
{
|
||||
name: "frequency penalty",
|
||||
opts: Options{RepeatLastN: 4, FrequencyPenalty: 2},
|
||||
priors: []int32{1, 1},
|
||||
logits: []float32{0, 5, 4},
|
||||
want: 2, // 5 - 2*count(1)=2*2=4 → 1, argmax shifts to 2
|
||||
},
|
||||
{
|
||||
name: "top-k",
|
||||
opts: Options{Temperature: 1, TopK: 1},
|
||||
logits: []float32{1, 5, 4},
|
||||
want: 1, // only argmax survives → deterministic even with temperature
|
||||
},
|
||||
{
|
||||
name: "top-p",
|
||||
opts: Options{Temperature: 1, TopP: 0.4},
|
||||
logits: []float32{logOf(0.5), logOf(0.3), logOf(0.2)},
|
||||
want: 0, // exclusive cumsum below 0.4 keeps only token 0
|
||||
},
|
||||
{
|
||||
name: "min-p",
|
||||
opts: Options{Temperature: 1, MinP: 0.7},
|
||||
logits: []float32{logOf(0.5), logOf(0.3), logOf(0.2)},
|
||||
want: 0, // threshold 0.5*0.7=0.35 drops all but the top token
|
||||
},
|
||||
{
|
||||
name: "RepeatLastN=0 disables penalties",
|
||||
opts: Options{RepeatLastN: 0, RepeatPenalty: 2, PresencePenalty: 10},
|
||||
priors: []int32{1},
|
||||
logits: []float32{0, 5, 4},
|
||||
want: 1, // 0 = disabled per API contract, argmax unchanged
|
||||
},
|
||||
{
|
||||
name: "RepeatLastN=-1 resolves to num_ctx",
|
||||
opts: Options{RepeatLastN: -1, PresencePenalty: 6},
|
||||
priors: []int32{1},
|
||||
logits: []float32{0, 5, 4},
|
||||
want: 2, // -1 → num_ctx (128); penalty applies, argmax shifts
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := sampleOne(t, tc.opts, tc.priors, tc.logits); got != tc.want {
|
||||
t.Errorf("got %d, want %d", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepeatPenaltyUsesHistoryWithoutPresencePenalty(t *testing.T) {
|
||||
s := New(Options{RepeatLastN: 1, RepeatPenalty: 2})
|
||||
defer func() {
|
||||
// TestSampleHistoryWindow verifies that penalty history respects the
|
||||
// RepeatLastN window: priors longer than RepeatLastN are trimmed on Add,
|
||||
// and once the ring wraps, tokens that rotate out no longer contribute
|
||||
// to penalties.
|
||||
func TestSampleHistoryWindow(t *testing.T) {
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
}()
|
||||
})
|
||||
|
||||
s.ResetHistory([]int32{1})
|
||||
// RepeatLastN=2 with priors {1, 2, 3}: makeHistoryRow keeps only
|
||||
// {2, 3}. Token 1 was trimmed — its penalty is NOT active.
|
||||
s.Add(0, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{1, 2, 3})
|
||||
|
||||
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
|
||||
got := s.Sample(logits).Token
|
||||
mlx.Eval(got)
|
||||
// Step 1: logits favor token 1 (trimmed). If the trim were broken it
|
||||
// would be penalized and the argmax would move.
|
||||
step1 := s.Sample([]int{0}, slotLogits([]float32{0, 5, 0, 0, 0})).Token
|
||||
mlx.Eval(step1)
|
||||
if got := step1.Int(); got != 1 {
|
||||
t.Fatalf("step 1 = %d, want 1 (token 1 trimmed from priors)", got)
|
||||
}
|
||||
// After step 1 the ring holds {1, 3}; token 2 has rotated out.
|
||||
|
||||
// token 1 is repeated and positive, so 5 / 2 falls below token 2.
|
||||
gotInt := got.Int()
|
||||
if gotInt != 2 {
|
||||
t.Fatalf("got %d, want 2", gotInt)
|
||||
// Step 2: logits favor token 2 (rotated out). If the ring wrap were
|
||||
// wrong, token 2 would still be penalized.
|
||||
step2 := s.Sample([]int{0}, slotLogits([]float32{0, 0, 5, 0, 0})).Token
|
||||
mlx.Eval(step2)
|
||||
if got := step2.Int(); got != 2 {
|
||||
t.Fatalf("step 2 = %d, want 2 (token 2 rotated out of ring)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrequencyPenaltyUsesTokenCounts(t *testing.T) {
|
||||
s := New(Options{RepeatLastN: 4, FrequencyPenalty: 2})
|
||||
defer func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
}()
|
||||
// TestBatchSamplingPreservesPerSlotBehavior is the core equivalence test:
|
||||
// for every representative dispatch branch (uniform, serial on mixed opts,
|
||||
// serial on partial ring, subset/out-of-order), a batched Sample call must
|
||||
// produce the same token per row as running the same slot alone.
|
||||
func TestBatchSamplingPreservesPerSlotBehavior(t *testing.T) {
|
||||
type slot struct {
|
||||
id int
|
||||
opts Options
|
||||
priors []int32
|
||||
}
|
||||
|
||||
s.ResetHistory([]int32{1, 1})
|
||||
cases := []struct {
|
||||
name string
|
||||
slots []slot
|
||||
sample []int
|
||||
rows [][]float32
|
||||
}{
|
||||
{
|
||||
name: "uniform",
|
||||
slots: []slot{
|
||||
{10, Options{RepeatLastN: 2, PresencePenalty: 5}, []int32{1, 2}},
|
||||
{20, Options{RepeatLastN: 2, PresencePenalty: 5}, []int32{0, 2}},
|
||||
},
|
||||
sample: []int{10, 20},
|
||||
rows: [][]float32{{0, 5, 4}, {3, 0, 0}},
|
||||
},
|
||||
{
|
||||
name: "serial — mixed opts",
|
||||
slots: []slot{
|
||||
{1, Options{RepeatLastN: 1, RepeatPenalty: 2}, []int32{1}},
|
||||
{2, Options{Temperature: 1, TopK: 1}, nil},
|
||||
},
|
||||
sample: []int{1, 2},
|
||||
rows: [][]float32{{0, 5, 4, 1}, {2, 1, 5, 3}},
|
||||
},
|
||||
{
|
||||
name: "serial — partial ring",
|
||||
slots: []slot{
|
||||
{1, Options{RepeatLastN: 4, PresencePenalty: 5}, []int32{1, 1, 1, 1}},
|
||||
{2, Options{RepeatLastN: 4, PresencePenalty: 5}, []int32{2}},
|
||||
},
|
||||
sample: []int{1, 2},
|
||||
rows: [][]float32{{0, 5, 4}, {0, 4, 5}},
|
||||
},
|
||||
{
|
||||
name: "subset out-of-order",
|
||||
slots: []slot{
|
||||
{10, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{1, 1}},
|
||||
{20, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{2, 2}},
|
||||
{30, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{3, 3}},
|
||||
},
|
||||
sample: []int{30, 10},
|
||||
rows: [][]float32{{5, 5, 5, 0, 5, 5}, {5, 0, 5, 5, 0, 5}},
|
||||
},
|
||||
}
|
||||
|
||||
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
|
||||
got := s.Sample(logits).Token
|
||||
mlx.Eval(got)
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Per-slot reference for each sampled seq.
|
||||
want := make([]int, len(tc.sample))
|
||||
for i, id := range tc.sample {
|
||||
var spec slot
|
||||
for _, s := range tc.slots {
|
||||
if s.id == id {
|
||||
spec = s
|
||||
break
|
||||
}
|
||||
}
|
||||
want[i] = sampleOne(t, spec.opts, spec.priors, tc.rows[i])
|
||||
}
|
||||
|
||||
// token 1 appears twice, so 5 - (2 * 2) falls below token 2.
|
||||
gotInt := got.Int()
|
||||
if gotInt != 2 {
|
||||
t.Fatalf("got %d, want 2", gotInt)
|
||||
// Batched call.
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
for _, spec := range tc.slots {
|
||||
s.Add(spec.id, spec.opts, spec.priors)
|
||||
}
|
||||
res := s.Sample(tc.sample, batchLogits(tc.rows...))
|
||||
mlx.Eval(res.Token)
|
||||
got := res.Token.Ints()
|
||||
|
||||
for i, id := range tc.sample {
|
||||
if got[i] != want[i] {
|
||||
t.Errorf("seq %d: batched = %d, per-slot = %d", id, got[i], want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinPMasksTokensBelowThreshold(t *testing.T) {
|
||||
s := New(Options{MinP: 0.5})
|
||||
defer func() {
|
||||
// TestRemoveDoesNotLeakHistory: after Remove, a newly-added slot at the
|
||||
// recycled row must start from its own priors only — no carryover from
|
||||
// the removed slot's history.
|
||||
func TestRemoveDoesNotLeakHistory(t *testing.T) {
|
||||
opts := Options{RepeatLastN: 1, PresencePenalty: 10}
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
}()
|
||||
})
|
||||
s.Add(1, opts, []int32{1})
|
||||
s.Add(2, opts, []int32{2})
|
||||
s.Remove(1)
|
||||
s.Add(3, opts, []int32{0})
|
||||
|
||||
logits := mlx.FromValues([]float32{
|
||||
float32(math.Log(0.5)),
|
||||
float32(math.Log(0.3)),
|
||||
float32(math.Log(0.2)),
|
||||
}, 3)
|
||||
got := minP(s, logits)
|
||||
mlx.Eval(got)
|
||||
|
||||
gotFloats := got.Floats()
|
||||
if len(gotFloats) != 3 {
|
||||
t.Fatalf("got %d scores, want 3", len(gotFloats))
|
||||
// Slot 2 retains history {2}; slot 3 retains history {0}. With
|
||||
// equal logits and PresencePenalty=10 the argmax drops to the first
|
||||
// unpenalized token.
|
||||
res := s.Sample([]int{2, 3}, batchLogits(
|
||||
[]float32{3, 3, 0},
|
||||
[]float32{3, 3, 0},
|
||||
))
|
||||
mlx.Eval(res.Token)
|
||||
tokens := res.Token.Ints()
|
||||
if tokens[0] != 0 {
|
||||
t.Errorf("slot 2 = %d, want 0 (token 2 penalized)", tokens[0])
|
||||
}
|
||||
|
||||
if math.IsInf(float64(gotFloats[0]), -1) || math.IsInf(float64(gotFloats[1]), -1) {
|
||||
t.Fatalf("kept tokens were masked: %v", gotFloats)
|
||||
}
|
||||
|
||||
if !math.IsInf(float64(gotFloats[2]), -1) {
|
||||
t.Fatalf("lowest-probability token should be masked, got %v", gotFloats)
|
||||
if tokens[1] != 1 {
|
||||
t.Errorf("slot 3 = %d, want 1 (token 0 penalized, no slot-1 carryover)", tokens[1])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ func Execute(args []string) error {
|
||||
}
|
||||
|
||||
request.Pipeline = runner.TextGenerationPipeline
|
||||
request.Sampler = sample.New(sample.Options{
|
||||
request.SamplerOpts = sample.Options{
|
||||
Temperature: request.Options.Temperature,
|
||||
TopP: request.Options.TopP,
|
||||
MinP: request.Options.MinP,
|
||||
@@ -104,7 +104,12 @@ func Execute(args []string) error {
|
||||
FrequencyPenalty: request.Options.FrequencyPenalty,
|
||||
Logprobs: request.Logprobs,
|
||||
TopLogprobs: request.TopLogprobs,
|
||||
})
|
||||
}
|
||||
|
||||
if err := runner.Prepare(&request); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var cancel context.CancelFunc
|
||||
request.Ctx, cancel = context.WithCancel(r.Context())
|
||||
|
||||
Reference in New Issue
Block a user