From cb0033598ec36c9d10eceef8a0ba4e02329c5b35 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Tue, 31 Mar 2026 17:00:36 -0700 Subject: [PATCH] tokenizer: add SentencePiece-style BPE support (#15162) * tokenizer: add SentencePiece-style BPE support Add WithSentencePieceNormalizer option to BytePairEncoding for models that use BPE with SentencePiece-style space markers (space to/from U+2581). NewBytePairEncoding is unchanged; the new NewBytePairEncodingWithOptions constructor accepts BPEOption functions. Decoding handles the reverse mapping of U+2581 back to spaces. * review comments --- tokenizer/bytepairencoding.go | 78 ++++++++++--- tokenizer/bytepairencoding_test.go | 180 +++++++++++++++++++++++++++++ 2 files changed, 241 insertions(+), 17 deletions(-) diff --git a/tokenizer/bytepairencoding.go b/tokenizer/bytepairencoding.go index b592aeedb..6d78925c3 100644 --- a/tokenizer/bytepairencoding.go +++ b/tokenizer/bytepairencoding.go @@ -14,20 +14,40 @@ import ( ) type BytePairEncoding struct { - vocab *Vocabulary - regexps []*regexp2.Regexp + vocab *Vocabulary + regexps []*regexp2.Regexp + spaceToSpmSep bool // When true, normalize spaces to ▁ instead of GPT-2 byte-level encoding } var _ Tokenizer = (*BytePairEncoding)(nil) +// BPEOption configures BytePairEncoding behavior +type BPEOption func(*BytePairEncoding) + +// WithSentencePieceNormalizer enables ▁ space normalization instead of GPT-2 byte-level encoding. +func WithSentencePieceNormalizer() BPEOption { + return func(bpe *BytePairEncoding) { + bpe.spaceToSpmSep = true + } +} + func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEncoding { + return newBytePairEncoding(vocab, pretokenizer) +} + +func NewBytePairEncodingWithOptions(vocab *Vocabulary, pretokenizer []string, opts ...BPEOption) BytePairEncoding { + bpe := newBytePairEncoding(vocab, pretokenizer, opts...) + return bpe +} + +func newBytePairEncoding(vocab *Vocabulary, pretokenizer []string, opts ...BPEOption) BytePairEncoding { if len(pretokenizer) == 0 { // set default byte-level pretokenizer if none provided, e.g. // https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44 pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`} } - return BytePairEncoding{ + bpe := BytePairEncoding{ vocab: vocab, regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) { for _, p := range pretokenizer { @@ -37,6 +57,12 @@ func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEnco } }), } + + for _, opt := range opts { + opt(&bpe) + } + + return bpe } func (bpe BytePairEncoding) Vocabulary() *Vocabulary { @@ -136,28 +162,35 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { for split := range bpe.split(frag.value) { // TODO: process splits concurrently - var sb strings.Builder - for _, b := range []byte(split) { - r := rune(b) - switch { - case r == 0x00ad: - r = 0x0143 - case r <= 0x0020: - r = r + 0x0100 - case r >= 0x007f && r <= 0x00a0: - r = r + 0x00a2 + var normalized string + if bpe.spaceToSpmSep { + // SentencePiece-style: replace spaces with ▁ + normalized = strings.ReplaceAll(split, " ", spmWhitespaceSep) + } else { + // GPT-2 byte-level: map bytes to shifted Unicode codepoints + var sb strings.Builder + for _, b := range []byte(split) { + r := rune(b) + switch { + case r == 0x00ad: + r = 0x0143 + case r <= 0x0020: + r = r + 0x0100 + case r >= 0x007f && r <= 0x00a0: + r = r + 0x00a2 + } + sb.WriteRune(r) } - - sb.WriteRune(r) + normalized = sb.String() } // short circuit if the fragment is in the vocabulary - if id := bpe.vocab.Encode(sb.String()); id >= 0 { + if id := bpe.vocab.Encode(normalized); id >= 0 { ids = append(ids, id) continue } - runes := []rune(sb.String()) + runes := []rune(normalized) merges := make([]merge, len(runes)) for r := range runes { merges[r] = merge{ @@ -257,6 +290,8 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) { var sb strings.Builder for _, id := range ids { for _, r := range bpe.vocab.Decode(id) { + // GPT-2 byte-level BPE uses Unicode chars in the 0x0100-0x0143 + // range to represent bytes. Remap them back to actual bytes. switch { case r == 0x0100: // this produces 0x00 aka NULL @@ -267,6 +302,15 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) { r = r - 0x0100 case r > 0x0120 && r <= 0x0142: r = r - 0x00a2 + case r > 0x0143: + // Non-GPT2 rune (e.g., SentencePiece-style BPE). + // Handle ▁ as word separator, otherwise write the rune as-is. + if r == 0x2581 { // ▁ (LOWER ONE EIGHTH BLOCK) + sb.WriteByte(' ') + } else { + sb.WriteRune(r) + } + continue } // NOTE: not using WriteRune here because it writes the UTF-8 diff --git a/tokenizer/bytepairencoding_test.go b/tokenizer/bytepairencoding_test.go index 9b9e901a7..09d12ca13 100644 --- a/tokenizer/bytepairencoding_test.go +++ b/tokenizer/bytepairencoding_test.go @@ -239,6 +239,186 @@ func TestLlama(t *testing.T) { }) } +// spmBPE builds a SentencePiece-style BPE tokenizer for testing. +// +// Models that use SentencePiece BPE differ from GPT-2 BPE in how they +// handle spaces: the vocabulary stores ▁ (U+2581) instead of GPT-2's +// shifted-byte encoding (0x0100–0x0143). Without WithSentencePieceNormalizer, +// spaces are mapped through the GPT-2 byte table which produces wrong token +// IDs for any vocabulary that uses ▁-prefixed tokens. The decode path has +// the inverse problem: high codepoints like CJK characters and ▁ itself +// would be mangled by the GPT-2 reverse mapping instead of being passed +// through (or converted to spaces in the ▁ case). +func spmBPE(t testing.TB) BytePairEncoding { + t.Helper() + + tokens := []string{ + // Control tokens (low IDs, as in real SentencePiece vocabs) + "", // 0 + "", // 1 + "", // 2 + "<|start>", // 3 - asymmetric open/close special tokens + "", // 4 + "<|q>", // 5 - short special token (like <|"|>) + + // ▁-prefixed word tokens (the core of what SPM BPE changes) + "▁hello", // 6 + "▁world", // 7 + "hello", // 8 + "▁Run", // 9 + "▁a", // 10 + + // Punctuation and structure + ",", // 11 + "!", // 12 + ":", // 13 + "{", // 14 + "}", // 15 + + // Whitespace separator + "▁", // 16 + + // Subword tokens used in tool-declaration-like patterns + "description", // 17 + "▁command", // 18 + "declaration", // 19 + + // Unicode token for decode passthrough testing (must be > U+0143 + // to exercise the SPM decode path rather than GPT-2 byte reversal) + "▁中文", // 20 + } + + types := make([]int32, len(tokens)) + for i := range types { + types[i] = TOKEN_TYPE_NORMAL + } + types[0] = TOKEN_TYPE_CONTROL // + types[1] = TOKEN_TYPE_CONTROL // + types[2] = TOKEN_TYPE_CONTROL // + types[3] = TOKEN_TYPE_USER_DEFINED // <|start> + types[4] = TOKEN_TYPE_USER_DEFINED // + types[5] = TOKEN_TYPE_USER_DEFINED // <|q> + + return NewBytePairEncodingWithOptions( + &Vocabulary{ + Values: tokens, + Types: types, + BOS: []int32{2}, + EOS: []int32{1}, + AddBOS: false, + }, + // Empty pretokenizer list: falls back to the default pattern. + // Real SentencePiece BPE models are configured this way. + []string{}, + WithSentencePieceNormalizer(), + ) +} + +func TestSentencePieceBPE(t *testing.T) { + tok := spmBPE(t) + + // Test 1: Space-to-▁ normalization and roundtrip. + // + // This is the core behavior that WithSentencePieceNormalizer enables. + // Without it, " hello" would be byte-mapped through the GPT-2 table + // (producing Ġhello or similar shifted codepoints) which would never + // match the ▁-prefixed vocab entry. + t.Run("spm space normalization roundtrip", func(t *testing.T) { + t.Parallel() + + cases := map[string][]int32{ + "hello": {8}, // no space → no ▁ prefix → "hello"(8) + " hello": {6}, // leading space → "▁hello"(6) + "hello, world!": {8, 11, 7, 12}, // pretokenizer splits punctuation; + // " world" normalizes to "▁world" + } + + for input, wantIDs := range cases { + ids, err := tok.Encode(input, false) + if err != nil { + t.Fatalf("Encode(%q): %v", input, err) + } + if diff := cmp.Diff(wantIDs, ids); diff != "" { + t.Errorf("Encode(%q) mismatch (-want +got):\n%s", input, diff) + } + + got, err := tok.Decode(ids) + if err != nil { + t.Fatalf("Decode(%v): %v", ids, err) + } + if got != input { + t.Errorf("roundtrip %q: Decode(Encode) = %q", input, got) + } + } + }) + + // Test 2: Special tokens interleaved with SPM-normalized text. + // + // This mimics tool declaration patterns like: + // <|tool>declaration:bash{description:<|"|>Run a command<|"|>} + // where special tokens (<|tool>, <|"|>, ) must be extracted + // first, then the remaining text fragments go through SPM normalization. + // Without the SPM normalizer, the text between special tokens would be + // encoded with GPT-2 byte mapping, producing entirely wrong IDs. + t.Run("special tokens with spm text fragments", func(t *testing.T) { + t.Parallel() + + // Pattern: <|start>declaration:description:<|q>Run a command<|q>} + input := "<|start>declaration:description:<|q> Run a command<|q>}" + ids, err := tok.Encode(input, false) + if err != nil { + t.Fatal(err) + } + + // Special tokens should be extracted as single IDs, and the text + // between them should be SPM-normalized (spaces → ▁). + want := []int32{ + 3, // <|start> + 19, // "declaration" (text fragment, no leading space) + 13, // ":" + 17, // "description" + 13, // ":" + 5, // <|q> + 9, // "▁Run" (space before "Run" becomes ▁) + 10, // "▁a" + 18, // "▁command" + 5, // <|q> + 15, // "}" + 4, // + } + + if diff := cmp.Diff(want, ids); diff != "" { + t.Errorf("mismatch (-want +got):\n%s", diff) + } + }) + + // Test 3: Decode handles non-GPT2 Unicode correctly. + // + // GPT-2 BPE decode reverses the byte→codepoint shift for runes in + // 0x0100–0x0143. But SentencePiece vocabs store real Unicode (CJK, + // accented chars, etc.) which have codepoints well above 0x0143. + // Without the > 0x0143 passthrough in Decode, these would be mangled + // by the GPT-2 reverse mapping (e.g., written as raw bytes instead + // of the original characters). + t.Run("decode non-gpt2 unicode passthrough", func(t *testing.T) { + t.Parallel() + + cases := map[string][]int32{ + " 中文": {20}, // ▁→space, then CJK passes through as-is + } + + for want, ids := range cases { + got, err := tok.Decode(ids) + if err != nil { + t.Fatalf("Decode(%v): %v", ids, err) + } + if got != want { + t.Errorf("Decode(%v) = %q, want %q", ids, got, want) + } + } + }) +} + func BenchmarkBytePairEncoding(b *testing.B) { tokenizer := llama(b) bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))