Files
ollama-ollama/tokenizer/bytepairencoding_test.go
Daniel Hiltgen cb0033598e tokenizer: add SentencePiece-style BPE support (#15162)
* tokenizer: add SentencePiece-style BPE support

Add WithSentencePieceNormalizer option to BytePairEncoding for models
that use BPE with SentencePiece-style space markers (space to/from
U+2581).

NewBytePairEncoding is unchanged; the new NewBytePairEncodingWithOptions
constructor accepts BPEOption functions. Decoding handles the reverse
mapping of U+2581 back to spaces.

* review comments
2026-03-31 17:00:36 -07:00

503 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package tokenizer
import (
"bufio"
"encoding/json"
"math"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
func llama(t testing.TB) BytePairEncoding {
t.Helper()
f, err := os.Open(filepath.FromSlash("testdata/llama3.2/encoder.json"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
vocab := make(map[string]int32)
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
t.Fatal(err)
}
types := make([]int32, len(vocab))
tokens := make([]string, len(vocab))
for token, id := range vocab {
tokens[id] = token
types[id] = 1
}
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
if _, ok := vocab[token]; !ok {
tokens = append(tokens, token) //nolint:makezero
types = append(types, 3) //nolint:makezero
vocab[token] = int32(len(vocab))
}
}
f, err = os.Open(filepath.FromSlash("testdata/llama3.2/vocab.bpe"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
merges := make([]string, 0, 50000)
scanner := bufio.NewScanner(f)
for scanner.Scan() {
if !strings.HasPrefix(scanner.Text(), "#") {
merges = append(merges, scanner.Text())
}
}
return NewBytePairEncoding(
&Vocabulary{
Values: tokens,
Types: types,
Merges: merges,
},
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
)
}
func TestLlama(t *testing.T) {
tokenizer := llama(t)
t.Run("simple", func(t *testing.T) {
t.Parallel()
ids, err := tokenizer.Encode("hello world", true)
if err != nil {
t.Error(err)
}
if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
s, err := tokenizer.Decode([]int32{15339, 1917})
if err != nil {
t.Fatal(err)
}
if s != "hello world" {
t.Errorf("got %q, want hello world", s)
}
ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
if err != nil {
t.Error(err)
}
if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
})
t.Run("simple repeated", func(t *testing.T) {
t.Parallel()
cases := map[string][]int32{
strings.Repeat("0", 1): {15},
strings.Repeat("0", 2): {410},
strings.Repeat("0", 3): {931},
strings.Repeat("0", 4): {931, 15},
strings.Repeat("0", 5): {931, 410},
strings.Repeat("0", 6): {931, 931},
strings.Repeat("0", 7): {931, 931, 15},
strings.Repeat("0", 8): {931, 931, 410},
strings.Repeat("0", 9): {931, 931, 931},
strings.Repeat("0", 10): {931, 931, 931, 15},
strings.Repeat("0", 11): {931, 931, 931, 410},
strings.Repeat("0", 12): {931, 931, 931, 931},
strings.Repeat("0", 13): {931, 931, 931, 931, 15},
strings.Repeat("0", 14): {931, 931, 931, 931, 410},
strings.Repeat("0", 15): {931, 931, 931, 931, 931},
strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15},
strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410},
}
for s, want := range cases {
ids, err := tokenizer.Encode(s, true)
if err != nil {
t.Error(err)
}
if diff := cmp.Diff(want, ids); diff != "" {
t.Errorf("%q no match (-theirs +ours):\n%s", s, diff)
}
}
})
t.Run("basic roundtrip", func(t *testing.T) {
t.Parallel()
cases := []string{
"hello",
"hello ",
"hello ",
" hello",
" hello ",
" hello ",
"hello world",
"请考试我的软件12345",
}
for _, want := range cases {
ids, err := tokenizer.Encode(want, true)
if err != nil {
t.Error(err)
}
if got, err := tokenizer.Decode(ids); err != nil {
t.Fatal(err)
} else if got != want {
t.Errorf("got %q, want %q", got, want)
}
}
})
t.Run("special", func(t *testing.T) {
t.Parallel()
cases := map[string][]int32{
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
}
for s, want := range cases {
ids, err := tokenizer.Encode(s, true)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(want, ids); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
}
})
t.Run("split", func(t *testing.T) {
t.Parallel()
cases := map[string][]string{
"Hello World!": {"Hello", " World", "!"},
"I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"},
"In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"},
"Hello!! ...world": {"Hello", "!!", " ...", "world"},
"Hello World": {"Hello", " ", " World"},
"Hello\nWorld": {"Hello", "\n", "World"},
"Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
}
for s, want := range cases {
got := slices.Collect(tokenizer.split(s))
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
}
})
t.Run("roundtriping 0x00-0xFF", func(t *testing.T) {
t.Parallel()
for b := 0x00; b <= 0xFF; b++ {
input := string(rune(b))
ids, err := tokenizer.Encode(input, false)
if err != nil {
t.Errorf("failed to encode rune 0x%02X: %v", b, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("failed to decode rune 0x%02X: %v", b, err)
continue
}
if b == 0x00 {
if len(decoded) != 0 {
t.Errorf("Decode(Encode(0x00)) should be empty, got %v", ids)
}
continue
}
if decoded != input {
t.Errorf("rune 0x%02X failed roundtrip: got %q, want %q", b, decoded, input)
}
}
})
}
// spmBPE builds a SentencePiece-style BPE tokenizer for testing.
//
// Models that use SentencePiece BPE differ from GPT-2 BPE in how they
// handle spaces: the vocabulary stores ▁ (U+2581) instead of GPT-2's
// shifted-byte encoding (0x01000x0143). Without WithSentencePieceNormalizer,
// spaces are mapped through the GPT-2 byte table which produces wrong token
// IDs for any vocabulary that uses ▁-prefixed tokens. The decode path has
// the inverse problem: high codepoints like CJK characters and ▁ itself
// would be mangled by the GPT-2 reverse mapping instead of being passed
// through (or converted to spaces in the ▁ case).
func spmBPE(t testing.TB) BytePairEncoding {
t.Helper()
tokens := []string{
// Control tokens (low IDs, as in real SentencePiece vocabs)
"<pad>", // 0
"<eos>", // 1
"<bos>", // 2
"<|start>", // 3 - asymmetric open/close special tokens
"<end|>", // 4
"<|q>", // 5 - short special token (like <|"|>)
// ▁-prefixed word tokens (the core of what SPM BPE changes)
"▁hello", // 6
"▁world", // 7
"hello", // 8
"▁Run", // 9
"▁a", // 10
// Punctuation and structure
",", // 11
"!", // 12
":", // 13
"{", // 14
"}", // 15
// Whitespace separator
"▁", // 16
// Subword tokens used in tool-declaration-like patterns
"description", // 17
"▁command", // 18
"declaration", // 19
// Unicode token for decode passthrough testing (must be > U+0143
// to exercise the SPM decode path rather than GPT-2 byte reversal)
"▁中文", // 20
}
types := make([]int32, len(tokens))
for i := range types {
types[i] = TOKEN_TYPE_NORMAL
}
types[0] = TOKEN_TYPE_CONTROL // <pad>
types[1] = TOKEN_TYPE_CONTROL // <eos>
types[2] = TOKEN_TYPE_CONTROL // <bos>
types[3] = TOKEN_TYPE_USER_DEFINED // <|start>
types[4] = TOKEN_TYPE_USER_DEFINED // <end|>
types[5] = TOKEN_TYPE_USER_DEFINED // <|q>
return NewBytePairEncodingWithOptions(
&Vocabulary{
Values: tokens,
Types: types,
BOS: []int32{2},
EOS: []int32{1},
AddBOS: false,
},
// Empty pretokenizer list: falls back to the default pattern.
// Real SentencePiece BPE models are configured this way.
[]string{},
WithSentencePieceNormalizer(),
)
}
func TestSentencePieceBPE(t *testing.T) {
tok := spmBPE(t)
// Test 1: Space-to-▁ normalization and roundtrip.
//
// This is the core behavior that WithSentencePieceNormalizer enables.
// Without it, " hello" would be byte-mapped through the GPT-2 table
// (producing Ġhello or similar shifted codepoints) which would never
// match the ▁-prefixed vocab entry.
t.Run("spm space normalization roundtrip", func(t *testing.T) {
t.Parallel()
cases := map[string][]int32{
"hello": {8}, // no space → no ▁ prefix → "hello"(8)
" hello": {6}, // leading space → "▁hello"(6)
"hello, world!": {8, 11, 7, 12}, // pretokenizer splits punctuation;
// " world" normalizes to "▁world"
}
for input, wantIDs := range cases {
ids, err := tok.Encode(input, false)
if err != nil {
t.Fatalf("Encode(%q): %v", input, err)
}
if diff := cmp.Diff(wantIDs, ids); diff != "" {
t.Errorf("Encode(%q) mismatch (-want +got):\n%s", input, diff)
}
got, err := tok.Decode(ids)
if err != nil {
t.Fatalf("Decode(%v): %v", ids, err)
}
if got != input {
t.Errorf("roundtrip %q: Decode(Encode) = %q", input, got)
}
}
})
// Test 2: Special tokens interleaved with SPM-normalized text.
//
// This mimics tool declaration patterns like:
// <|tool>declaration:bash{description:<|"|>Run a command<|"|>}<tool|>
// where special tokens (<|tool>, <|"|>, <tool|>) must be extracted
// first, then the remaining text fragments go through SPM normalization.
// Without the SPM normalizer, the text between special tokens would be
// encoded with GPT-2 byte mapping, producing entirely wrong IDs.
t.Run("special tokens with spm text fragments", func(t *testing.T) {
t.Parallel()
// Pattern: <|start>declaration:description:<|q>Run a command<|q>}<end|>
input := "<|start>declaration:description:<|q> Run a command<|q>}<end|>"
ids, err := tok.Encode(input, false)
if err != nil {
t.Fatal(err)
}
// Special tokens should be extracted as single IDs, and the text
// between them should be SPM-normalized (spaces → ▁).
want := []int32{
3, // <|start>
19, // "declaration" (text fragment, no leading space)
13, // ":"
17, // "description"
13, // ":"
5, // <|q>
9, // "▁Run" (space before "Run" becomes ▁)
10, // "▁a"
18, // "▁command"
5, // <|q>
15, // "}"
4, // <end|>
}
if diff := cmp.Diff(want, ids); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
}
})
// Test 3: Decode handles non-GPT2 Unicode correctly.
//
// GPT-2 BPE decode reverses the byte→codepoint shift for runes in
// 0x01000x0143. But SentencePiece vocabs store real Unicode (CJK,
// accented chars, etc.) which have codepoints well above 0x0143.
// Without the > 0x0143 passthrough in Decode, these would be mangled
// by the GPT-2 reverse mapping (e.g., written as raw bytes instead
// of the original characters).
t.Run("decode non-gpt2 unicode passthrough", func(t *testing.T) {
t.Parallel()
cases := map[string][]int32{
" 中文": {20}, // ▁→space, then CJK passes through as-is
}
for want, ids := range cases {
got, err := tok.Decode(ids)
if err != nil {
t.Fatalf("Decode(%v): %v", ids, err)
}
if got != want {
t.Errorf("Decode(%v) = %q, want %q", ids, got, want)
}
}
})
}
func BenchmarkBytePairEncoding(b *testing.B) {
tokenizer := llama(b)
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
if err != nil {
b.Fatal(err)
}
for i := range 8 {
n := min(int(math.Pow10(i)), len(bts))
bts := bts[:n]
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
_, err := tokenizer.Encode(string(bts), true)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
ids, err := tokenizer.Encode(string(bts), true)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for b.Loop() {
_, err := tokenizer.Decode(ids)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
slices.Collect(tokenizer.split(string(bts)))
}
})
}
}
func TestSplit(t *testing.T) {
cases := []struct {
name string
patterns,
want []string
}{
{
name: "default",
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"},
},
{
name: "unicode",
patterns: []string{
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
},
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"},
},
{
name: "individual digits",
patterns: []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
},
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tokenizer := NewBytePairEncoding(nil, tt.patterns...)
if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
})
}
}