diff --git a/x/mlxrunner/model/base/base.go b/x/mlxrunner/model/base/base.go index fcc8b8627..6d3a25798 100644 --- a/x/mlxrunner/model/base/base.go +++ b/x/mlxrunner/model/base/base.go @@ -8,10 +8,10 @@ import ( "log/slog" "sync" - "github.com/ollama/ollama/x/imagegen/tokenizer" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" + "github.com/ollama/ollama/x/tokenizer" ) // Model is the interface that model implementations must satisfy. diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index 0da5862c8..274fc9be6 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -7,7 +7,6 @@ import ( "errors" "log/slog" "time" - "unicode/utf8" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" @@ -126,13 +125,5 @@ func (r Runner) Decode(sample int32, b *bytes.Buffer) string { return "" } - if text := b.String(); utf8.ValidString(text) { - b.Reset() - return text - } else if b.Len() >= utf8.UTFMax { - b.Reset() - return text - } - - return "" + return flushValidUTF8Prefix(b) } diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index 826281c31..0b24fdb3d 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -12,12 +12,12 @@ import ( "golang.org/x/sync/errgroup" - "github.com/ollama/ollama/x/imagegen/tokenizer" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" "github.com/ollama/ollama/x/mlxrunner/model/base" "github.com/ollama/ollama/x/mlxrunner/sample" + "github.com/ollama/ollama/x/tokenizer" ) type Request struct { diff --git a/x/mlxrunner/utf8_buffer.go b/x/mlxrunner/utf8_buffer.go new file mode 100644 index 000000000..5d155b478 --- /dev/null +++ b/x/mlxrunner/utf8_buffer.go @@ -0,0 +1,47 @@ +package mlxrunner + +import ( + "bytes" + "unicode/utf8" +) + +// flushValidUTF8Prefix returns and consumes the longest valid UTF-8 prefix +// currently buffered, leaving any incomplete trailing bytes in place. +func flushValidUTF8Prefix(b *bytes.Buffer) string { + data := b.Bytes() + if len(data) == 0 { + return "" + } + + prefix := validUTF8PrefixLen(data) + if prefix == 0 { + return "" + } + + text := string(data[:prefix]) + b.Next(prefix) + return text +} + +func validUTF8PrefixLen(data []byte) int { + i := 0 + prefix := 0 + for i < len(data) { + r, size := utf8.DecodeRune(data[i:]) + if r == utf8.RuneError && size == 1 { + if !utf8.FullRune(data[i:]) { + break + } + + // Invalid UTF-8 byte; consume one byte to guarantee forward progress. + i++ + prefix = i + continue + } + + i += size + prefix = i + } + + return prefix +} diff --git a/x/mlxrunner/utf8_buffer_test.go b/x/mlxrunner/utf8_buffer_test.go new file mode 100644 index 000000000..aaaf77b63 --- /dev/null +++ b/x/mlxrunner/utf8_buffer_test.go @@ -0,0 +1,46 @@ +package mlxrunner + +import ( + "bytes" + "testing" +) + +func TestFlushValidUTF8Prefix_PreservesIncompleteRune(t *testing.T) { + var b bytes.Buffer + + b.Write([]byte{0xE3, 0x81}) + if got := flushValidUTF8Prefix(&b); got != "" { + t.Fatalf("first flush = %q, want empty", got) + } + + b.Write([]byte{0x93, 0xE3}) + if got := flushValidUTF8Prefix(&b); got != "こ" { + t.Fatalf("second flush = %q, want %q", got, "こ") + } + + if got := b.Bytes(); !bytes.Equal(got, []byte{0xE3}) { + t.Fatalf("buffer after second flush = %v, want %v", got, []byte{0xE3}) + } + + b.Write([]byte{0x82, 0x93}) + if got := flushValidUTF8Prefix(&b); got != "ん" { + t.Fatalf("third flush = %q, want %q", got, "ん") + } + + if b.Len() != 0 { + t.Fatalf("buffer not empty after third flush: %d", b.Len()) + } +} + +func TestFlushValidUTF8Prefix_ValidText(t *testing.T) { + var b bytes.Buffer + b.WriteString("hello 世界") + + if got := flushValidUTF8Prefix(&b); got != "hello 世界" { + t.Fatalf("flush = %q, want %q", got, "hello 世界") + } + + if b.Len() != 0 { + t.Fatalf("buffer not empty after flush: %d", b.Len()) + } +} diff --git a/x/models/gemma3/gemma3.go b/x/models/gemma3/gemma3.go index f35ef2a75..7ba24d294 100644 --- a/x/models/gemma3/gemma3.go +++ b/x/models/gemma3/gemma3.go @@ -8,12 +8,12 @@ import ( "fmt" "math" - "github.com/ollama/ollama/x/imagegen/tokenizer" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" "github.com/ollama/ollama/x/mlxrunner/model/base" "github.com/ollama/ollama/x/models/nn" + "github.com/ollama/ollama/x/tokenizer" ) func init() { diff --git a/x/models/glm4_moe_lite/glm4_moe_lite.go b/x/models/glm4_moe_lite/glm4_moe_lite.go index 65e26244d..a1ec55972 100644 --- a/x/models/glm4_moe_lite/glm4_moe_lite.go +++ b/x/models/glm4_moe_lite/glm4_moe_lite.go @@ -9,12 +9,12 @@ import ( "fmt" "math" - "github.com/ollama/ollama/x/imagegen/tokenizer" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" "github.com/ollama/ollama/x/mlxrunner/model/base" "github.com/ollama/ollama/x/models/nn" + "github.com/ollama/ollama/x/tokenizer" ) func init() { diff --git a/x/models/llama/llama.go b/x/models/llama/llama.go index f82678d3b..61e51b35c 100644 --- a/x/models/llama/llama.go +++ b/x/models/llama/llama.go @@ -8,12 +8,12 @@ import ( "fmt" "math" - "github.com/ollama/ollama/x/imagegen/tokenizer" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" "github.com/ollama/ollama/x/mlxrunner/model/base" "github.com/ollama/ollama/x/models/nn" + "github.com/ollama/ollama/x/tokenizer" ) func init() { diff --git a/x/models/qwen3/qwen3.go b/x/models/qwen3/qwen3.go index 7a49cf37a..76170881a 100644 --- a/x/models/qwen3/qwen3.go +++ b/x/models/qwen3/qwen3.go @@ -8,12 +8,12 @@ import ( "fmt" "math" - "github.com/ollama/ollama/x/imagegen/tokenizer" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" "github.com/ollama/ollama/x/mlxrunner/model/base" "github.com/ollama/ollama/x/models/nn" + "github.com/ollama/ollama/x/tokenizer" ) func init() { diff --git a/x/tokenizer/tokenizer.go b/x/tokenizer/tokenizer.go new file mode 100644 index 000000000..301e51aea --- /dev/null +++ b/x/tokenizer/tokenizer.go @@ -0,0 +1,108 @@ +//go:build mlx + +// tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models +// +// Based on standard BPE algorithm (Sennrich et al. 2015) with: +// - GPT-2 byte-level encoding (OpenAI tiktoken) +// - HuggingFace tokenizer.json pretokenizer patterns +// - SentencePiece ▁-style space handling + +package tokenizer + +import "regexp" + +// TokenizerType identifies the tokenization algorithm +type TokenizerType int + +const ( + TokenizerBPE TokenizerType = iota // GPT-2 style byte-level BPE + TokenizerSentencePiece // SentencePiece with ▁ for spaces +) + +// Vocabulary holds the tokenizer vocabulary and merges +type Vocabulary struct { + Values []string + Reverse map[string]int32 + Merges map[string]int + + BOS int32 + EOS []int32 // Multiple EOS tokens supported (e.g., Gemma has and ) + PAD int32 // Padding token (often <|endoftext|> or ) + AddBOS bool + AddEOS bool + + // Precomputed byte token IDs for <0xNN> fallback (256 entries, -1 if not found) + byteTokens [256]int32 +} + +// Tokenizer handles BPE and SentencePiece tokenization +type Tokenizer struct { + vocab *Vocabulary + pretokenizer *regexp.Regexp + specialTokens map[string]int32 // Special tokens for direct lookup + sortedSpecialTokens []string // Special tokens sorted by length, longest first + typ TokenizerType // Algorithm type +} + +// Precomputed GPT-2 byte-level encoding table +// Maps byte values to their encoded rune equivalents +var byteToRune [256]rune + +func init() { + for b := 0; b < 256; b++ { + r := rune(b) + switch { + case r == 0x00ad: + r = 0x0143 + case r <= 0x0020: + r = r + 0x0100 + case r >= 0x007f && r <= 0x00a0: + r = r + 0x00a2 + } + byteToRune[b] = r + } +} + +// VocabSize returns the vocabulary size +func (t *Tokenizer) VocabSize() int { + return len(t.vocab.Values) +} + +// BOS returns the beginning of sequence token ID +func (t *Tokenizer) BOS() int32 { + return t.vocab.BOS +} + +// EOS returns the first end of sequence token ID (for backwards compatibility) +func (t *Tokenizer) EOS() int32 { + if len(t.vocab.EOS) > 0 { + return t.vocab.EOS[0] + } + return -1 +} + +// EOSTokens returns all end of sequence token IDs +func (t *Tokenizer) EOSTokens() []int32 { + return t.vocab.EOS +} + +// PAD returns the padding token ID, or -1 if not set +func (t *Tokenizer) PAD() int32 { + return t.vocab.PAD +} + +// IsEOS returns true if the token ID is an end of sequence token +func (t *Tokenizer) IsEOS(id int32) bool { + for _, eos := range t.vocab.EOS { + if id == eos { + return true + } + } + return false +} + +// GetSpecialToken returns the token ID for a special token string +func (t *Tokenizer) GetSpecialToken(name string) (int32, bool) { + id, ok := t.specialTokens[name] + return id, ok +} diff --git a/x/tokenizer/tokenizer_benchmark_test.go b/x/tokenizer/tokenizer_benchmark_test.go new file mode 100644 index 000000000..e65a59786 --- /dev/null +++ b/x/tokenizer/tokenizer_benchmark_test.go @@ -0,0 +1,251 @@ +//go:build mlx + +package tokenizer + +import ( + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +var ( + benchmarkSinkIDs []int32 + benchmarkSinkStr string + benchmarkSinkTok *Tokenizer +) + +const benchmarkWordPieceJSON = `{ + "model": { + "type": "WordPiece", + "vocab": { + "[UNK]": 0, + "hello": 1, + "##world": 2, + "##ly": 3, + "##hello": 4 + } + }, + "added_tokens": [] +}` + +const benchmarkSentencePieceJSON = `{ + "model": { + "type": "BPE", + "vocab": { + "\u2581": 0, + "h": 1, + "e": 2, + "l": 3, + "o": 4, + "w": 5, + "r": 6, + "d": 7, + "<0x0A>": 8 + }, + "merges": [] + }, + "decoder": { + "type": "Sequence", + "decoders": [ + { + "type": "Replace", + "pattern": { + "String": "\u2581" + } + } + ] + }, + "added_tokens": [] +}` + +func benchmarkMiniLlamaPath(tb testing.TB) string { + tb.Helper() + + _, filename, _, ok := runtime.Caller(0) + if !ok { + tb.Fatal("failed to resolve benchmark file path") + } + + return filepath.Join(filepath.Dir(filename), "..", "imagegen", "tokenizer", "testdata", "mini_llama.json") +} + +func benchmarkLoadMiniLlama(tb testing.TB) *Tokenizer { + tb.Helper() + + data := benchmarkLoadMiniLlamaBytes(tb) + tok, err := LoadFromBytes(data) + if err != nil { + tb.Fatalf("failed to load mini llama tokenizer: %v", err) + } + return tok +} + +func benchmarkLoadMiniLlamaBytes(tb testing.TB) []byte { + tb.Helper() + + data, err := os.ReadFile(benchmarkMiniLlamaPath(tb)) + if err != nil { + tb.Fatalf("failed to read mini llama tokenizer: %v", err) + } + return data +} + +func benchmarkLoadFromBytes(tb testing.TB, data []byte) *Tokenizer { + tb.Helper() + + tok, err := LoadFromBytes(data) + if err != nil { + tb.Fatalf("failed to load tokenizer from bytes: %v", err) + } + return tok +} + +func BenchmarkTokenizerEncodeBPE(b *testing.B) { + tok := benchmarkLoadMiniLlama(b) + + inputs := []struct { + name string + text string + }{ + {name: "short", text: "Hello, world!"}, + {name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)}, + {name: "long_sequential", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 80)}, + {name: "long_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)}, + {name: "huge_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640)}, + {name: "special_tokens", text: "<|begin_of_text|>system\nYou are concise.<|end_of_text|>"}, + } + + for _, input := range inputs { + b.Run(input.name, func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(input.text))) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchmarkSinkIDs = tok.Encode(input.text, false) + } + }) + } +} + +func BenchmarkTokenizerDecodeBPE(b *testing.B) { + tok := benchmarkLoadMiniLlama(b) + + inputs := []struct { + name string + text string + }{ + {name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)}, + {name: "long", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)}, + } + + for _, input := range inputs { + ids := tok.Encode(input.text, false) + b.Run(input.name, func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(input.text))) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchmarkSinkStr = tok.Decode(ids) + } + }) + } +} + +func BenchmarkTokenizerLoadFromBytes(b *testing.B) { + data := benchmarkLoadMiniLlamaBytes(b) + + config := &TokenizerConfig{ + TokenizerConfigJSON: []byte(`{ + "bos_token": {"content": "<|begin_of_text|>"}, + "eos_token": {"content": "<|end_of_text|>"}, + "add_bos_token": true + }`), + GenerationConfigJSON: []byte(`{"bos_token_id": 128000, "eos_token_id": 128001}`), + } + + b.Run("without_config", func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(data))) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + tok, err := LoadFromBytes(data) + if err != nil { + b.Fatalf("LoadFromBytes failed: %v", err) + } + benchmarkSinkTok = tok + } + }) + + b.Run("with_config", func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(data))) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + tok, err := LoadFromBytesWithConfig(data, config) + if err != nil { + b.Fatalf("LoadFromBytesWithConfig failed: %v", err) + } + benchmarkSinkTok = tok + } + }) +} + +func BenchmarkTokenizerEncodeWordPiece(b *testing.B) { + tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON)) + text := strings.Repeat("helloworldly", 16) + + b.ReportAllocs() + b.SetBytes(int64(len(text))) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchmarkSinkIDs = tok.Encode(text, false) + } +} + +func BenchmarkTokenizerDecodeWordPiece(b *testing.B) { + tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON)) + text := strings.Repeat("helloworldly", 16) + ids := tok.Encode(text, false) + + b.ReportAllocs() + b.SetBytes(int64(len(text))) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchmarkSinkStr = tok.Decode(ids) + } +} + +func BenchmarkTokenizerEncodeSentencePiece(b *testing.B) { + tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON)) + text := strings.Repeat("hello world\n", 64) + + b.ReportAllocs() + b.SetBytes(int64(len(text))) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchmarkSinkIDs = tok.Encode(text, false) + } +} + +func BenchmarkTokenizerDecodeSentencePiece(b *testing.B) { + tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON)) + text := strings.Repeat("hello world\n", 64) + ids := tok.Encode(text, false) + + b.ReportAllocs() + b.SetBytes(int64(len(text))) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchmarkSinkStr = tok.Decode(ids) + } +} diff --git a/x/tokenizer/tokenizer_bpe.go b/x/tokenizer/tokenizer_bpe.go new file mode 100644 index 000000000..1e625c20a --- /dev/null +++ b/x/tokenizer/tokenizer_bpe.go @@ -0,0 +1,175 @@ +//go:build mlx + +package tokenizer + +import "container/heap" + +type bpeMergeNode struct { + prev int + next int + token string +} + +type bpePair struct { + left int + right int + rank int + value string +} + +type bpePairHeap []*bpePair + +func (h bpePairHeap) Len() int { return len(h) } + +func (h bpePairHeap) Less(i, j int) bool { + return h[i].rank < h[j].rank || (h[i].rank == h[j].rank && h[i].left < h[j].left) +} + +func (h bpePairHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +func (h *bpePairHeap) Push(x any) { + *h = append(*h, x.(*bpePair)) +} + +func (h *bpePairHeap) Pop() any { + old := *h + n := len(old) + item := old[n-1] + *h = old[:n-1] + return item +} + +// encodeBPEMerge encodes using BPE merge algorithm. +// Uses the heap/linked-list pair merge strategy from tokenizer/bytepairencoding.go: +// merge the lowest-rank valid pair, then only recheck adjacent pairs. +func (t *Tokenizer) encodeBPEMerge(encoded string, ids []int32) []int32 { + runes := []rune(encoded) + if len(runes) == 0 { + return ids + } + + nodes := make([]bpeMergeNode, len(runes)) + for i := range runes { + nodes[i] = bpeMergeNode{ + prev: i - 1, + next: i + 1, + token: string(runes[i]), + } + } + + pairwise := func(left, right int) *bpePair { + if left < 0 || right >= len(nodes) { + return nil + } + if nodes[left].token == "" || nodes[right].token == "" { + return nil + } + + leftToken, rightToken := nodes[left].token, nodes[right].token + rank, ok := t.vocab.Merges[leftToken+" "+rightToken] + if !ok { + return nil + } + + value := leftToken + rightToken + if _, ok := t.vocab.Reverse[value]; !ok { + return nil + } + + return &bpePair{ + left: left, + right: right, + rank: rank, + value: value, + } + } + + pairs := bpePairHeap{} + heap.Init(&pairs) + for i := 0; i < len(runes)-1; i++ { + if pair := pairwise(i, i+1); pair != nil { + heap.Push(&pairs, pair) + } + } + + for pairs.Len() > 0 { + pair := heap.Pop(&pairs).(*bpePair) + left, right := nodes[pair.left], nodes[pair.right] + if left.token == "" || right.token == "" { + continue + } + if left.next != pair.right || right.prev != pair.left { + continue + } + if left.token+right.token != pair.value { + continue + } + + nodes[pair.left].token = pair.value + nodes[pair.right].token = "" + nodes[pair.left].next = right.next + if right.next < len(nodes) { + nodes[right.next].prev = pair.left + } + + if pair := pairwise(nodes[pair.left].prev, pair.left); pair != nil { + heap.Push(&pairs, pair) + } + if pair := pairwise(pair.left, nodes[pair.left].next); pair != nil { + heap.Push(&pairs, pair) + } + } + + for _, node := range nodes { + if node.token == "" { + continue + } + + if id, ok := t.vocab.Reverse[node.token]; ok { + ids = append(ids, id) + continue + } + + ids = t.appendByteFallback(ids, node.token) + } + + return ids +} + +func (t *Tokenizer) appendByteFallback(ids []int32, token string) []int32 { + if t.typ == TokenizerBPE { + for _, r := range token { + if b, ok := decodeByteLevelRune(r); ok { + if id := t.vocab.byteTokens[b]; id >= 0 { + ids = append(ids, id) + } + } + } + return ids + } + + // SentencePiece fallback uses the UTF-8 bytes for <0xNN> tokens. + for _, b := range []byte(token) { + if id := t.vocab.byteTokens[b]; id >= 0 { + ids = append(ids, id) + } + } + return ids +} + +func decodeByteLevelRune(r rune) (byte, bool) { + switch { + case r >= 0x00 && r <= 0xFF: + return byte(r), true + case r == 0x0100: + return 0x00, true + case r == 0x0143: + return 0x00ad, true + case r > 0x0100 && r <= 0x0120: + return byte(r - 0x0100), true + case r > 0x0120 && r <= 0x0142: + return byte(r - 0x00a2), true + default: + return 0, false + } +} diff --git a/x/tokenizer/tokenizer_correctness_test.go b/x/tokenizer/tokenizer_correctness_test.go new file mode 100644 index 000000000..2fe94d279 --- /dev/null +++ b/x/tokenizer/tokenizer_correctness_test.go @@ -0,0 +1,137 @@ +//go:build mlx + +package tokenizer + +import ( + "runtime" + "strings" + "testing" +) + +func equalIDs(a, b []int32) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func TestEncodeRoundtripMiniLlama(t *testing.T) { + tok := benchmarkLoadMiniLlama(t) + + inputs := []string{ + "", + "hello", + "hello world", + " hello world ", + "don't we'll they're", + "1234567890", + "こんにちは世界", + "Hello 世界", + "func main() {}", + "<|begin_of_text|>system\nYou are concise.<|end_of_text|>", + strings.Repeat("The quick brown fox jumps over the lazy dog. ", 32), + } + + for _, input := range inputs { + ids := tok.Encode(input, false) + got := tok.Decode(ids) + if got != input { + t.Fatalf("roundtrip mismatch for %q: got %q", input, got) + } + } +} + +func TestSplitBySpecialTokensGreedyLongest(t *testing.T) { + data := []byte(`{ + "model": { + "type": "BPE", + "vocab": {"a": 0, "b": 1}, + "merges": [] + }, + "added_tokens": [ + {"id": 2, "content": "", "special": true}, + {"id": 3, "content": "x", "special": true} + ] + }`) + + tok, err := LoadFromBytes(data) + if err != nil { + t.Fatalf("failed to load tokenizer: %v", err) + } + + input := "axb" + want := []string{"a", "x", "b"} + + got := tok.splitBySpecialTokens(input) + if len(got) != len(want) { + t.Fatalf("split length mismatch: got %v want %v", got, want) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("split mismatch at %d: got %v want %v", i, got, want) + } + } +} + +func TestSplitBySpecialTokensFallbackWithoutCache(t *testing.T) { + data := []byte(`{ + "model": { + "type": "BPE", + "vocab": {"a": 0, "b": 1}, + "merges": [] + }, + "added_tokens": [ + {"id": 2, "content": "", "special": true}, + {"id": 3, "content": "x", "special": true} + ] + }`) + + tok, err := LoadFromBytes(data) + if err != nil { + t.Fatalf("failed to load tokenizer: %v", err) + } + + input := "axb" + want := []string{"a", "x", "b"} + + // Simulate construction outside loader path where cache is not set. + tok.sortedSpecialTokens = nil + + got := tok.splitBySpecialTokens(input) + if len(got) != len(want) { + t.Fatalf("split length mismatch: got %v want %v", got, want) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("split mismatch at %d: got %v want %v", i, got, want) + } + } +} + +func TestEncodeDeterministicAcrossGOMAXPROCS(t *testing.T) { + tok := benchmarkLoadMiniLlama(t) + + input := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640) + + prev := runtime.GOMAXPROCS(0) + defer runtime.GOMAXPROCS(prev) + + runtime.GOMAXPROCS(1) + seq := tok.Encode(input, false) + + if prev < 2 { + runtime.GOMAXPROCS(2) + } else { + runtime.GOMAXPROCS(prev) + } + par := tok.Encode(input, false) + + if !equalIDs(seq, par) { + t.Fatalf("encode mismatch between sequential and parallel paths: seq=%d par=%d", len(seq), len(par)) + } +} diff --git a/x/tokenizer/tokenizer_decode.go b/x/tokenizer/tokenizer_decode.go new file mode 100644 index 000000000..e02d2a88b --- /dev/null +++ b/x/tokenizer/tokenizer_decode.go @@ -0,0 +1,56 @@ +//go:build mlx + +package tokenizer + +import ( + "strconv" + "strings" +) + +// Decode converts token IDs back to text +func (t *Tokenizer) Decode(ids []int32) string { + var sb strings.Builder + + for _, id := range ids { + if int(id) >= len(t.vocab.Values) { + continue + } + + token := t.vocab.Values[id] + + switch t.typ { + case TokenizerSentencePiece: + // SentencePiece style: replace ▁ with space, decode byte tokens + token = strings.ReplaceAll(token, "▁", " ") + // Handle byte fallback tokens like <0x0D> + if len(token) == 6 && token[0] == '<' && token[1] == '0' && token[2] == 'x' && token[5] == '>' { + if v, err := strconv.ParseUint(token[3:5], 16, 8); err == nil { + sb.WriteByte(byte(v)) + continue + } + } + sb.WriteString(token) + default: + // GPT-2 BPE style: decode byte-level encoding + for _, r := range token { + switch { + case r == 0x0100: + // Mirror GGML tokenizer behavior for NULL byte. + // 0x00 is omitted during decode. + continue + case r == 0x0143: + r = 0x00ad + case r > 0x0100 && r <= 0x0120: + r = r - 0x0100 + case r > 0x0120 && r <= 0x0142: + r = r - 0x00a2 + } + + // Write as byte, not UTF-8 encoded rune + sb.WriteByte(byte(r)) + } + } + } + + return sb.String() +} diff --git a/x/tokenizer/tokenizer_encode.go b/x/tokenizer/tokenizer_encode.go new file mode 100644 index 000000000..1b71ea6d3 --- /dev/null +++ b/x/tokenizer/tokenizer_encode.go @@ -0,0 +1,289 @@ +//go:build mlx + +package tokenizer + +import ( + "runtime" + "sort" + "strings" + "sync" + "unicode" + "unicode/utf8" +) + +const ( + encodeParallelMinInputBytes = 4 * 1024 + encodeParallelMinChunksPerWorker = 8 +) + +type tokenMatch struct { + start int + end int +} + +type encodeChunk struct { + text string + isSpecial bool +} + +// isNonNewlineWhitespace returns true if s contains only whitespace characters (no newlines) +func isNonNewlineWhitespace(s string) bool { + if s == "" { + return false + } + for _, r := range s { + if r == '\n' || r == '\r' { + return false + } + if !unicode.IsSpace(r) { + return false + } + } + return true +} + +// splitBySpecialTokens splits text into parts, keeping special tokens as separate elements +func (t *Tokenizer) splitBySpecialTokens(s string) []string { + if len(t.specialTokens) == 0 { + return []string{s} + } + + tokens := t.sortedSpecialTokens + if len(tokens) == 0 { + // Fallback for tokenizers constructed outside the loaders. + tokens = make([]string, 0, len(t.specialTokens)) + for tok := range t.specialTokens { + tokens = append(tokens, tok) + } + sort.Slice(tokens, func(i, j int) bool { + return len(tokens[i]) > len(tokens[j]) + }) + } + + var result []string + remaining := s + + for len(remaining) > 0 { + found := false + for _, tok := range tokens { + if strings.HasPrefix(remaining, tok) { + result = append(result, tok) + remaining = remaining[len(tok):] + found = true + break + } + } + if !found { + // Find next special token position + nextPos := len(remaining) + for _, tok := range tokens { + if idx := strings.Index(remaining, tok); idx != -1 && idx < nextPos { + nextPos = idx + } + } + if nextPos > 0 { + result = append(result, remaining[:nextPos]) + } + remaining = remaining[nextPos:] + } + } + + return result +} + +func adjustWhitespaceBoundary(part string, curr, next *tokenMatch) { + m := part[curr.start:curr.end] + nextText := part[next.start:next.end] + + if !isNonNewlineWhitespace(m) || len(nextText) == 0 { + return + } + + firstRune, _ := utf8.DecodeRuneInString(nextText) + if !unicode.IsLetter(firstRune) { + return + } + + lastSpaceStart := curr.end + for j := curr.end; j > curr.start; { + r, size := utf8.DecodeLastRuneInString(part[curr.start:j]) + if unicode.IsSpace(r) { + lastSpaceStart = j - size + break + } + j -= size + } + if lastSpaceStart > curr.start { + curr.end = lastSpaceStart + next.start = lastSpaceStart + } else { + next.start = curr.start + curr.end = curr.start + } +} + +func (t *Tokenizer) forEachPartChunk(part string, fn func(encodeChunk)) { + if _, ok := t.specialTokens[part]; ok { + fn(encodeChunk{text: part, isSpecial: true}) + return + } + + if t.pretokenizer == nil { + fn(encodeChunk{text: part, isSpecial: false}) + return + } + + re := t.pretokenizer + offset := 0 + loc := re.FindStringIndex(part[offset:]) + if loc == nil { + return + } + + curr := tokenMatch{start: offset + loc[0], end: offset + loc[1]} + offset += loc[1] + + for { + loc = re.FindStringIndex(part[offset:]) + if loc == nil { + if curr.end > curr.start { + fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false}) + } + return + } + + next := tokenMatch{start: offset + loc[0], end: offset + loc[1]} + offset += loc[1] + + adjustWhitespaceBoundary(part, &curr, &next) + + if curr.end > curr.start { + fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false}) + } + curr = next + } +} + +func (t *Tokenizer) appendEncodedChunk(ids []int32, c encodeChunk) []int32 { + if c.isSpecial { + if id, ok := t.specialTokens[c.text]; ok { + return append(ids, id) + } + return ids + } + + return t.encodeChunkInto(c.text, ids) +} + +// Encode tokenizes text to token IDs. +// Parallel encoding is used only for very large inputs with enough chunks per worker. +func (t *Tokenizer) Encode(s string, addBOS bool) []int32 { + // First: split by special tokens + parts := t.splitBySpecialTokens(s) + + // Fast path: encode sequentially without materializing chunk slices. + if len(s) < encodeParallelMinInputBytes { + var ids []int32 + for _, part := range parts { + t.forEachPartChunk(part, func(c encodeChunk) { + ids = t.appendEncodedChunk(ids, c) + }) + } + + if addBOS && t.vocab.BOS >= 0 { + ids = append([]int32{t.vocab.BOS}, ids...) + } + return ids + } + + // For large inputs collect chunks to enable parallel processing. + var allChunks []encodeChunk + for _, part := range parts { + t.forEachPartChunk(part, func(c encodeChunk) { + allChunks = append(allChunks, c) + }) + } + + // Encode chunks. Use the parallel path only when the chunk count is + // large enough to amortize goroutine/synchronization overhead. + useParallel := true + numWorkers := runtime.GOMAXPROCS(0) + if numWorkers > len(allChunks) { + numWorkers = len(allChunks) + } + if numWorkers < 2 || len(allChunks) < numWorkers*encodeParallelMinChunksPerWorker { + useParallel = false + } + + var ids []int32 + if !useParallel { + for _, c := range allChunks { + ids = t.appendEncodedChunk(ids, c) + } + } else { + chunksPer := (len(allChunks) + numWorkers - 1) / numWorkers + results := make([][]int32, numWorkers) + var wg sync.WaitGroup + + for i := 0; i < numWorkers; i++ { + start := i * chunksPer + end := start + chunksPer + if end > len(allChunks) { + end = len(allChunks) + } + if start >= end { + continue + } + + wg.Add(1) + go func(i int, chunks []encodeChunk) { + defer wg.Done() + var r []int32 + for _, c := range chunks { + r = t.appendEncodedChunk(r, c) + } + results[i] = r + }(i, allChunks[start:end]) + } + wg.Wait() + + for _, r := range results { + ids = append(ids, r...) + } + } + + if addBOS && t.vocab.BOS >= 0 { + ids = append([]int32{t.vocab.BOS}, ids...) + } + return ids +} + +// encodeChunkInto appends encoded tokens to ids and returns the extended slice. +// Uses BPE merge algorithm for both BPE and SentencePiece tokenization. +func (t *Tokenizer) encodeChunkInto(s string, ids []int32) []int32 { + if s == "" { + return ids + } + + // Apply encoding transformation + // SentencePiece: replace space with ▁ + // BPE: convert bytes using precomputed table (GPT-2 byte-level encoding) + var encoded string + if t.typ == TokenizerSentencePiece { + encoded = strings.ReplaceAll(s, " ", "▁") + } else { + var sb strings.Builder + sb.Grow(len(s) * 2) + for i := 0; i < len(s); i++ { + sb.WriteRune(byteToRune[s[i]]) + } + encoded = sb.String() + } + + // Fast path: check if entire chunk is a single token + if id, ok := t.vocab.Reverse[encoded]; ok { + return append(ids, id) + } + + return t.encodeBPEMerge(encoded, ids) +} diff --git a/x/tokenizer/tokenizer_ggml_parity_test.go b/x/tokenizer/tokenizer_ggml_parity_test.go new file mode 100644 index 000000000..4cef3d3dd --- /dev/null +++ b/x/tokenizer/tokenizer_ggml_parity_test.go @@ -0,0 +1,207 @@ +//go:build mlx + +package tokenizer + +import ( + "bufio" + "encoding/json" + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +func llama32GGMLFixturePath(tb testing.TB, file string) string { + tb.Helper() + + _, filename, _, ok := runtime.Caller(0) + if !ok { + tb.Fatal("failed to resolve test file path") + } + + return filepath.Join(filepath.Dir(filename), "..", "..", "tokenizer", "testdata", "llama3.2", file) +} + +func loadLlama32FromGGMLFixture(tb testing.TB) *Tokenizer { + tb.Helper() + + f, err := os.Open(llama32GGMLFixturePath(tb, "encoder.json")) + if err != nil { + tb.Fatalf("failed to open encoder.json: %v", err) + } + defer f.Close() + + vocab := make(map[string]int32) + if err := json.NewDecoder(f).Decode(&vocab); err != nil { + tb.Fatalf("failed to decode encoder.json: %v", err) + } + + type addedToken struct { + ID int32 `json:"id"` + Content string `json:"content"` + Special bool `json:"special"` + } + var addedTokens []addedToken + for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} { + if _, ok := vocab[token]; !ok { + id := int32(len(vocab)) + vocab[token] = id + addedTokens = append(addedTokens, addedToken{ID: id, Content: token, Special: true}) + } + } + + mf, err := os.Open(llama32GGMLFixturePath(tb, "vocab.bpe")) + if err != nil { + tb.Fatalf("failed to open vocab.bpe: %v", err) + } + defer mf.Close() + + var merges []string + scanner := bufio.NewScanner(mf) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "#") { + continue + } + line = strings.TrimSpace(line) + if line != "" { + merges = append(merges, line) + } + } + if err := scanner.Err(); err != nil { + tb.Fatalf("failed to read vocab.bpe: %v", err) + } + + payload := struct { + Model struct { + Type string `json:"type"` + Vocab map[string]int32 `json:"vocab"` + Merges []string `json:"merges"` + } `json:"model"` + PreTokenizer struct { + Type string `json:"type"` + Pretokenizers []struct { + Type string `json:"type"` + Pattern struct { + Regex string `json:"Regex"` + } `json:"pattern"` + } `json:"pretokenizers"` + } `json:"pre_tokenizer"` + AddedTokens []addedToken `json:"added_tokens"` + }{} + + payload.Model.Type = "BPE" + payload.Model.Vocab = vocab + payload.Model.Merges = merges + payload.PreTokenizer.Type = "Sequence" + payload.PreTokenizer.Pretokenizers = []struct { + Type string `json:"type"` + Pattern struct { + Regex string `json:"Regex"` + } `json:"pattern"` + }{ + { + Type: "Split", + Pattern: struct { + Regex string `json:"Regex"` + }{ + Regex: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, + }, + }, + } + payload.AddedTokens = addedTokens + + data, err := json.Marshal(payload) + if err != nil { + tb.Fatalf("failed to marshal synthetic tokenizer.json: %v", err) + } + + tok, err := LoadFromBytes(data) + if err != nil { + tb.Fatalf("failed to load tokenizer from fixture data: %v", err) + } + return tok +} + +func TestGGMLLlamaKnownEncodings(t *testing.T) { + tok := loadLlama32FromGGMLFixture(t) + + cases := map[string][]int32{ + "hello world": {15339, 1917}, + "hello <|end_of_text|>": {15339, 220, 128001}, + "<|begin_of_text|>A B!": {128000, 32, 426, 0}, + "<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0}, + "<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0}, + "<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001}, + } + + for input, want := range cases { + got := tok.Encode(input, false) + if !equalIDs(got, want) { + t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want) + } + } +} + +func TestGGMLLlamaRepeatedZeros(t *testing.T) { + tok := loadLlama32FromGGMLFixture(t) + + cases := map[int][]int32{ + 1: {15}, + 2: {410}, + 3: {931}, + 4: {931, 15}, + 5: {931, 410}, + 6: {931, 931}, + 7: {931, 931, 15}, + 8: {931, 931, 410}, + 9: {931, 931, 931}, + 10: {931, 931, 931, 15}, + 11: {931, 931, 931, 410}, + 12: {931, 931, 931, 931}, + 13: {931, 931, 931, 931, 15}, + 14: {931, 931, 931, 931, 410}, + 15: {931, 931, 931, 931, 931}, + 16: {931, 931, 931, 931, 931, 15}, + 17: {931, 931, 931, 931, 931, 410}, + } + + for n, want := range cases { + input := strings.Repeat("0", n) + got := tok.Encode(input, false) + if !equalIDs(got, want) { + t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want) + } + } +} + +func TestGGMLLlamaRoundtripAndByteBehavior(t *testing.T) { + tok := loadLlama32FromGGMLFixture(t) + + cases := []string{ + "hello", + "hello ", + "hello ", + " hello", + " hello ", + " hello ", + "hello world", + "请考试我的软件!12345", + } + + for _, input := range cases { + ids := tok.Encode(input, false) + got := tok.Decode(ids) + if got != input { + t.Fatalf("roundtrip mismatch for %q: got %q", input, got) + } + } + + // Match GGML tokenizer behavior: 0x00 is omitted when decoding. + ids := tok.Encode(string(rune(0x00)), false) + got := tok.Decode(ids) + if got != "" { + t.Fatalf("expected empty decode for 0x00, got %q (ids=%v)", got, ids) + } +} diff --git a/x/tokenizer/tokenizer_load.go b/x/tokenizer/tokenizer_load.go new file mode 100644 index 000000000..d2a253e17 --- /dev/null +++ b/x/tokenizer/tokenizer_load.go @@ -0,0 +1,458 @@ +//go:build mlx + +package tokenizer + +import ( + "encoding/json" + "fmt" + "regexp" + "sort" + "strings" +) + +// TokenizerConfig holds optional configuration data that can be passed to LoadFromBytesWithConfig. +type TokenizerConfig struct { + TokenizerConfigJSON []byte // tokenizer_config.json content + GenerationConfigJSON []byte // generation_config.json content + SpecialTokensMapJSON []byte // special_tokens_map.json content + ConfigJSON []byte // config.json content +} + +// LoadFromBytes loads a tokenizer from tokenizer.json bytes. +// This is useful when loading from blob storage where the file content is already in memory. +// Note: This won't load special token config from companion files. Use LoadFromBytesWithConfig +// to provide tokenizer_config.json data for proper PAD/EOS token loading. +func LoadFromBytes(data []byte) (*Tokenizer, error) { + return loadFromTokenizerJSON(data) +} + +// LoadFromBytesWithConfig loads a tokenizer from tokenizer.json bytes with additional config files. +// This is useful when loading from blob storage where companion config files are also blobs. +func LoadFromBytesWithConfig(data []byte, config *TokenizerConfig) (*Tokenizer, error) { + t, err := loadFromTokenizerJSON(data) + if err != nil { + return nil, err + } + + if config == nil { + return t, nil + } + + // Apply special token configs from provided data + loadSpecialTokenConfigFromBytes(t, config) + + return t, nil +} + +// loadFromTokenizerJSON parses tokenizer.json content from bytes. +func loadFromTokenizerJSON(data []byte) (*Tokenizer, error) { + + var raw struct { + Model struct { + Type string `json:"type"` // "BPE" + Vocab map[string]int32 `json:"vocab"` + Merges json.RawMessage `json:"merges"` // Can be []string or [][]string (BPE only) + } `json:"model"` + PreTokenizer json.RawMessage `json:"pre_tokenizer"` + Decoder json.RawMessage `json:"decoder"` + AddedTokens []struct { + ID int32 `json:"id"` + Content string `json:"content"` + Special bool `json:"special"` + } `json:"added_tokens"` + } + + if err := json.Unmarshal(data, &raw); err != nil { + return nil, fmt.Errorf("failed to parse tokenizer: %w", err) + } + + // Covers SentencePiece and BPE models + if raw.Model.Type != "BPE" { + return nil, fmt.Errorf("unsupported tokenizer type: %s", raw.Model.Type) + } + + // Parse merges - can be []string (Llama) or [][]string (GPT-OSS). + var mergesStrings []string + if raw.Model.Merges != nil { + var mergesArrays [][]string + if err := json.Unmarshal(raw.Model.Merges, &mergesStrings); err != nil { + // Try array of arrays format + if err := json.Unmarshal(raw.Model.Merges, &mergesArrays); err != nil { + return nil, fmt.Errorf("failed to parse merges: %w", err) + } + // Convert [][]string to []string + mergesStrings = make([]string, len(mergesArrays)) + for i, pair := range mergesArrays { + if len(pair) != 2 { + return nil, fmt.Errorf("failed to parse merges: expected merge pair of length 2, got %d", len(pair)) + } + mergesStrings[i] = pair[0] + " " + pair[1] + } + } + } + + // Build tokenizer + t := &Tokenizer{ + vocab: &Vocabulary{ + Values: make([]string, len(raw.Model.Vocab)), + Reverse: raw.Model.Vocab, + Merges: make(map[string]int, len(mergesStrings)), + BOS: -1, + PAD: -1, + }, + specialTokens: make(map[string]int32), + } + + // Build values array + for token, id := range raw.Model.Vocab { + if int(id) >= len(t.vocab.Values) { + newValues := make([]string, id+1) + copy(newValues, t.vocab.Values) + t.vocab.Values = newValues + } + t.vocab.Values[id] = token + } + + // Build merges map + for i, merge := range mergesStrings { + t.vocab.Merges[merge] = i + } + + // Add all added_tokens to vocabulary and special tokens map. + // HuggingFace treats ALL added_tokens as special for tokenization purposes - + // they bypass BPE and get their own token ID. The "special" flag just indicates + // if it's a "truly special" token like BOS/EOS/PAD, but for tokenization we need + // to treat all added_tokens as special to match HuggingFace behavior. + for _, tok := range raw.AddedTokens { + if int(tok.ID) >= len(t.vocab.Values) { + newValues := make([]string, tok.ID+1) + copy(newValues, t.vocab.Values) + t.vocab.Values = newValues + } + t.vocab.Values[tok.ID] = tok.Content + t.specialTokens[tok.Content] = tok.ID // Add ALL added_tokens to special tokens + } + + // Precompute byte token IDs for <0xNN> fallback + initByteTokens(t) + + // Determine tokenizer type + switch { + case detectSentencePiece(raw.Decoder): + t.typ = TokenizerSentencePiece + default: + t.typ = TokenizerBPE + } + + // Parse and compile pretokenizer pattern (BPE only - SentencePiece doesn't use pretokenizer) + if t.typ == TokenizerBPE { + pattern := extractPretokenizer(raw.PreTokenizer) + if pattern == "" { + pattern = `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+` + } + re, err := regexp.Compile(rewritePatternForRE2(pattern)) + if err != nil { + return nil, fmt.Errorf("failed to compile pretokenizer regex %q: %w", pattern, err) + } + t.pretokenizer = re + } + + cacheSortedSpecialTokens(t) + + return t, nil +} + +func cacheSortedSpecialTokens(t *Tokenizer) { + if len(t.specialTokens) == 0 { + t.sortedSpecialTokens = nil + return + } + + tokens := make([]string, 0, len(t.specialTokens)) + for tok := range t.specialTokens { + tokens = append(tokens, tok) + } + sort.Slice(tokens, func(i, j int) bool { + return len(tokens[i]) > len(tokens[j]) + }) + t.sortedSpecialTokens = tokens +} + +type specialTokenConfigData struct { + tokenizerConfigJSON []byte + generationConfigJSON []byte + specialTokensMapJSON []byte + configJSON []byte +} + +func applySpecialTokenConfig(t *Tokenizer, config specialTokenConfigData) { + parseTokenIDs := func(v interface{}) []int32 { + switch val := v.(type) { + case float64: + return []int32{int32(val)} + case []interface{}: + ids := make([]int32, 0, len(val)) + for _, id := range val { + if f, ok := id.(float64); ok { + ids = append(ids, int32(f)) + } + } + return ids + } + return nil + } + + // Priority 1: generation_config.json + if len(config.generationConfigJSON) > 0 { + var genConfig struct { + EOSTokenID interface{} `json:"eos_token_id"` + BOSTokenID interface{} `json:"bos_token_id"` + } + if err := json.Unmarshal(config.generationConfigJSON, &genConfig); err == nil { + if ids := parseTokenIDs(genConfig.EOSTokenID); len(ids) > 0 { + t.vocab.EOS = ids + } + if ids := parseTokenIDs(genConfig.BOSTokenID); len(ids) > 0 { + t.vocab.BOS = ids[0] + } + } + } + + // Priority 2: config.json + if len(config.configJSON) > 0 && (len(t.vocab.EOS) == 0 || t.vocab.BOS < 0) { + var modelConfig struct { + EOSTokenID interface{} `json:"eos_token_id"` + BOSTokenID interface{} `json:"bos_token_id"` + } + if err := json.Unmarshal(config.configJSON, &modelConfig); err == nil { + if len(t.vocab.EOS) == 0 { + if ids := parseTokenIDs(modelConfig.EOSTokenID); len(ids) > 0 { + t.vocab.EOS = ids + } + } + if t.vocab.BOS < 0 { + if ids := parseTokenIDs(modelConfig.BOSTokenID); len(ids) > 0 { + t.vocab.BOS = ids[0] + } + } + } + } + + // Priority 3: tokenizer_config.json + if len(config.tokenizerConfigJSON) > 0 { + var tokConfig struct { + BOSToken interface{} `json:"bos_token"` + EOSToken interface{} `json:"eos_token"` + PADToken interface{} `json:"pad_token"` + AddBOSToken *bool `json:"add_bos_token"` + AddEOSToken *bool `json:"add_eos_token"` + } + if err := json.Unmarshal(config.tokenizerConfigJSON, &tokConfig); err == nil { + if t.vocab.BOS < 0 { + if bosStr := extractTokenString(tokConfig.BOSToken); bosStr != "" { + if id, ok := t.specialTokens[bosStr]; ok { + t.vocab.BOS = id + } + } + } + if len(t.vocab.EOS) == 0 { + if eosStr := extractTokenString(tokConfig.EOSToken); eosStr != "" { + if id, ok := t.specialTokens[eosStr]; ok { + t.vocab.EOS = []int32{id} + } + } + } + if t.vocab.PAD < 0 { + if padStr := extractTokenString(tokConfig.PADToken); padStr != "" { + if id, ok := t.specialTokens[padStr]; ok { + t.vocab.PAD = id + } + } + } + if tokConfig.AddBOSToken != nil { + t.vocab.AddBOS = *tokConfig.AddBOSToken + } + if tokConfig.AddEOSToken != nil { + t.vocab.AddEOS = *tokConfig.AddEOSToken + } + } + } + + // Priority 4: special_tokens_map.json + if len(config.specialTokensMapJSON) > 0 { + var tokensMap map[string]interface{} + if err := json.Unmarshal(config.specialTokensMapJSON, &tokensMap); err == nil { + if t.vocab.BOS < 0 { + if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" { + if id, ok := t.specialTokens[bosStr]; ok { + t.vocab.BOS = id + } + } + } + if len(t.vocab.EOS) == 0 { + if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" { + if id, ok := t.specialTokens[eosStr]; ok { + t.vocab.EOS = []int32{id} + } + } + } + if t.vocab.PAD < 0 { + if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" { + if id, ok := t.specialTokens[padStr]; ok { + t.vocab.PAD = id + } + } + } + } + } +} + +// extractTokenString extracts the token string from various formats used in HuggingFace configs. +// Tokens can be represented as: +// - string: "token" +// - object: {"content": "token", ...} +func extractTokenString(v interface{}) string { + if v == nil { + return "" + } + // Direct string + if s, ok := v.(string); ok { + return s + } + // Object with content field + if m, ok := v.(map[string]interface{}); ok { + if content, ok := m["content"].(string); ok { + return content + } + } + return "" +} + +// rewritePatternForRE2 rewrites HuggingFace pretokenizer regex patterns to be +// compatible with Go's regexp package (RE2). HuggingFace patterns use PCRE features: +// - (?!\S) negative lookahead - RE2 doesn't support this +// - (?i:...) inline case-insensitive groups - RE2 doesn't support this +// +// We replace \s+(?!\S)|\s+ with \s+ and fix whitespace boundaries in encodeWithRegex(). +// The lookahead version splits "a b" into ["a", " ", " b"] (space prepended to word). +// Simple \s+ would give ["a", " ", "b"]. We post-process to match Python's behavior. +func rewritePatternForRE2(pattern string) string { + // Replace lookahead pattern with simple \s+ - we fix boundaries in encodeWithRegex() + pattern = strings.ReplaceAll(pattern, `\s+(?!\S)|\s+`, `\s+`) + + // Handle the pattern when it appears with a ? suffix (optional contractions in GPT-4o style) + // IMPORTANT: Must be done before the non-optional version to avoid partial replacement + pattern = strings.ReplaceAll(pattern, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, + `(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?`) + + // Expand case-insensitive contraction pattern to explicit alternations + // (?i:'s|'t|'re|'ve|'m|'ll|'d) -> '[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD] + pattern = strings.ReplaceAll(pattern, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)`, + `(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])`) + + return pattern +} + +// loadSpecialTokenConfigFromBytes loads special token configuration from byte slices. +func loadSpecialTokenConfigFromBytes(t *Tokenizer, config *TokenizerConfig) { + applySpecialTokenConfig(t, specialTokenConfigData{ + tokenizerConfigJSON: config.TokenizerConfigJSON, + generationConfigJSON: config.GenerationConfigJSON, + specialTokensMapJSON: config.SpecialTokensMapJSON, + configJSON: config.ConfigJSON, + }) +} + +// detectSentencePiece checks if the decoder uses SentencePiece-style (▁ for spaces) +// vs GPT-2 byte-level encoding +func detectSentencePiece(data json.RawMessage) bool { + if data == nil { + return false + } + + // Check for Sequence decoder with Replace step (SentencePiece style) + var seq struct { + Type string `json:"type"` + Decoders []struct { + Type string `json:"type"` + Pattern struct { + String string `json:"String"` + } `json:"pattern"` + } `json:"decoders"` + } + if err := json.Unmarshal(data, &seq); err == nil { + if seq.Type == "Sequence" { + for _, dec := range seq.Decoders { + // Look for Replace decoder that converts ▁ to space + if dec.Type == "Replace" && dec.Pattern.String == "▁" { + return true + } + } + } + } + + // Check for direct ByteLevel decoder (GPT-2 style) + var simple struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &simple); err == nil { + if simple.Type == "ByteLevel" { + return false + } + } + + return false +} + +// initByteTokens precomputes byte token IDs for <0xNN> fallback encoding +func initByteTokens(t *Tokenizer) { + for i := range t.vocab.byteTokens { + t.vocab.byteTokens[i] = -1 + } + for b := 0; b < 256; b++ { + token := fmt.Sprintf("<0x%02X>", b) + if id, ok := t.vocab.Reverse[token]; ok { + t.vocab.byteTokens[b] = id + } + } +} + +// extractPretokenizer extracts the regex pattern from the pre_tokenizer config +func extractPretokenizer(data json.RawMessage) string { + if data == nil { + return "" + } + + // Try to parse as a single Split pretokenizer + var single struct { + Type string `json:"type"` + Pattern struct { + Regex string `json:"Regex"` + } `json:"pattern"` + } + if err := json.Unmarshal(data, &single); err == nil && single.Pattern.Regex != "" { + return single.Pattern.Regex + } + + // Try to parse as Sequence of pretokenizers - use first Split pattern + var seq struct { + Type string `json:"type"` + Pretokenizers []struct { + Type string `json:"type"` + Pattern struct { + Regex string `json:"Regex"` + } `json:"pattern"` + } `json:"pretokenizers"` + } + if err := json.Unmarshal(data, &seq); err == nil && seq.Type == "Sequence" { + for _, pt := range seq.Pretokenizers { + if pt.Type == "Split" && pt.Pattern.Regex != "" { + return pt.Pattern.Regex + } + } + } + + return "" +} diff --git a/x/tokenizer/tokenizer_load_test.go b/x/tokenizer/tokenizer_load_test.go new file mode 100644 index 000000000..136399c2e --- /dev/null +++ b/x/tokenizer/tokenizer_load_test.go @@ -0,0 +1,26 @@ +//go:build mlx + +package tokenizer + +import ( + "strings" + "testing" +) + +func TestLoadFromBytesRejectsWordPiece(t *testing.T) { + data := []byte(`{ + "model": { + "type": "WordPiece", + "vocab": {"[UNK]": 0, "hello": 1} + }, + "added_tokens": [] + }`) + + _, err := LoadFromBytes(data) + if err == nil { + t.Fatal("expected WordPiece load to fail") + } + if !strings.Contains(err.Error(), "unsupported tokenizer type: WordPiece") { + t.Fatalf("unexpected error: %v", err) + } +}