mirror of
https://github.com/ollama/ollama.git
synced 2026-04-23 01:05:47 +02:00
tokenizer: add SentencePiece-style BPE support (#15162)
* tokenizer: add SentencePiece-style BPE support Add WithSentencePieceNormalizer option to BytePairEncoding for models that use BPE with SentencePiece-style space markers (space to/from U+2581). NewBytePairEncoding is unchanged; the new NewBytePairEncodingWithOptions constructor accepts BPEOption functions. Decoding handles the reverse mapping of U+2581 back to spaces. * review comments
This commit is contained in:
@@ -16,18 +16,38 @@ import (
|
|||||||
type BytePairEncoding struct {
|
type BytePairEncoding struct {
|
||||||
vocab *Vocabulary
|
vocab *Vocabulary
|
||||||
regexps []*regexp2.Regexp
|
regexps []*regexp2.Regexp
|
||||||
|
spaceToSpmSep bool // When true, normalize spaces to ▁ instead of GPT-2 byte-level encoding
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Tokenizer = (*BytePairEncoding)(nil)
|
var _ Tokenizer = (*BytePairEncoding)(nil)
|
||||||
|
|
||||||
|
// BPEOption configures BytePairEncoding behavior
|
||||||
|
type BPEOption func(*BytePairEncoding)
|
||||||
|
|
||||||
|
// WithSentencePieceNormalizer enables ▁ space normalization instead of GPT-2 byte-level encoding.
|
||||||
|
func WithSentencePieceNormalizer() BPEOption {
|
||||||
|
return func(bpe *BytePairEncoding) {
|
||||||
|
bpe.spaceToSpmSep = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEncoding {
|
func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEncoding {
|
||||||
|
return newBytePairEncoding(vocab, pretokenizer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBytePairEncodingWithOptions(vocab *Vocabulary, pretokenizer []string, opts ...BPEOption) BytePairEncoding {
|
||||||
|
bpe := newBytePairEncoding(vocab, pretokenizer, opts...)
|
||||||
|
return bpe
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBytePairEncoding(vocab *Vocabulary, pretokenizer []string, opts ...BPEOption) BytePairEncoding {
|
||||||
if len(pretokenizer) == 0 {
|
if len(pretokenizer) == 0 {
|
||||||
// set default byte-level pretokenizer if none provided, e.g.
|
// set default byte-level pretokenizer if none provided, e.g.
|
||||||
// https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
|
// https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
|
||||||
pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||||
}
|
}
|
||||||
|
|
||||||
return BytePairEncoding{
|
bpe := BytePairEncoding{
|
||||||
vocab: vocab,
|
vocab: vocab,
|
||||||
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
||||||
for _, p := range pretokenizer {
|
for _, p := range pretokenizer {
|
||||||
@@ -37,6 +57,12 @@ func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEnco
|
|||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(&bpe)
|
||||||
|
}
|
||||||
|
|
||||||
|
return bpe
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
|
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
|
||||||
@@ -136,6 +162,12 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
|||||||
|
|
||||||
for split := range bpe.split(frag.value) {
|
for split := range bpe.split(frag.value) {
|
||||||
// TODO: process splits concurrently
|
// TODO: process splits concurrently
|
||||||
|
var normalized string
|
||||||
|
if bpe.spaceToSpmSep {
|
||||||
|
// SentencePiece-style: replace spaces with ▁
|
||||||
|
normalized = strings.ReplaceAll(split, " ", spmWhitespaceSep)
|
||||||
|
} else {
|
||||||
|
// GPT-2 byte-level: map bytes to shifted Unicode codepoints
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for _, b := range []byte(split) {
|
for _, b := range []byte(split) {
|
||||||
r := rune(b)
|
r := rune(b)
|
||||||
@@ -147,17 +179,18 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
|||||||
case r >= 0x007f && r <= 0x00a0:
|
case r >= 0x007f && r <= 0x00a0:
|
||||||
r = r + 0x00a2
|
r = r + 0x00a2
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteRune(r)
|
sb.WriteRune(r)
|
||||||
}
|
}
|
||||||
|
normalized = sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
// short circuit if the fragment is in the vocabulary
|
// short circuit if the fragment is in the vocabulary
|
||||||
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
|
if id := bpe.vocab.Encode(normalized); id >= 0 {
|
||||||
ids = append(ids, id)
|
ids = append(ids, id)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
runes := []rune(sb.String())
|
runes := []rune(normalized)
|
||||||
merges := make([]merge, len(runes))
|
merges := make([]merge, len(runes))
|
||||||
for r := range runes {
|
for r := range runes {
|
||||||
merges[r] = merge{
|
merges[r] = merge{
|
||||||
@@ -257,6 +290,8 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
|||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
for _, r := range bpe.vocab.Decode(id) {
|
for _, r := range bpe.vocab.Decode(id) {
|
||||||
|
// GPT-2 byte-level BPE uses Unicode chars in the 0x0100-0x0143
|
||||||
|
// range to represent bytes. Remap them back to actual bytes.
|
||||||
switch {
|
switch {
|
||||||
case r == 0x0100:
|
case r == 0x0100:
|
||||||
// this produces 0x00 aka NULL
|
// this produces 0x00 aka NULL
|
||||||
@@ -267,6 +302,15 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
|||||||
r = r - 0x0100
|
r = r - 0x0100
|
||||||
case r > 0x0120 && r <= 0x0142:
|
case r > 0x0120 && r <= 0x0142:
|
||||||
r = r - 0x00a2
|
r = r - 0x00a2
|
||||||
|
case r > 0x0143:
|
||||||
|
// Non-GPT2 rune (e.g., SentencePiece-style BPE).
|
||||||
|
// Handle ▁ as word separator, otherwise write the rune as-is.
|
||||||
|
if r == 0x2581 { // ▁ (LOWER ONE EIGHTH BLOCK)
|
||||||
|
sb.WriteByte(' ')
|
||||||
|
} else {
|
||||||
|
sb.WriteRune(r)
|
||||||
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: not using WriteRune here because it writes the UTF-8
|
// NOTE: not using WriteRune here because it writes the UTF-8
|
||||||
|
|||||||
@@ -239,6 +239,186 @@ func TestLlama(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// spmBPE builds a SentencePiece-style BPE tokenizer for testing.
|
||||||
|
//
|
||||||
|
// Models that use SentencePiece BPE differ from GPT-2 BPE in how they
|
||||||
|
// handle spaces: the vocabulary stores ▁ (U+2581) instead of GPT-2's
|
||||||
|
// shifted-byte encoding (0x0100–0x0143). Without WithSentencePieceNormalizer,
|
||||||
|
// spaces are mapped through the GPT-2 byte table which produces wrong token
|
||||||
|
// IDs for any vocabulary that uses ▁-prefixed tokens. The decode path has
|
||||||
|
// the inverse problem: high codepoints like CJK characters and ▁ itself
|
||||||
|
// would be mangled by the GPT-2 reverse mapping instead of being passed
|
||||||
|
// through (or converted to spaces in the ▁ case).
|
||||||
|
func spmBPE(t testing.TB) BytePairEncoding {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tokens := []string{
|
||||||
|
// Control tokens (low IDs, as in real SentencePiece vocabs)
|
||||||
|
"<pad>", // 0
|
||||||
|
"<eos>", // 1
|
||||||
|
"<bos>", // 2
|
||||||
|
"<|start>", // 3 - asymmetric open/close special tokens
|
||||||
|
"<end|>", // 4
|
||||||
|
"<|q>", // 5 - short special token (like <|"|>)
|
||||||
|
|
||||||
|
// ▁-prefixed word tokens (the core of what SPM BPE changes)
|
||||||
|
"▁hello", // 6
|
||||||
|
"▁world", // 7
|
||||||
|
"hello", // 8
|
||||||
|
"▁Run", // 9
|
||||||
|
"▁a", // 10
|
||||||
|
|
||||||
|
// Punctuation and structure
|
||||||
|
",", // 11
|
||||||
|
"!", // 12
|
||||||
|
":", // 13
|
||||||
|
"{", // 14
|
||||||
|
"}", // 15
|
||||||
|
|
||||||
|
// Whitespace separator
|
||||||
|
"▁", // 16
|
||||||
|
|
||||||
|
// Subword tokens used in tool-declaration-like patterns
|
||||||
|
"description", // 17
|
||||||
|
"▁command", // 18
|
||||||
|
"declaration", // 19
|
||||||
|
|
||||||
|
// Unicode token for decode passthrough testing (must be > U+0143
|
||||||
|
// to exercise the SPM decode path rather than GPT-2 byte reversal)
|
||||||
|
"▁中文", // 20
|
||||||
|
}
|
||||||
|
|
||||||
|
types := make([]int32, len(tokens))
|
||||||
|
for i := range types {
|
||||||
|
types[i] = TOKEN_TYPE_NORMAL
|
||||||
|
}
|
||||||
|
types[0] = TOKEN_TYPE_CONTROL // <pad>
|
||||||
|
types[1] = TOKEN_TYPE_CONTROL // <eos>
|
||||||
|
types[2] = TOKEN_TYPE_CONTROL // <bos>
|
||||||
|
types[3] = TOKEN_TYPE_USER_DEFINED // <|start>
|
||||||
|
types[4] = TOKEN_TYPE_USER_DEFINED // <end|>
|
||||||
|
types[5] = TOKEN_TYPE_USER_DEFINED // <|q>
|
||||||
|
|
||||||
|
return NewBytePairEncodingWithOptions(
|
||||||
|
&Vocabulary{
|
||||||
|
Values: tokens,
|
||||||
|
Types: types,
|
||||||
|
BOS: []int32{2},
|
||||||
|
EOS: []int32{1},
|
||||||
|
AddBOS: false,
|
||||||
|
},
|
||||||
|
// Empty pretokenizer list: falls back to the default pattern.
|
||||||
|
// Real SentencePiece BPE models are configured this way.
|
||||||
|
[]string{},
|
||||||
|
WithSentencePieceNormalizer(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSentencePieceBPE(t *testing.T) {
|
||||||
|
tok := spmBPE(t)
|
||||||
|
|
||||||
|
// Test 1: Space-to-▁ normalization and roundtrip.
|
||||||
|
//
|
||||||
|
// This is the core behavior that WithSentencePieceNormalizer enables.
|
||||||
|
// Without it, " hello" would be byte-mapped through the GPT-2 table
|
||||||
|
// (producing Ġhello or similar shifted codepoints) which would never
|
||||||
|
// match the ▁-prefixed vocab entry.
|
||||||
|
t.Run("spm space normalization roundtrip", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cases := map[string][]int32{
|
||||||
|
"hello": {8}, // no space → no ▁ prefix → "hello"(8)
|
||||||
|
" hello": {6}, // leading space → "▁hello"(6)
|
||||||
|
"hello, world!": {8, 11, 7, 12}, // pretokenizer splits punctuation;
|
||||||
|
// " world" normalizes to "▁world"
|
||||||
|
}
|
||||||
|
|
||||||
|
for input, wantIDs := range cases {
|
||||||
|
ids, err := tok.Encode(input, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Encode(%q): %v", input, err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(wantIDs, ids); diff != "" {
|
||||||
|
t.Errorf("Encode(%q) mismatch (-want +got):\n%s", input, diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := tok.Decode(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Decode(%v): %v", ids, err)
|
||||||
|
}
|
||||||
|
if got != input {
|
||||||
|
t.Errorf("roundtrip %q: Decode(Encode) = %q", input, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 2: Special tokens interleaved with SPM-normalized text.
|
||||||
|
//
|
||||||
|
// This mimics tool declaration patterns like:
|
||||||
|
// <|tool>declaration:bash{description:<|"|>Run a command<|"|>}<tool|>
|
||||||
|
// where special tokens (<|tool>, <|"|>, <tool|>) must be extracted
|
||||||
|
// first, then the remaining text fragments go through SPM normalization.
|
||||||
|
// Without the SPM normalizer, the text between special tokens would be
|
||||||
|
// encoded with GPT-2 byte mapping, producing entirely wrong IDs.
|
||||||
|
t.Run("special tokens with spm text fragments", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Pattern: <|start>declaration:description:<|q>Run a command<|q>}<end|>
|
||||||
|
input := "<|start>declaration:description:<|q> Run a command<|q>}<end|>"
|
||||||
|
ids, err := tok.Encode(input, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special tokens should be extracted as single IDs, and the text
|
||||||
|
// between them should be SPM-normalized (spaces → ▁).
|
||||||
|
want := []int32{
|
||||||
|
3, // <|start>
|
||||||
|
19, // "declaration" (text fragment, no leading space)
|
||||||
|
13, // ":"
|
||||||
|
17, // "description"
|
||||||
|
13, // ":"
|
||||||
|
5, // <|q>
|
||||||
|
9, // "▁Run" (space before "Run" becomes ▁)
|
||||||
|
10, // "▁a"
|
||||||
|
18, // "▁command"
|
||||||
|
5, // <|q>
|
||||||
|
15, // "}"
|
||||||
|
4, // <end|>
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(want, ids); diff != "" {
|
||||||
|
t.Errorf("mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 3: Decode handles non-GPT2 Unicode correctly.
|
||||||
|
//
|
||||||
|
// GPT-2 BPE decode reverses the byte→codepoint shift for runes in
|
||||||
|
// 0x0100–0x0143. But SentencePiece vocabs store real Unicode (CJK,
|
||||||
|
// accented chars, etc.) which have codepoints well above 0x0143.
|
||||||
|
// Without the > 0x0143 passthrough in Decode, these would be mangled
|
||||||
|
// by the GPT-2 reverse mapping (e.g., written as raw bytes instead
|
||||||
|
// of the original characters).
|
||||||
|
t.Run("decode non-gpt2 unicode passthrough", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cases := map[string][]int32{
|
||||||
|
" 中文": {20}, // ▁→space, then CJK passes through as-is
|
||||||
|
}
|
||||||
|
|
||||||
|
for want, ids := range cases {
|
||||||
|
got, err := tok.Decode(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Decode(%v): %v", ids, err)
|
||||||
|
}
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("Decode(%v) = %q, want %q", ids, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkBytePairEncoding(b *testing.B) {
|
func BenchmarkBytePairEncoding(b *testing.B) {
|
||||||
tokenizer := llama(b)
|
tokenizer := llama(b)
|
||||||
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
|
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
|
||||||
|
|||||||
Reference in New Issue
Block a user