tokenizer: add SentencePiece-style BPE support (#15162)

* tokenizer: add SentencePiece-style BPE support

Add WithSentencePieceNormalizer option to BytePairEncoding for models
that use BPE with SentencePiece-style space markers (space to/from
U+2581).

NewBytePairEncoding is unchanged; the new NewBytePairEncodingWithOptions
constructor accepts BPEOption functions. Decoding handles the reverse
mapping of U+2581 back to spaces.

* review comments
This commit is contained in:
Daniel Hiltgen
2026-03-31 17:00:36 -07:00
committed by GitHub
parent 4d14b0ff92
commit cb0033598e
2 changed files with 241 additions and 17 deletions

View File

@@ -14,20 +14,40 @@ import (
)
type BytePairEncoding struct {
vocab *Vocabulary
regexps []*regexp2.Regexp
vocab *Vocabulary
regexps []*regexp2.Regexp
spaceToSpmSep bool // When true, normalize spaces to ▁ instead of GPT-2 byte-level encoding
}
var _ Tokenizer = (*BytePairEncoding)(nil)
// BPEOption configures BytePairEncoding behavior
type BPEOption func(*BytePairEncoding)
// WithSentencePieceNormalizer enables ▁ space normalization instead of GPT-2 byte-level encoding.
func WithSentencePieceNormalizer() BPEOption {
return func(bpe *BytePairEncoding) {
bpe.spaceToSpmSep = true
}
}
func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEncoding {
return newBytePairEncoding(vocab, pretokenizer)
}
func NewBytePairEncodingWithOptions(vocab *Vocabulary, pretokenizer []string, opts ...BPEOption) BytePairEncoding {
bpe := newBytePairEncoding(vocab, pretokenizer, opts...)
return bpe
}
func newBytePairEncoding(vocab *Vocabulary, pretokenizer []string, opts ...BPEOption) BytePairEncoding {
if len(pretokenizer) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
// https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{
bpe := BytePairEncoding{
vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
for _, p := range pretokenizer {
@@ -37,6 +57,12 @@ func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEnco
}
}),
}
for _, opt := range opts {
opt(&bpe)
}
return bpe
}
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
@@ -136,28 +162,35 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
for split := range bpe.split(frag.value) {
// TODO: process splits concurrently
var sb strings.Builder
for _, b := range []byte(split) {
r := rune(b)
switch {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2
var normalized string
if bpe.spaceToSpmSep {
// SentencePiece-style: replace spaces with ▁
normalized = strings.ReplaceAll(split, " ", spmWhitespaceSep)
} else {
// GPT-2 byte-level: map bytes to shifted Unicode codepoints
var sb strings.Builder
for _, b := range []byte(split) {
r := rune(b)
switch {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2
}
sb.WriteRune(r)
}
sb.WriteRune(r)
normalized = sb.String()
}
// short circuit if the fragment is in the vocabulary
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
if id := bpe.vocab.Encode(normalized); id >= 0 {
ids = append(ids, id)
continue
}
runes := []rune(sb.String())
runes := []rune(normalized)
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
@@ -257,6 +290,8 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
for _, r := range bpe.vocab.Decode(id) {
// GPT-2 byte-level BPE uses Unicode chars in the 0x0100-0x0143
// range to represent bytes. Remap them back to actual bytes.
switch {
case r == 0x0100:
// this produces 0x00 aka NULL
@@ -267,6 +302,15 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
r = r - 0x0100
case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2
case r > 0x0143:
// Non-GPT2 rune (e.g., SentencePiece-style BPE).
// Handle ▁ as word separator, otherwise write the rune as-is.
if r == 0x2581 { // ▁ (LOWER ONE EIGHTH BLOCK)
sb.WriteByte(' ')
} else {
sb.WriteRune(r)
}
continue
}
// NOTE: not using WriteRune here because it writes the UTF-8