mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 02:54:17 +02:00
This change adds a new x/tokenizer package which includes: * New BPE and SentencePiece tokenizers * Removing the dependency on the imagegen tokenizers * Fixes to multibyte decoding in the pipeline * Various correctness and benchmark tests Not included in this PR is the WordPiece tokenizer for BERT models which will be added when we add embedding models. The imagegen tokenizers will also be removed in a follow-up PR.
176 lines
3.5 KiB
Go
176 lines
3.5 KiB
Go
//go:build mlx
|
|
|
|
package tokenizer
|
|
|
|
import "container/heap"
|
|
|
|
type bpeMergeNode struct {
|
|
prev int
|
|
next int
|
|
token string
|
|
}
|
|
|
|
type bpePair struct {
|
|
left int
|
|
right int
|
|
rank int
|
|
value string
|
|
}
|
|
|
|
type bpePairHeap []*bpePair
|
|
|
|
func (h bpePairHeap) Len() int { return len(h) }
|
|
|
|
func (h bpePairHeap) Less(i, j int) bool {
|
|
return h[i].rank < h[j].rank || (h[i].rank == h[j].rank && h[i].left < h[j].left)
|
|
}
|
|
|
|
func (h bpePairHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
|
|
|
func (h *bpePairHeap) Push(x any) {
|
|
*h = append(*h, x.(*bpePair))
|
|
}
|
|
|
|
func (h *bpePairHeap) Pop() any {
|
|
old := *h
|
|
n := len(old)
|
|
item := old[n-1]
|
|
*h = old[:n-1]
|
|
return item
|
|
}
|
|
|
|
// encodeBPEMerge encodes using BPE merge algorithm.
|
|
// Uses the heap/linked-list pair merge strategy from tokenizer/bytepairencoding.go:
|
|
// merge the lowest-rank valid pair, then only recheck adjacent pairs.
|
|
func (t *Tokenizer) encodeBPEMerge(encoded string, ids []int32) []int32 {
|
|
runes := []rune(encoded)
|
|
if len(runes) == 0 {
|
|
return ids
|
|
}
|
|
|
|
nodes := make([]bpeMergeNode, len(runes))
|
|
for i := range runes {
|
|
nodes[i] = bpeMergeNode{
|
|
prev: i - 1,
|
|
next: i + 1,
|
|
token: string(runes[i]),
|
|
}
|
|
}
|
|
|
|
pairwise := func(left, right int) *bpePair {
|
|
if left < 0 || right >= len(nodes) {
|
|
return nil
|
|
}
|
|
if nodes[left].token == "" || nodes[right].token == "" {
|
|
return nil
|
|
}
|
|
|
|
leftToken, rightToken := nodes[left].token, nodes[right].token
|
|
rank, ok := t.vocab.Merges[leftToken+" "+rightToken]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
value := leftToken + rightToken
|
|
if _, ok := t.vocab.Reverse[value]; !ok {
|
|
return nil
|
|
}
|
|
|
|
return &bpePair{
|
|
left: left,
|
|
right: right,
|
|
rank: rank,
|
|
value: value,
|
|
}
|
|
}
|
|
|
|
pairs := bpePairHeap{}
|
|
heap.Init(&pairs)
|
|
for i := 0; i < len(runes)-1; i++ {
|
|
if pair := pairwise(i, i+1); pair != nil {
|
|
heap.Push(&pairs, pair)
|
|
}
|
|
}
|
|
|
|
for pairs.Len() > 0 {
|
|
pair := heap.Pop(&pairs).(*bpePair)
|
|
left, right := nodes[pair.left], nodes[pair.right]
|
|
if left.token == "" || right.token == "" {
|
|
continue
|
|
}
|
|
if left.next != pair.right || right.prev != pair.left {
|
|
continue
|
|
}
|
|
if left.token+right.token != pair.value {
|
|
continue
|
|
}
|
|
|
|
nodes[pair.left].token = pair.value
|
|
nodes[pair.right].token = ""
|
|
nodes[pair.left].next = right.next
|
|
if right.next < len(nodes) {
|
|
nodes[right.next].prev = pair.left
|
|
}
|
|
|
|
if pair := pairwise(nodes[pair.left].prev, pair.left); pair != nil {
|
|
heap.Push(&pairs, pair)
|
|
}
|
|
if pair := pairwise(pair.left, nodes[pair.left].next); pair != nil {
|
|
heap.Push(&pairs, pair)
|
|
}
|
|
}
|
|
|
|
for _, node := range nodes {
|
|
if node.token == "" {
|
|
continue
|
|
}
|
|
|
|
if id, ok := t.vocab.Reverse[node.token]; ok {
|
|
ids = append(ids, id)
|
|
continue
|
|
}
|
|
|
|
ids = t.appendByteFallback(ids, node.token)
|
|
}
|
|
|
|
return ids
|
|
}
|
|
|
|
func (t *Tokenizer) appendByteFallback(ids []int32, token string) []int32 {
|
|
if t.typ == TokenizerBPE {
|
|
for _, r := range token {
|
|
if b, ok := decodeByteLevelRune(r); ok {
|
|
if id := t.vocab.byteTokens[b]; id >= 0 {
|
|
ids = append(ids, id)
|
|
}
|
|
}
|
|
}
|
|
return ids
|
|
}
|
|
|
|
// SentencePiece fallback uses the UTF-8 bytes for <0xNN> tokens.
|
|
for _, b := range []byte(token) {
|
|
if id := t.vocab.byteTokens[b]; id >= 0 {
|
|
ids = append(ids, id)
|
|
}
|
|
}
|
|
return ids
|
|
}
|
|
|
|
func decodeByteLevelRune(r rune) (byte, bool) {
|
|
switch {
|
|
case r >= 0x00 && r <= 0xFF:
|
|
return byte(r), true
|
|
case r == 0x0100:
|
|
return 0x00, true
|
|
case r == 0x0143:
|
|
return 0x00ad, true
|
|
case r > 0x0100 && r <= 0x0120:
|
|
return byte(r - 0x0100), true
|
|
case r > 0x0120 && r <= 0x0142:
|
|
return byte(r - 0x00a2), true
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|