Files
ollama-ollama/x/tokenizer/tokenizer_bpe.go
Daniel Hiltgen 10e51c5177 MLX: add header vendoring and remove go build tag (#14642)
* prefer rocm v6 on windows

Avoid building with v7 - more changes are needed

* MLX: add header vendoring and remove go build tag

This switches to using a vendoring approach for the mlx-c headers so that Go
can build without requiring a cmake first.  This enables building the new MLX
based code by default.  Every time cmake runs, the headers are refreshed, so we
can easily keep them in sync when we bump mlx versions.  Basic Windows
and Linux support are verified.

* ci: harden for flaky choco repo servers

CI sometimes fails due to choco not actually installing cache.  Since it just speeds up the build, we can proceed without.

* review comments
2026-03-09 17:24:45 -07:00

174 lines
3.5 KiB
Go

package tokenizer
import "container/heap"
type bpeMergeNode struct {
prev int
next int
token string
}
type bpePair struct {
left int
right int
rank int
value string
}
type bpePairHeap []*bpePair
func (h bpePairHeap) Len() int { return len(h) }
func (h bpePairHeap) Less(i, j int) bool {
return h[i].rank < h[j].rank || (h[i].rank == h[j].rank && h[i].left < h[j].left)
}
func (h bpePairHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *bpePairHeap) Push(x any) {
*h = append(*h, x.(*bpePair))
}
func (h *bpePairHeap) Pop() any {
old := *h
n := len(old)
item := old[n-1]
*h = old[:n-1]
return item
}
// encodeBPEMerge encodes using BPE merge algorithm.
// Uses the heap/linked-list pair merge strategy from tokenizer/bytepairencoding.go:
// merge the lowest-rank valid pair, then only recheck adjacent pairs.
func (t *Tokenizer) encodeBPEMerge(encoded string, ids []int32) []int32 {
runes := []rune(encoded)
if len(runes) == 0 {
return ids
}
nodes := make([]bpeMergeNode, len(runes))
for i := range runes {
nodes[i] = bpeMergeNode{
prev: i - 1,
next: i + 1,
token: string(runes[i]),
}
}
pairwise := func(left, right int) *bpePair {
if left < 0 || right >= len(nodes) {
return nil
}
if nodes[left].token == "" || nodes[right].token == "" {
return nil
}
leftToken, rightToken := nodes[left].token, nodes[right].token
rank, ok := t.vocab.Merges[leftToken+" "+rightToken]
if !ok {
return nil
}
value := leftToken + rightToken
if _, ok := t.vocab.Reverse[value]; !ok {
return nil
}
return &bpePair{
left: left,
right: right,
rank: rank,
value: value,
}
}
pairs := bpePairHeap{}
heap.Init(&pairs)
for i := 0; i < len(runes)-1; i++ {
if pair := pairwise(i, i+1); pair != nil {
heap.Push(&pairs, pair)
}
}
for pairs.Len() > 0 {
pair := heap.Pop(&pairs).(*bpePair)
left, right := nodes[pair.left], nodes[pair.right]
if left.token == "" || right.token == "" {
continue
}
if left.next != pair.right || right.prev != pair.left {
continue
}
if left.token+right.token != pair.value {
continue
}
nodes[pair.left].token = pair.value
nodes[pair.right].token = ""
nodes[pair.left].next = right.next
if right.next < len(nodes) {
nodes[right.next].prev = pair.left
}
if pair := pairwise(nodes[pair.left].prev, pair.left); pair != nil {
heap.Push(&pairs, pair)
}
if pair := pairwise(pair.left, nodes[pair.left].next); pair != nil {
heap.Push(&pairs, pair)
}
}
for _, node := range nodes {
if node.token == "" {
continue
}
if id, ok := t.vocab.Reverse[node.token]; ok {
ids = append(ids, id)
continue
}
ids = t.appendByteFallback(ids, node.token)
}
return ids
}
func (t *Tokenizer) appendByteFallback(ids []int32, token string) []int32 {
if t.typ == TokenizerBPE {
for _, r := range token {
if b, ok := decodeByteLevelRune(r); ok {
if id := t.vocab.byteTokens[b]; id >= 0 {
ids = append(ids, id)
}
}
}
return ids
}
// SentencePiece fallback uses the UTF-8 bytes for <0xNN> tokens.
for _, b := range []byte(token) {
if id := t.vocab.byteTokens[b]; id >= 0 {
ids = append(ids, id)
}
}
return ids
}
func decodeByteLevelRune(r rune) (byte, bool) {
switch {
case r >= 0x00 && r <= 0xFF:
return byte(r), true
case r == 0x0100:
return 0x00, true
case r == 0x0143:
return 0x00ad, true
case r > 0x0100 && r <= 0x0120:
return byte(r - 0x0100), true
case r > 0x0120 && r <= 0x0142:
return byte(r - 0x00a2), true
default:
return 0, false
}
}