package tokenizer import ( "bufio" "encoding/json" "math" "os" "path/filepath" "slices" "strconv" "strings" "testing" "github.com/google/go-cmp/cmp" ) func llama(t testing.TB) BytePairEncoding { t.Helper() f, err := os.Open(filepath.FromSlash("testdata/llama3.2/encoder.json")) if err != nil { t.Fatal(err) } defer f.Close() vocab := make(map[string]int32) if err := json.NewDecoder(f).Decode(&vocab); err != nil { t.Fatal(err) } types := make([]int32, len(vocab)) tokens := make([]string, len(vocab)) for token, id := range vocab { tokens[id] = token types[id] = 1 } for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} { if _, ok := vocab[token]; !ok { tokens = append(tokens, token) //nolint:makezero types = append(types, 3) //nolint:makezero vocab[token] = int32(len(vocab)) } } f, err = os.Open(filepath.FromSlash("testdata/llama3.2/vocab.bpe")) if err != nil { t.Fatal(err) } defer f.Close() merges := make([]string, 0, 50000) scanner := bufio.NewScanner(f) for scanner.Scan() { if !strings.HasPrefix(scanner.Text(), "#") { merges = append(merges, scanner.Text()) } } return NewBytePairEncoding( &Vocabulary{ Values: tokens, Types: types, Merges: merges, }, "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", ) } func TestLlama(t *testing.T) { tokenizer := llama(t) t.Run("simple", func(t *testing.T) { t.Parallel() ids, err := tokenizer.Encode("hello world", true) if err != nil { t.Error(err) } if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" { t.Errorf("no match (-theirs +ours):\n%s", diff) } s, err := tokenizer.Decode([]int32{15339, 1917}) if err != nil { t.Fatal(err) } if s != "hello world" { t.Errorf("got %q, want hello world", s) } ids, err = tokenizer.Encode("hello <|end_of_text|>", true) if err != nil { t.Error(err) } if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" { t.Errorf("no match (-theirs +ours):\n%s", diff) } }) t.Run("simple repeated", func(t *testing.T) { t.Parallel() cases := map[string][]int32{ strings.Repeat("0", 1): {15}, strings.Repeat("0", 2): {410}, strings.Repeat("0", 3): {931}, strings.Repeat("0", 4): {931, 15}, strings.Repeat("0", 5): {931, 410}, strings.Repeat("0", 6): {931, 931}, strings.Repeat("0", 7): {931, 931, 15}, strings.Repeat("0", 8): {931, 931, 410}, strings.Repeat("0", 9): {931, 931, 931}, strings.Repeat("0", 10): {931, 931, 931, 15}, strings.Repeat("0", 11): {931, 931, 931, 410}, strings.Repeat("0", 12): {931, 931, 931, 931}, strings.Repeat("0", 13): {931, 931, 931, 931, 15}, strings.Repeat("0", 14): {931, 931, 931, 931, 410}, strings.Repeat("0", 15): {931, 931, 931, 931, 931}, strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15}, strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410}, } for s, want := range cases { ids, err := tokenizer.Encode(s, true) if err != nil { t.Error(err) } if diff := cmp.Diff(want, ids); diff != "" { t.Errorf("%q no match (-theirs +ours):\n%s", s, diff) } } }) t.Run("basic roundtrip", func(t *testing.T) { t.Parallel() cases := []string{ "hello", "hello ", "hello ", " hello", " hello ", " hello ", "hello world", "请考试我的软件!12345", } for _, want := range cases { ids, err := tokenizer.Encode(want, true) if err != nil { t.Error(err) } if got, err := tokenizer.Decode(ids); err != nil { t.Fatal(err) } else if got != want { t.Errorf("got %q, want %q", got, want) } } }) t.Run("special", func(t *testing.T) { t.Parallel() cases := map[string][]int32{ "<|begin_of_text|>A B!": {128000, 32, 426, 0}, "<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0}, "<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0}, "<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001}, } for s, want := range cases { ids, err := tokenizer.Encode(s, true) if err != nil { t.Fatal(err) } if diff := cmp.Diff(want, ids); diff != "" { t.Errorf("no match (-theirs +ours):\n%s", diff) } } }) t.Run("split", func(t *testing.T) { t.Parallel() cases := map[string][]string{ "Hello World!": {"Hello", " World", "!"}, "I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"}, "In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"}, "Hello!! ...world": {"Hello", "!!", " ...", "world"}, "Hello World": {"Hello", " ", " World"}, "Hello\nWorld": {"Hello", "\n", "World"}, "Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"}, } for s, want := range cases { got := slices.Collect(tokenizer.split(s)) if diff := cmp.Diff(want, got); diff != "" { t.Errorf("no match (-theirs +ours):\n%s", diff) } } }) t.Run("roundtriping 0x00-0xFF", func(t *testing.T) { t.Parallel() for b := 0x00; b <= 0xFF; b++ { input := string(rune(b)) ids, err := tokenizer.Encode(input, false) if err != nil { t.Errorf("failed to encode rune 0x%02X: %v", b, err) continue } decoded, err := tokenizer.Decode(ids) if err != nil { t.Errorf("failed to decode rune 0x%02X: %v", b, err) continue } if b == 0x00 { if len(decoded) != 0 { t.Errorf("Decode(Encode(0x00)) should be empty, got %v", ids) } continue } if decoded != input { t.Errorf("rune 0x%02X failed roundtrip: got %q, want %q", b, decoded, input) } } }) } // spmBPE builds a SentencePiece-style BPE tokenizer for testing. // // Models that use SentencePiece BPE differ from GPT-2 BPE in how they // handle spaces: the vocabulary stores ▁ (U+2581) instead of GPT-2's // shifted-byte encoding (0x0100–0x0143). Without WithSentencePieceNormalizer, // spaces are mapped through the GPT-2 byte table which produces wrong token // IDs for any vocabulary that uses ▁-prefixed tokens. The decode path has // the inverse problem: high codepoints like CJK characters and ▁ itself // would be mangled by the GPT-2 reverse mapping instead of being passed // through (or converted to spaces in the ▁ case). func spmBPE(t testing.TB) BytePairEncoding { t.Helper() tokens := []string{ // Control tokens (low IDs, as in real SentencePiece vocabs) "", // 0 "", // 1 "", // 2 "<|start>", // 3 - asymmetric open/close special tokens "", // 4 "<|q>", // 5 - short special token (like <|"|>) // ▁-prefixed word tokens (the core of what SPM BPE changes) "▁hello", // 6 "▁world", // 7 "hello", // 8 "▁Run", // 9 "▁a", // 10 // Punctuation and structure ",", // 11 "!", // 12 ":", // 13 "{", // 14 "}", // 15 // Whitespace separator "▁", // 16 // Subword tokens used in tool-declaration-like patterns "description", // 17 "▁command", // 18 "declaration", // 19 // Unicode token for decode passthrough testing (must be > U+0143 // to exercise the SPM decode path rather than GPT-2 byte reversal) "▁中文", // 20 } types := make([]int32, len(tokens)) for i := range types { types[i] = TOKEN_TYPE_NORMAL } types[0] = TOKEN_TYPE_CONTROL // types[1] = TOKEN_TYPE_CONTROL // types[2] = TOKEN_TYPE_CONTROL // types[3] = TOKEN_TYPE_USER_DEFINED // <|start> types[4] = TOKEN_TYPE_USER_DEFINED // types[5] = TOKEN_TYPE_USER_DEFINED // <|q> return NewBytePairEncodingWithOptions( &Vocabulary{ Values: tokens, Types: types, BOS: []int32{2}, EOS: []int32{1}, AddBOS: false, }, // Empty pretokenizer list: falls back to the default pattern. // Real SentencePiece BPE models are configured this way. []string{}, WithSentencePieceNormalizer(), ) } func TestSentencePieceBPE(t *testing.T) { tok := spmBPE(t) // Test 1: Space-to-▁ normalization and roundtrip. // // This is the core behavior that WithSentencePieceNormalizer enables. // Without it, " hello" would be byte-mapped through the GPT-2 table // (producing Ġhello or similar shifted codepoints) which would never // match the ▁-prefixed vocab entry. t.Run("spm space normalization roundtrip", func(t *testing.T) { t.Parallel() cases := map[string][]int32{ "hello": {8}, // no space → no ▁ prefix → "hello"(8) " hello": {6}, // leading space → "▁hello"(6) "hello, world!": {8, 11, 7, 12}, // pretokenizer splits punctuation; // " world" normalizes to "▁world" } for input, wantIDs := range cases { ids, err := tok.Encode(input, false) if err != nil { t.Fatalf("Encode(%q): %v", input, err) } if diff := cmp.Diff(wantIDs, ids); diff != "" { t.Errorf("Encode(%q) mismatch (-want +got):\n%s", input, diff) } got, err := tok.Decode(ids) if err != nil { t.Fatalf("Decode(%v): %v", ids, err) } if got != input { t.Errorf("roundtrip %q: Decode(Encode) = %q", input, got) } } }) // Test 2: Special tokens interleaved with SPM-normalized text. // // This mimics tool declaration patterns like: // <|tool>declaration:bash{description:<|"|>Run a command<|"|>} // where special tokens (<|tool>, <|"|>, ) must be extracted // first, then the remaining text fragments go through SPM normalization. // Without the SPM normalizer, the text between special tokens would be // encoded with GPT-2 byte mapping, producing entirely wrong IDs. t.Run("special tokens with spm text fragments", func(t *testing.T) { t.Parallel() // Pattern: <|start>declaration:description:<|q>Run a command<|q>} input := "<|start>declaration:description:<|q> Run a command<|q>}" ids, err := tok.Encode(input, false) if err != nil { t.Fatal(err) } // Special tokens should be extracted as single IDs, and the text // between them should be SPM-normalized (spaces → ▁). want := []int32{ 3, // <|start> 19, // "declaration" (text fragment, no leading space) 13, // ":" 17, // "description" 13, // ":" 5, // <|q> 9, // "▁Run" (space before "Run" becomes ▁) 10, // "▁a" 18, // "▁command" 5, // <|q> 15, // "}" 4, // } if diff := cmp.Diff(want, ids); diff != "" { t.Errorf("mismatch (-want +got):\n%s", diff) } }) // Test 3: Decode handles non-GPT2 Unicode correctly. // // GPT-2 BPE decode reverses the byte→codepoint shift for runes in // 0x0100–0x0143. But SentencePiece vocabs store real Unicode (CJK, // accented chars, etc.) which have codepoints well above 0x0143. // Without the > 0x0143 passthrough in Decode, these would be mangled // by the GPT-2 reverse mapping (e.g., written as raw bytes instead // of the original characters). t.Run("decode non-gpt2 unicode passthrough", func(t *testing.T) { t.Parallel() cases := map[string][]int32{ " 中文": {20}, // ▁→space, then CJK passes through as-is } for want, ids := range cases { got, err := tok.Decode(ids) if err != nil { t.Fatalf("Decode(%v): %v", ids, err) } if got != want { t.Errorf("Decode(%v) = %q, want %q", ids, got, want) } } }) } func BenchmarkBytePairEncoding(b *testing.B) { tokenizer := llama(b) bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt")) if err != nil { b.Fatal(err) } for i := range 8 { n := min(int(math.Pow10(i)), len(bts)) bts := bts[:n] b.Run("encode"+strconv.Itoa(n), func(b *testing.B) { b.ResetTimer() for b.Loop() { _, err := tokenizer.Encode(string(bts), true) if err != nil { b.Fatal(err) } } }) b.Run("decode"+strconv.Itoa(n), func(b *testing.B) { ids, err := tokenizer.Encode(string(bts), true) if err != nil { b.Fatal(err) } b.ResetTimer() for b.Loop() { _, err := tokenizer.Decode(ids) if err != nil { b.Fatal(err) } } }) b.Run("split"+strconv.Itoa(n), func(b *testing.B) { b.ResetTimer() for b.Loop() { slices.Collect(tokenizer.split(string(bts))) } }) } } func TestSplit(t *testing.T) { cases := []struct { name string patterns, want []string }{ { name: "default", want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"}, }, { name: "unicode", patterns: []string{ "\\p{N}{1,3}", `[一-龥぀-ゟ゠-ヿ]+`, "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", }, want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"}, }, { name: "individual digits", patterns: []string{ "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }, want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"}, }, } for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { tokenizer := NewBytePairEncoding(nil, tt.patterns...) if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" { t.Errorf("no match (-theirs +ours):\n%s", diff) } }) } }