mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
tokenizer: add byte fallback for SentencePiece BPE encoding (#15232)
* tokenizer: add byte fallback for SentencePiece BPE encoding When BPE merging produces tokens not in the vocabulary, fall back to encoding each UTF-8 byte as <0xHH> byte tokens instead of silently dropping the character. Also teach Decode to convert <0xHH> tokens back to raw bytes. Fixes #15229, fixes #15231 * tokenizer fixes
This commit is contained in:
341
model/models/gemma4/tokenizer_reference_test.go
Normal file
341
model/models/gemma4/tokenizer_reference_test.go
Normal file
@@ -0,0 +1,341 @@
|
|||||||
|
package gemma4
|
||||||
|
|
||||||
|
// TestGemma4TokenizerMatchesReference verifies our BPE tokenizer matches
|
||||||
|
// the Rust tokenizers library (the reference implementation) for Gemma 4.
|
||||||
|
//
|
||||||
|
// The test loads vocabulary from any local ollama gemma4 GGUF model.
|
||||||
|
// Skips if no gemma4 model is installed.
|
||||||
|
//
|
||||||
|
// Set VERIFY_HF_TOKENIZER=1 to verify against the Rust tokenizers library
|
||||||
|
// via Python. Requires python3 with tokenizers>=0.21 on PATH:
|
||||||
|
//
|
||||||
|
// VERIFY_HF_TOKENIZER=1 go test ./model/models/gemma4/ -run TestGemma4Tokenizer -v
|
||||||
|
//
|
||||||
|
// Workflow for adding a new test case:
|
||||||
|
// 1. Add {name: "...", input: "..."} to the test list (no want field)
|
||||||
|
// 2. Run with VERIFY_HF_TOKENIZER=1 — it prints the reference IDs
|
||||||
|
// 3. Paste those IDs into the want field
|
||||||
|
// 4. Run without VERIFY_HF_TOKENIZER — our tokenizer must match
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/fs/gguf"
|
||||||
|
"github.com/ollama/ollama/tokenizer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type tokenizerRefCase struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
want []int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reference token IDs generated by the Rust tokenizers library using
|
||||||
|
// vocab/merges from a gemma4 GGUF with add_special_tokens=False.
|
||||||
|
var gemma4TokenizerRefCases = []tokenizerRefCase{
|
||||||
|
// Basic ASCII
|
||||||
|
{name: "basic word", input: "hello", want: []int32{23391}},
|
||||||
|
{name: "two words", input: "hello world", want: []int32{23391, 1902}},
|
||||||
|
{name: "punctuation", input: "Hello, World!", want: []int32{9259, 236764, 4109, 236888}},
|
||||||
|
|
||||||
|
// Space handling (pretokenizer bug: GPT-2 splitter mangled leading/multiple spaces)
|
||||||
|
{name: "leading space", input: " hello", want: []int32{29104}},
|
||||||
|
{name: "double leading space", input: " hello", want: []int32{138, 23391}},
|
||||||
|
{name: "double space between words", input: "hello world", want: []int32{23391, 138, 12392}},
|
||||||
|
{name: "only spaces", input: " ", want: []int32{139}},
|
||||||
|
{name: "repeated spaces", input: " ", want: []int32{142}},
|
||||||
|
{name: "leading spaces phrase", input: " leading spaces", want: []int32{5830, 9952}},
|
||||||
|
{name: "multiple interior spaces", input: "multiple spaces", want: []int32{43819, 140, 35220}},
|
||||||
|
|
||||||
|
// Polish diacritics (issue #15231 — Decode mangled U+0105-U+0142)
|
||||||
|
{name: "polish diacritics", input: "ąęśćżźółń", want: []int32{237198, 237202, 14732, 237277, 238992, 24875, 238041}},
|
||||||
|
{name: "polish sentence", input: "Zażółć gęślą jaźń", want: []int32{236953, 40512, 24875, 237289, 549, 237202, 62081, 237198, 4828, 238992, 238041}},
|
||||||
|
|
||||||
|
// French accents (issue #15229 — Decode mangled U+00E0-U+00FF)
|
||||||
|
{name: "french accents", input: "café résumé naïve", want: []int32{123125, 236859, 118515, 120362}},
|
||||||
|
{name: "french with apostrophe", input: "L'élève a mangé", want: []int32{236798, 236789, 161654, 496, 14695, 236859}},
|
||||||
|
|
||||||
|
// German umlauts
|
||||||
|
{name: "german umlauts", input: "über Straße Größe", want: []int32{28223, 80176, 112880}},
|
||||||
|
|
||||||
|
// Codepoints in GPT-2 byte reversal range (U+0100-U+0142)
|
||||||
|
{name: "codepoints in gpt2 byte range", input: "ąęćł", want: []int32{237198, 226110, 237114}},
|
||||||
|
{name: "latin extended A", input: "ĀāĂ㥹", want: []int32{241920, 237448, 241645, 237106, 243514, 237198}},
|
||||||
|
|
||||||
|
// CJK & Japanese
|
||||||
|
{name: "chinese", input: "你好世界", want: []int32{144626, 12811}},
|
||||||
|
{name: "japanese hiragana", input: "こんにちは", want: []int32{85141}},
|
||||||
|
|
||||||
|
// Mixed scripts
|
||||||
|
{name: "mixed scripts", input: "hello ąęść world café 你好", want: []int32{23391, 236743, 237198, 237202, 14732, 1902, 33443, 43758, 237389}},
|
||||||
|
|
||||||
|
// Whitespace
|
||||||
|
{name: "empty string", input: "", want: []int32{}},
|
||||||
|
{name: "newlines", input: "\n\n", want: []int32{108}},
|
||||||
|
{name: "tabs", input: "\t\t", want: []int32{255969}},
|
||||||
|
|
||||||
|
// Code-like content
|
||||||
|
{name: "python code", input: "def foo(x): return x + 1", want: []int32{2063, 46293, 236769, 236781, 1473, 994, 1123, 900, 236743, 236770}},
|
||||||
|
{name: "json", input: `{"key": "value"}`, want: []int32{14937, 2478, 1083, 623, 2394, 25938}},
|
||||||
|
|
||||||
|
// Misc
|
||||||
|
{name: "repeated char", input: "aaaaaa", want: []int32{50354, 9236}},
|
||||||
|
{name: "emoji", input: "hello 👋 world", want: []int32{23391, 155818, 1902}},
|
||||||
|
{name: "digits", input: "12345", want: []int32{236770, 236778, 236800, 236812, 236810}},
|
||||||
|
{name: "float", input: "3.14159", want: []int32{236800, 236761, 236770, 236812, 236770, 236810, 236819}},
|
||||||
|
}
|
||||||
|
|
||||||
|
// findGemma4GGUF looks for any gemma4 model GGUF in the local ollama store.
|
||||||
|
func findGemma4GGUF() (string, error) {
|
||||||
|
modelsDir := envconfig.Models()
|
||||||
|
manifestDir := filepath.Join(modelsDir, "manifests", "registry.ollama.ai", "library", "gemma4")
|
||||||
|
entries, err := os.ReadDir(manifestDir)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("no gemma4 manifests in %s: %w", manifestDir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
blobDir := filepath.Join(modelsDir, "blobs")
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(filepath.Join(manifestDir, entry.Name()))
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var manifest struct {
|
||||||
|
Layers []struct {
|
||||||
|
MediaType string `json:"mediaType"`
|
||||||
|
Digest string `json:"digest"`
|
||||||
|
} `json:"layers"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, layer := range manifest.Layers {
|
||||||
|
if layer.MediaType == "application/vnd.ollama.image.model" {
|
||||||
|
blobPath := filepath.Join(blobDir, strings.Replace(layer.Digest, ":", "-", 1))
|
||||||
|
if _, err := os.Stat(blobPath); err == nil {
|
||||||
|
return blobPath, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("no gemma4 model blob found in %s", modelsDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadGemma4Tokenizer opens a GGUF and builds a BPE tokenizer from its
|
||||||
|
// tokenizer metadata — the same configuration used at inference time.
|
||||||
|
func loadGemma4Tokenizer(t *testing.T, ggufPath string) tokenizer.BytePairEncoding {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
f, err := gguf.Open(ggufPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("gguf.Open: %v", err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
tokens := f.KeyValue("tokenizer.ggml.tokens").Strings()
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
t.Fatal("no tokenizer.ggml.tokens in GGUF")
|
||||||
|
}
|
||||||
|
|
||||||
|
scores64 := f.KeyValue("tokenizer.ggml.scores").Floats()
|
||||||
|
scores := make([]float32, len(scores64))
|
||||||
|
for i, s := range scores64 {
|
||||||
|
scores[i] = float32(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
types64 := f.KeyValue("tokenizer.ggml.token_type").Ints()
|
||||||
|
types := make([]int32, len(types64))
|
||||||
|
for i, tt := range types64 {
|
||||||
|
types[i] = int32(tt)
|
||||||
|
}
|
||||||
|
|
||||||
|
merges := f.KeyValue("tokenizer.ggml.merges").Strings()
|
||||||
|
|
||||||
|
vocab := &tokenizer.Vocabulary{
|
||||||
|
Values: tokens,
|
||||||
|
Types: types,
|
||||||
|
Scores: scores,
|
||||||
|
Merges: merges,
|
||||||
|
BOS: []int32{2},
|
||||||
|
EOS: []int32{1},
|
||||||
|
AddBOS: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenizer.NewBytePairEncodingWithOptions(vocab, []string{},
|
||||||
|
tokenizer.WithSentencePieceNormalizer())
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeTokenizerJSON reconstructs a tokenizer.json from GGUF metadata
|
||||||
|
// for the Rust tokenizers library to load as an independent reference.
|
||||||
|
func writeTokenizerJSON(t *testing.T, ggufPath string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
f, err := gguf.Open(ggufPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("gguf.Open: %v", err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
tokens := f.KeyValue("tokenizer.ggml.tokens").Strings()
|
||||||
|
mergeStrs := f.KeyValue("tokenizer.ggml.merges").Strings()
|
||||||
|
|
||||||
|
vocab := make(map[string]int, len(tokens))
|
||||||
|
for i, tok := range tokens {
|
||||||
|
vocab[tok] = i
|
||||||
|
}
|
||||||
|
|
||||||
|
merges := make([][2]string, len(mergeStrs))
|
||||||
|
for i, m := range mergeStrs {
|
||||||
|
parts := strings.SplitN(m, " ", 2)
|
||||||
|
if len(parts) == 2 {
|
||||||
|
merges[i] = [2]string{parts[0], parts[1]}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tj := map[string]any{
|
||||||
|
"version": "1.0",
|
||||||
|
"model": map[string]any{
|
||||||
|
"type": "BPE",
|
||||||
|
"vocab": vocab,
|
||||||
|
"merges": merges,
|
||||||
|
},
|
||||||
|
"normalizer": map[string]any{
|
||||||
|
"type": "Replace",
|
||||||
|
"pattern": map[string]string{"String": " "},
|
||||||
|
"content": "\u2581",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpFile, err := os.CreateTemp(t.TempDir(), "gemma4_tokenizer_*.json")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewEncoder(tmpFile).Encode(tj); err != nil {
|
||||||
|
tmpFile.Close()
|
||||||
|
t.Fatalf("encode tokenizer.json: %v", err)
|
||||||
|
}
|
||||||
|
tmpFile.Close()
|
||||||
|
|
||||||
|
return tmpFile.Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGemma4TokenizerMatchesReference(t *testing.T) {
|
||||||
|
ggufPath, err := findGemma4GGUF()
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("skipping: %v", err)
|
||||||
|
}
|
||||||
|
t.Logf("using GGUF: %s", ggufPath)
|
||||||
|
|
||||||
|
tok := loadGemma4Tokenizer(t, ggufPath)
|
||||||
|
|
||||||
|
verify := os.Getenv("VERIFY_HF_TOKENIZER") != ""
|
||||||
|
var tokenizerJSONPath string
|
||||||
|
if verify {
|
||||||
|
if err := exec.Command("python3", "-c", "from tokenizers import Tokenizer").Run(); err != nil {
|
||||||
|
t.Fatal("VERIFY_HF_TOKENIZER=1 requires python3 with tokenizers>=0.21 on PATH")
|
||||||
|
}
|
||||||
|
tokenizerJSONPath = writeTokenizerJSON(t, ggufPath)
|
||||||
|
defer os.Remove(tokenizerJSONPath)
|
||||||
|
t.Log("VERIFY_HF_TOKENIZER=1: verifying against Rust tokenizers library")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range gemma4TokenizerRefCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
ids, err := tok.Encode(tc.input, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Encode(%q): %v", tc.input, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.want != nil {
|
||||||
|
if fmt.Sprint(ids) != fmt.Sprint(tc.want) {
|
||||||
|
t.Errorf("Encode(%q):\n got: %v\n want: %v", tc.input, ids, tc.want)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Errorf("no expected IDs for %q; our tokenizer produced: %v", tc.input, ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ids) > 0 {
|
||||||
|
decoded, err := tok.Decode(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Decode: %v", err)
|
||||||
|
}
|
||||||
|
if decoded != tc.input {
|
||||||
|
t.Errorf("roundtrip %q: Decode(Encode) = %q", tc.input, decoded)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if verify {
|
||||||
|
refIDs := encodeWithRustTokenizer(t, tokenizerJSONPath, tc.input)
|
||||||
|
|
||||||
|
if fmt.Sprint(refIDs) != fmt.Sprint(ids) {
|
||||||
|
fmt.Fprintf(os.Stderr, "\nREFERENCE OUTPUT for %s (copy-paste as want):\nwant: []int32{%s},\n\n",
|
||||||
|
tc.name, int32SliceStr(refIDs))
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.want != nil && fmt.Sprint(refIDs) != fmt.Sprint(tc.want) {
|
||||||
|
t.Errorf("hardcoded expected IDs don't match reference for %q:\n ref: %v\n hardcoded: %v",
|
||||||
|
tc.input, refIDs, tc.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeWithRustTokenizer(t *testing.T, tokenizerPath, text string) []int32 {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if text == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
script := fmt.Sprintf(`
|
||||||
|
from tokenizers import Tokenizer
|
||||||
|
t = Tokenizer.from_file(%q)
|
||||||
|
ids = t.encode(%q, add_special_tokens=False).ids
|
||||||
|
print(",".join(str(i) for i in ids))
|
||||||
|
`, tokenizerPath, text)
|
||||||
|
|
||||||
|
cmd := exec.Command("python3", "-c", script)
|
||||||
|
var stdout, stderr strings.Builder
|
||||||
|
cmd.Stdout = &stdout
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
t.Fatalf("python3 failed: %v\nstderr: %s", err, stderr.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(strings.TrimSpace(stdout.String()), ",")
|
||||||
|
var ids []int32
|
||||||
|
for _, p := range parts {
|
||||||
|
if p == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var id int32
|
||||||
|
fmt.Sscanf(p, "%d", &id)
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
func int32SliceStr(ids []int32) string {
|
||||||
|
parts := make([]string, len(ids))
|
||||||
|
for i, id := range ids {
|
||||||
|
parts[i] = fmt.Sprintf("%d", id)
|
||||||
|
}
|
||||||
|
return strings.Join(parts, ", ")
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"iter"
|
"iter"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/dlclark/regexp2"
|
"github.com/dlclark/regexp2"
|
||||||
@@ -41,27 +42,28 @@ func NewBytePairEncodingWithOptions(vocab *Vocabulary, pretokenizer []string, op
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newBytePairEncoding(vocab *Vocabulary, pretokenizer []string, opts ...BPEOption) BytePairEncoding {
|
func newBytePairEncoding(vocab *Vocabulary, pretokenizer []string, opts ...BPEOption) BytePairEncoding {
|
||||||
if len(pretokenizer) == 0 {
|
|
||||||
// set default byte-level pretokenizer if none provided, e.g.
|
|
||||||
// https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
|
|
||||||
pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
|
||||||
}
|
|
||||||
|
|
||||||
bpe := BytePairEncoding{
|
bpe := BytePairEncoding{
|
||||||
vocab: vocab,
|
vocab: vocab,
|
||||||
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
|
||||||
for _, p := range pretokenizer {
|
|
||||||
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(&bpe)
|
opt(&bpe)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(pretokenizer) == 0 && !bpe.spaceToSpmSep {
|
||||||
|
// set default byte-level pretokenizer if none provided, e.g.
|
||||||
|
// https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
|
||||||
|
pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||||
|
}
|
||||||
|
|
||||||
|
bpe.regexps = slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
||||||
|
for _, p := range pretokenizer {
|
||||||
|
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
return bpe
|
return bpe
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,9 +263,17 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
|||||||
|
|
||||||
for _, merge := range merges {
|
for _, merge := range merges {
|
||||||
if len(merge.runes) > 0 {
|
if len(merge.runes) > 0 {
|
||||||
// TODO: handle the edge case where the rune isn't in the vocabulary
|
|
||||||
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||||
ids = append(ids, id)
|
ids = append(ids, id)
|
||||||
|
} else if bpe.spaceToSpmSep {
|
||||||
|
// SentencePiece byte fallback: encode each UTF-8 byte as <0xHH>
|
||||||
|
for _, b := range []byte(string(merge.runes)) {
|
||||||
|
if id := bpe.vocab.Encode(fmt.Sprintf("<0x%02X>", b)); id >= 0 {
|
||||||
|
ids = append(ids, id)
|
||||||
|
} else {
|
||||||
|
slog.Debug("unknown byte token", "byte", b)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -288,6 +298,37 @@ func (l lazyIdsString) LogValue() slog.Value {
|
|||||||
|
|
||||||
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|
||||||
|
// SentencePiece-style BPE stores true Unicode codepoints in the vocab
|
||||||
|
// (plus ▁ as a whitespace marker), so decoding should pass runes through
|
||||||
|
// directly instead of applying the GPT-2 byte-level reverse mapping.
|
||||||
|
// Without this, codepoints in the 0x0100-0x0142 range (e.g. ą ę ć ł)
|
||||||
|
// get mangled by the GPT-2 reversal into control characters.
|
||||||
|
if bpe.spaceToSpmSep {
|
||||||
|
for _, id := range ids {
|
||||||
|
data := bpe.vocab.Decode(id)
|
||||||
|
|
||||||
|
// SentencePiece byte tokens: "<0xHH>" → raw byte
|
||||||
|
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
|
||||||
|
if b, err := strconv.ParseUint(data[3:5], 16, 8); err == nil {
|
||||||
|
sb.WriteByte(byte(b))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range data {
|
||||||
|
if r == 0x2581 { // ▁ (LOWER ONE EIGHTH BLOCK)
|
||||||
|
sb.WriteByte(' ')
|
||||||
|
} else {
|
||||||
|
sb.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
for _, r := range bpe.vocab.Decode(id) {
|
for _, r := range bpe.vocab.Decode(id) {
|
||||||
// GPT-2 byte-level BPE uses Unicode chars in the 0x0100-0x0143
|
// GPT-2 byte-level BPE uses Unicode chars in the 0x0100-0x0143
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package tokenizer
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -286,6 +287,22 @@ func spmBPE(t testing.TB) BytePairEncoding {
|
|||||||
// Unicode token for decode passthrough testing (must be > U+0143
|
// Unicode token for decode passthrough testing (must be > U+0143
|
||||||
// to exercise the SPM decode path rather than GPT-2 byte reversal)
|
// to exercise the SPM decode path rather than GPT-2 byte reversal)
|
||||||
"▁中文", // 20
|
"▁中文", // 20
|
||||||
|
|
||||||
|
// Unicode tokens with codepoints in the GPT-2 byte range (0x0100-0x0142).
|
||||||
|
// Without the SPM decode path, these get mangled by GPT-2 byte reversal.
|
||||||
|
"ą", // 21 (U+0105) — would become 0x05 via GPT-2 reversal
|
||||||
|
"ę", // 22 (U+0119) — would become 0x19
|
||||||
|
"ć", // 23 (U+0107) — would become 0x07
|
||||||
|
"ł", // 24 (U+0142) — would become 0xA0
|
||||||
|
|
||||||
|
// Byte fallback tokens (SentencePiece BYTE type)
|
||||||
|
"<0x00>", // 25
|
||||||
|
"<0x01>", // 26
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add all 256 byte tokens starting at index 27
|
||||||
|
for b := 2; b < 256; b++ {
|
||||||
|
tokens = append(tokens, fmt.Sprintf("<0x%02X>", b))
|
||||||
}
|
}
|
||||||
|
|
||||||
types := make([]int32, len(tokens))
|
types := make([]int32, len(tokens))
|
||||||
@@ -298,6 +315,9 @@ func spmBPE(t testing.TB) BytePairEncoding {
|
|||||||
types[3] = TOKEN_TYPE_USER_DEFINED // <|start>
|
types[3] = TOKEN_TYPE_USER_DEFINED // <|start>
|
||||||
types[4] = TOKEN_TYPE_USER_DEFINED // <end|>
|
types[4] = TOKEN_TYPE_USER_DEFINED // <end|>
|
||||||
types[5] = TOKEN_TYPE_USER_DEFINED // <|q>
|
types[5] = TOKEN_TYPE_USER_DEFINED // <|q>
|
||||||
|
for i := 21; i < len(types); i++ {
|
||||||
|
types[i] = TOKEN_TYPE_BYTE
|
||||||
|
}
|
||||||
|
|
||||||
return NewBytePairEncodingWithOptions(
|
return NewBytePairEncodingWithOptions(
|
||||||
&Vocabulary{
|
&Vocabulary{
|
||||||
@@ -319,27 +339,25 @@ func TestSentencePieceBPE(t *testing.T) {
|
|||||||
|
|
||||||
// Test 1: Space-to-▁ normalization and roundtrip.
|
// Test 1: Space-to-▁ normalization and roundtrip.
|
||||||
//
|
//
|
||||||
// This is the core behavior that WithSentencePieceNormalizer enables.
|
// SentencePiece BPE has no pretokenizer — the BPE merges handle word
|
||||||
// Without it, " hello" would be byte-mapped through the GPT-2 table
|
// boundaries via ▁ markers. With no merges in the test vocab, multi-char
|
||||||
// (producing Ġhello or similar shifted codepoints) which would never
|
// tokens won't be found, but the roundtrip must still be lossless.
|
||||||
// match the ▁-prefixed vocab entry.
|
|
||||||
t.Run("spm space normalization roundtrip", func(t *testing.T) {
|
t.Run("spm space normalization roundtrip", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
cases := map[string][]int32{
|
for _, input := range []string{
|
||||||
"hello": {8}, // no space → no ▁ prefix → "hello"(8)
|
"hello",
|
||||||
" hello": {6}, // leading space → "▁hello"(6)
|
" hello",
|
||||||
"hello, world!": {8, 11, 7, 12}, // pretokenizer splits punctuation;
|
"hello, world!",
|
||||||
// " world" normalizes to "▁world"
|
" leading spaces",
|
||||||
}
|
"multiple spaces",
|
||||||
|
} {
|
||||||
for input, wantIDs := range cases {
|
|
||||||
ids, err := tok.Encode(input, false)
|
ids, err := tok.Encode(input, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Encode(%q): %v", input, err)
|
t.Fatalf("Encode(%q): %v", input, err)
|
||||||
}
|
}
|
||||||
if diff := cmp.Diff(wantIDs, ids); diff != "" {
|
if len(ids) == 0 {
|
||||||
t.Errorf("Encode(%q) mismatch (-want +got):\n%s", input, diff)
|
t.Fatalf("Encode(%q) returned empty IDs", input)
|
||||||
}
|
}
|
||||||
|
|
||||||
got, err := tok.Decode(ids)
|
got, err := tok.Decode(ids)
|
||||||
@@ -358,41 +376,105 @@ func TestSentencePieceBPE(t *testing.T) {
|
|||||||
// <|tool>declaration:bash{description:<|"|>Run a command<|"|>}<tool|>
|
// <|tool>declaration:bash{description:<|"|>Run a command<|"|>}<tool|>
|
||||||
// where special tokens (<|tool>, <|"|>, <tool|>) must be extracted
|
// where special tokens (<|tool>, <|"|>, <tool|>) must be extracted
|
||||||
// first, then the remaining text fragments go through SPM normalization.
|
// first, then the remaining text fragments go through SPM normalization.
|
||||||
// Without the SPM normalizer, the text between special tokens would be
|
|
||||||
// encoded with GPT-2 byte mapping, producing entirely wrong IDs.
|
|
||||||
t.Run("special tokens with spm text fragments", func(t *testing.T) {
|
t.Run("special tokens with spm text fragments", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
// Pattern: <|start>declaration:description:<|q>Run a command<|q>}<end|>
|
|
||||||
input := "<|start>declaration:description:<|q> Run a command<|q>}<end|>"
|
input := "<|start>declaration:description:<|q> Run a command<|q>}<end|>"
|
||||||
ids, err := tok.Encode(input, false)
|
ids, err := tok.Encode(input, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Special tokens should be extracted as single IDs, and the text
|
// Special tokens should be extracted as single IDs at the right positions.
|
||||||
// between them should be SPM-normalized (spaces → ▁).
|
// The text between them is SPM-normalized and BPE-encoded (specific IDs
|
||||||
want := []int32{
|
// depend on merges, so we verify the special token positions + roundtrip).
|
||||||
3, // <|start>
|
specialPositions := map[int32]bool{3: true, 4: true, 5: true} // <|start>, <end|>, <|q>
|
||||||
19, // "declaration" (text fragment, no leading space)
|
foundSpecials := 0
|
||||||
13, // ":"
|
for _, id := range ids {
|
||||||
17, // "description"
|
if specialPositions[id] {
|
||||||
13, // ":"
|
foundSpecials++
|
||||||
5, // <|q>
|
}
|
||||||
9, // "▁Run" (space before "Run" becomes ▁)
|
}
|
||||||
10, // "▁a"
|
if foundSpecials != 4 { // <|start>, <|q>, <|q>, <end|>
|
||||||
18, // "▁command"
|
t.Errorf("expected 4 special tokens, found %d in %v", foundSpecials, ids)
|
||||||
5, // <|q>
|
|
||||||
15, // "}"
|
|
||||||
4, // <end|>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(want, ids); diff != "" {
|
// First token must be <|start>(3), last must be <end|>(4)
|
||||||
t.Errorf("mismatch (-want +got):\n%s", diff)
|
if ids[0] != 3 {
|
||||||
|
t.Errorf("first token = %d, want 3 (<|start>)", ids[0])
|
||||||
|
}
|
||||||
|
if ids[len(ids)-1] != 4 {
|
||||||
|
t.Errorf("last token = %d, want 4 (<end|>)", ids[len(ids)-1])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Test 3: Decode handles non-GPT2 Unicode correctly.
|
// Test 3: Byte fallback for characters not in the vocabulary.
|
||||||
|
//
|
||||||
|
// SentencePiece vocabs include <0xHH> byte tokens for every byte value.
|
||||||
|
// When a character (e.g. "ą" = U+0105 = C4 85) isn't in the vocab as a
|
||||||
|
// direct token, the encoder must fall back to its UTF-8 bytes:
|
||||||
|
// <0xC4> <0x85>. Without this fallback, the character is silently dropped.
|
||||||
|
// See: https://github.com/ollama/ollama/issues/15229
|
||||||
|
t.Run("byte fallback for unknown chars", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// "ą" is not in the vocab — should fall back to byte tokens
|
||||||
|
ids, err := tok.Encode("ą", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Encode(ą): %v", err)
|
||||||
|
}
|
||||||
|
if len(ids) == 0 {
|
||||||
|
t.Fatal("Encode(ą) returned empty IDs — character was silently dropped")
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := tok.Decode(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Decode: %v", err)
|
||||||
|
}
|
||||||
|
if got != "ą" {
|
||||||
|
t.Errorf("roundtrip = %q, want %q", got, "ą")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 4: Byte fallback preserves known tokens around unknown chars.
|
||||||
|
t.Run("byte fallback mixed with known tokens", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// "hello" is in vocab, "é" is not
|
||||||
|
ids, err := tok.Encode("helloé", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Encode: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := tok.Decode(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Decode: %v", err)
|
||||||
|
}
|
||||||
|
if got != "helloé" {
|
||||||
|
t.Errorf("roundtrip = %q, want %q", got, "helloé")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 5: Decode doesn't mangle Unicode in the GPT-2 byte range.
|
||||||
|
//
|
||||||
|
// Characters like ą (U+0105), ę (U+0119), ć (U+0107), ł (U+0142) have
|
||||||
|
// codepoints in the 0x0100-0x0142 range that GPT-2 byte reversal would
|
||||||
|
// remap to control characters. SentencePiece decode must pass them through.
|
||||||
|
t.Run("decode unicode in gpt2 byte range", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Token IDs 21-24 are ą, ę, ć, ł
|
||||||
|
ids := []int32{21, 22, 23, 24}
|
||||||
|
got, err := tok.Decode(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Decode: %v", err)
|
||||||
|
}
|
||||||
|
if got != "ąęćł" {
|
||||||
|
t.Errorf("Decode = %q, want %q", got, "ąęćł")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test 6: Decode handles non-GPT2 Unicode correctly.
|
||||||
//
|
//
|
||||||
// GPT-2 BPE decode reverses the byte→codepoint shift for runes in
|
// GPT-2 BPE decode reverses the byte→codepoint shift for runes in
|
||||||
// 0x0100–0x0143. But SentencePiece vocabs store real Unicode (CJK,
|
// 0x0100–0x0143. But SentencePiece vocabs store real Unicode (CJK,
|
||||||
|
|||||||
Reference in New Issue
Block a user