mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 19:54:03 +02:00
* tokenizer: add byte fallback for SentencePiece BPE encoding When BPE merging produces tokens not in the vocabulary, fall back to encoding each UTF-8 byte as <0xHH> byte tokens instead of silently dropping the character. Also teach Decode to convert <0xHH> tokens back to raw bytes. Fixes #15229, fixes #15231 * tokenizer fixes
368 lines
8.8 KiB
Go
368 lines
8.8 KiB
Go
package tokenizer
|
|
|
|
import (
|
|
"cmp"
|
|
"fmt"
|
|
"iter"
|
|
"log/slog"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/dlclark/regexp2"
|
|
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
|
|
"github.com/ollama/ollama/logutil"
|
|
)
|
|
|
|
type BytePairEncoding struct {
|
|
vocab *Vocabulary
|
|
regexps []*regexp2.Regexp
|
|
spaceToSpmSep bool // When true, normalize spaces to ▁ instead of GPT-2 byte-level encoding
|
|
}
|
|
|
|
var _ Tokenizer = (*BytePairEncoding)(nil)
|
|
|
|
// BPEOption configures BytePairEncoding behavior
|
|
type BPEOption func(*BytePairEncoding)
|
|
|
|
// WithSentencePieceNormalizer enables ▁ space normalization instead of GPT-2 byte-level encoding.
|
|
func WithSentencePieceNormalizer() BPEOption {
|
|
return func(bpe *BytePairEncoding) {
|
|
bpe.spaceToSpmSep = true
|
|
}
|
|
}
|
|
|
|
func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEncoding {
|
|
return newBytePairEncoding(vocab, pretokenizer)
|
|
}
|
|
|
|
func NewBytePairEncodingWithOptions(vocab *Vocabulary, pretokenizer []string, opts ...BPEOption) BytePairEncoding {
|
|
bpe := newBytePairEncoding(vocab, pretokenizer, opts...)
|
|
return bpe
|
|
}
|
|
|
|
func newBytePairEncoding(vocab *Vocabulary, pretokenizer []string, opts ...BPEOption) BytePairEncoding {
|
|
bpe := BytePairEncoding{
|
|
vocab: vocab,
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(&bpe)
|
|
}
|
|
|
|
if len(pretokenizer) == 0 && !bpe.spaceToSpmSep {
|
|
// set default byte-level pretokenizer if none provided, e.g.
|
|
// https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
|
|
pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
|
}
|
|
|
|
bpe.regexps = slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
|
for _, p := range pretokenizer {
|
|
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
|
|
return
|
|
}
|
|
}
|
|
})
|
|
|
|
return bpe
|
|
}
|
|
|
|
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
|
|
return bpe.vocab
|
|
}
|
|
|
|
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
|
|
return bpe.vocab.Is(id, special)
|
|
}
|
|
|
|
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
|
|
parts := []string{s}
|
|
for _, re := range bpe.regexps {
|
|
parts = slices.Collect(func(yield func(string) bool) {
|
|
for _, part := range parts {
|
|
r := []rune(part)
|
|
var offset int
|
|
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
|
|
if offset-m.Index != 0 {
|
|
if !yield(string(r[:m.Index])) {
|
|
return
|
|
}
|
|
}
|
|
|
|
if !yield(m.String()) {
|
|
return
|
|
}
|
|
|
|
offset = m.Index + m.Length
|
|
}
|
|
|
|
if offset < len(r) {
|
|
if !yield(string(r[offset:])) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
return slices.Values(parts)
|
|
}
|
|
|
|
// fragment is a string fragment and their corresponding token IDs
|
|
type fragment struct {
|
|
value string
|
|
ids []int32
|
|
}
|
|
|
|
// pair is a pair of runes and its rank
|
|
type pair struct {
|
|
a, b int
|
|
rank int
|
|
value string
|
|
}
|
|
|
|
type merge struct {
|
|
p, n int
|
|
runes []rune
|
|
}
|
|
|
|
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
|
fragments := []fragment{{value: s}}
|
|
for _, special := range bpe.vocab.SpecialVocabulary() {
|
|
// TODO: process special tokens concurrently
|
|
id := bpe.vocab.Encode(special)
|
|
for i := 0; i < len(fragments); i++ {
|
|
frag := fragments[i]
|
|
if len(frag.ids) > 0 {
|
|
continue
|
|
}
|
|
|
|
var middle []fragment
|
|
switch i := strings.Index(frag.value, special); {
|
|
case i < 0:
|
|
middle = append(middle, frag)
|
|
case i > 0:
|
|
middle = append(middle, fragment{value: frag.value[:i]})
|
|
fallthrough
|
|
default:
|
|
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
|
if rest := frag.value[i+len(special):]; rest != "" {
|
|
middle = append(middle, fragment{value: rest})
|
|
}
|
|
}
|
|
|
|
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
|
}
|
|
}
|
|
|
|
var ids []int32
|
|
for _, frag := range fragments {
|
|
if len(frag.ids) > 0 {
|
|
ids = append(ids, frag.ids...)
|
|
continue
|
|
}
|
|
|
|
for split := range bpe.split(frag.value) {
|
|
// TODO: process splits concurrently
|
|
var normalized string
|
|
if bpe.spaceToSpmSep {
|
|
// SentencePiece-style: replace spaces with ▁
|
|
normalized = strings.ReplaceAll(split, " ", spmWhitespaceSep)
|
|
} else {
|
|
// GPT-2 byte-level: map bytes to shifted Unicode codepoints
|
|
var sb strings.Builder
|
|
for _, b := range []byte(split) {
|
|
r := rune(b)
|
|
switch {
|
|
case r == 0x00ad:
|
|
r = 0x0143
|
|
case r <= 0x0020:
|
|
r = r + 0x0100
|
|
case r >= 0x007f && r <= 0x00a0:
|
|
r = r + 0x00a2
|
|
}
|
|
sb.WriteRune(r)
|
|
}
|
|
normalized = sb.String()
|
|
}
|
|
|
|
// short circuit if the fragment is in the vocabulary
|
|
if id := bpe.vocab.Encode(normalized); id >= 0 {
|
|
ids = append(ids, id)
|
|
continue
|
|
}
|
|
|
|
runes := []rune(normalized)
|
|
merges := make([]merge, len(runes))
|
|
for r := range runes {
|
|
merges[r] = merge{
|
|
p: r - 1,
|
|
n: r + 1,
|
|
runes: []rune{runes[r]},
|
|
}
|
|
}
|
|
|
|
pairwise := func(a, b int) *pair {
|
|
if a < 0 || b >= len(runes) {
|
|
return nil
|
|
}
|
|
|
|
left, right := string(merges[a].runes), string(merges[b].runes)
|
|
rank := bpe.vocab.Merge(left, right)
|
|
if rank < 0 {
|
|
return nil
|
|
}
|
|
|
|
return &pair{
|
|
a: a,
|
|
b: b,
|
|
rank: rank,
|
|
value: left + right,
|
|
}
|
|
}
|
|
|
|
pairs := heap.NewWith(func(i, j *pair) int {
|
|
return cmp.Compare(i.rank, j.rank)
|
|
})
|
|
|
|
for i := range len(runes) - 1 {
|
|
if pair := pairwise(i, i+1); pair != nil {
|
|
pairs.Push(pair)
|
|
}
|
|
}
|
|
|
|
for !pairs.Empty() {
|
|
pair, _ := pairs.Pop()
|
|
|
|
left, right := merges[pair.a], merges[pair.b]
|
|
if len(left.runes) == 0 || len(right.runes) == 0 ||
|
|
string(left.runes)+string(right.runes) != pair.value {
|
|
continue
|
|
}
|
|
|
|
if id := bpe.vocab.Encode(pair.value); id < 0 {
|
|
continue
|
|
}
|
|
|
|
merges[pair.a].runes = append(left.runes, right.runes...)
|
|
merges[pair.b].runes = nil
|
|
|
|
merges[pair.a].n = right.n
|
|
if right.n < len(merges) {
|
|
merges[right.n].p = pair.a
|
|
}
|
|
|
|
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
|
pairs.Push(pair)
|
|
}
|
|
|
|
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
|
pairs.Push(pair)
|
|
}
|
|
}
|
|
|
|
for _, merge := range merges {
|
|
if len(merge.runes) > 0 {
|
|
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
|
ids = append(ids, id)
|
|
} else if bpe.spaceToSpmSep {
|
|
// SentencePiece byte fallback: encode each UTF-8 byte as <0xHH>
|
|
for _, b := range []byte(string(merge.runes)) {
|
|
if id := bpe.vocab.Encode(fmt.Sprintf("<0x%02X>", b)); id >= 0 {
|
|
ids = append(ids, id)
|
|
} else {
|
|
slog.Debug("unknown byte token", "byte", b)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if addSpecial {
|
|
ids = bpe.vocab.addSpecials(ids)
|
|
}
|
|
|
|
logutil.Trace("encoded", "string", s, "ids", ids)
|
|
return ids, nil
|
|
}
|
|
|
|
type lazyIdsString struct {
|
|
ids []int32
|
|
}
|
|
|
|
func (l lazyIdsString) LogValue() slog.Value {
|
|
return slog.AnyValue(fmt.Sprint(l.ids))
|
|
}
|
|
|
|
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
|
var sb strings.Builder
|
|
|
|
// SentencePiece-style BPE stores true Unicode codepoints in the vocab
|
|
// (plus ▁ as a whitespace marker), so decoding should pass runes through
|
|
// directly instead of applying the GPT-2 byte-level reverse mapping.
|
|
// Without this, codepoints in the 0x0100-0x0142 range (e.g. ą ę ć ł)
|
|
// get mangled by the GPT-2 reversal into control characters.
|
|
if bpe.spaceToSpmSep {
|
|
for _, id := range ids {
|
|
data := bpe.vocab.Decode(id)
|
|
|
|
// SentencePiece byte tokens: "<0xHH>" → raw byte
|
|
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
|
|
if b, err := strconv.ParseUint(data[3:5], 16, 8); err == nil {
|
|
sb.WriteByte(byte(b))
|
|
continue
|
|
}
|
|
}
|
|
|
|
for _, r := range data {
|
|
if r == 0x2581 { // ▁ (LOWER ONE EIGHTH BLOCK)
|
|
sb.WriteByte(' ')
|
|
} else {
|
|
sb.WriteRune(r)
|
|
}
|
|
}
|
|
}
|
|
|
|
logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
|
|
return sb.String(), nil
|
|
}
|
|
|
|
for _, id := range ids {
|
|
for _, r := range bpe.vocab.Decode(id) {
|
|
// GPT-2 byte-level BPE uses Unicode chars in the 0x0100-0x0143
|
|
// range to represent bytes. Remap them back to actual bytes.
|
|
switch {
|
|
case r == 0x0100:
|
|
// this produces 0x00 aka NULL
|
|
continue
|
|
case r == 0x0143:
|
|
r = 0x00ad
|
|
case r > 0x0100 && r <= 0x0120:
|
|
r = r - 0x0100
|
|
case r > 0x0120 && r <= 0x0142:
|
|
r = r - 0x00a2
|
|
case r > 0x0143:
|
|
// Non-GPT2 rune (e.g., SentencePiece-style BPE).
|
|
// Handle ▁ as word separator, otherwise write the rune as-is.
|
|
if r == 0x2581 { // ▁ (LOWER ONE EIGHTH BLOCK)
|
|
sb.WriteByte(' ')
|
|
} else {
|
|
sb.WriteRune(r)
|
|
}
|
|
continue
|
|
}
|
|
|
|
// NOTE: not using WriteRune here because it writes the UTF-8
|
|
// encoding of the rune which is _not_ what we want
|
|
if err := sb.WriteByte(byte(r)); err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
}
|
|
|
|
logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
|
|
return sb.String(), nil
|
|
}
|