mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
* prefer rocm v6 on windows Avoid building with v7 - more changes are needed * MLX: add header vendoring and remove go build tag This switches to using a vendoring approach for the mlx-c headers so that Go can build without requiring a cmake first. This enables building the new MLX based code by default. Every time cmake runs, the headers are refreshed, so we can easily keep them in sync when we bump mlx versions. Basic Windows and Linux support are verified. * ci: harden for flaky choco repo servers CI sometimes fails due to choco not actually installing cache. Since it just speeds up the build, we can proceed without. * review comments
1172 lines
32 KiB
Go
1172 lines
32 KiB
Go
// tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models
|
|
//
|
|
// Based on standard BPE algorithm (Sennrich et al. 2015) with:
|
|
// - GPT-2 byte-level encoding (OpenAI tiktoken)
|
|
// - HuggingFace tokenizer.json pretokenizer patterns
|
|
// - SentencePiece ▁-style space handling
|
|
|
|
package tokenizer
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"regexp"
|
|
"runtime"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"unicode"
|
|
"unicode/utf8"
|
|
)
|
|
|
|
// TokenizerType identifies the tokenization algorithm
|
|
type TokenizerType int
|
|
|
|
const (
|
|
TokenizerBPE TokenizerType = iota // GPT-2 style byte-level BPE
|
|
TokenizerSentencePiece // SentencePiece with ▁ for spaces
|
|
TokenizerWordPiece // BERT style with ## continuations
|
|
)
|
|
|
|
// Vocabulary holds the tokenizer vocabulary and merges
|
|
type Vocabulary struct {
|
|
Values []string
|
|
Reverse map[string]int32
|
|
Merges map[string]int
|
|
|
|
BOS int32
|
|
EOS []int32 // Multiple EOS tokens supported (e.g., Gemma has <eos> and <end_of_turn>)
|
|
PAD int32 // Padding token (often <|endoftext|> or <pad>)
|
|
AddBOS bool
|
|
AddEOS bool
|
|
|
|
// Precomputed byte token IDs for <0xNN> fallback (256 entries, -1 if not found)
|
|
byteTokens [256]int32
|
|
}
|
|
|
|
// Tokenizer handles BPE, SentencePiece, and WordPiece tokenization
|
|
type Tokenizer struct {
|
|
vocab *Vocabulary
|
|
pretokenizer *regexp.Regexp
|
|
specialTokens map[string]int32 // Special tokens for direct lookup
|
|
typ TokenizerType // Algorithm type
|
|
unkToken int32 // [UNK] token ID for WordPiece fallback
|
|
}
|
|
|
|
// Precomputed GPT-2 byte-level encoding table
|
|
// Maps byte values to their encoded rune equivalents
|
|
var byteToRune [256]rune
|
|
|
|
func init() {
|
|
for b := 0; b < 256; b++ {
|
|
r := rune(b)
|
|
switch {
|
|
case r == 0x00ad:
|
|
r = 0x0143
|
|
case r <= 0x0020:
|
|
r = r + 0x0100
|
|
case r >= 0x007f && r <= 0x00a0:
|
|
r = r + 0x00a2
|
|
}
|
|
byteToRune[b] = r
|
|
}
|
|
}
|
|
|
|
// loadSpecialTokenConfig loads special token configuration from HuggingFace companion files.
|
|
//
|
|
// Loading priority for EOS tokens (can be single int or []int):
|
|
// 1. generation_config.json - eos_token_id (preferred, matches HuggingFace generation)
|
|
// 2. config.json - eos_token_id (model config fallback)
|
|
// 3. tokenizer_config.json - eos_token string + add_bos/add_eos flags
|
|
// 4. special_tokens_map.json - final fallback
|
|
func loadSpecialTokenConfig(dir string, t *Tokenizer) {
|
|
// Helper to parse eos_token_id which can be int or []int
|
|
parseTokenIDs := func(v interface{}) []int32 {
|
|
switch val := v.(type) {
|
|
case float64:
|
|
return []int32{int32(val)}
|
|
case []interface{}:
|
|
ids := make([]int32, 0, len(val))
|
|
for _, id := range val {
|
|
if f, ok := id.(float64); ok {
|
|
ids = append(ids, int32(f))
|
|
}
|
|
}
|
|
return ids
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Priority 1: generation_config.json (eos_token_id can be int or []int)
|
|
if data, err := os.ReadFile(dir + "generation_config.json"); err == nil {
|
|
var config struct {
|
|
EOSTokenID interface{} `json:"eos_token_id"`
|
|
BOSTokenID interface{} `json:"bos_token_id"`
|
|
}
|
|
if err := json.Unmarshal(data, &config); err == nil {
|
|
if ids := parseTokenIDs(config.EOSTokenID); len(ids) > 0 {
|
|
t.vocab.EOS = ids
|
|
}
|
|
if ids := parseTokenIDs(config.BOSTokenID); len(ids) > 0 {
|
|
t.vocab.BOS = ids[0]
|
|
}
|
|
}
|
|
}
|
|
|
|
// Priority 2: config.json (model config, same format)
|
|
if len(t.vocab.EOS) == 0 || t.vocab.BOS < 0 {
|
|
if data, err := os.ReadFile(dir + "config.json"); err == nil {
|
|
var config struct {
|
|
EOSTokenID interface{} `json:"eos_token_id"`
|
|
BOSTokenID interface{} `json:"bos_token_id"`
|
|
}
|
|
if err := json.Unmarshal(data, &config); err == nil {
|
|
if len(t.vocab.EOS) == 0 {
|
|
if ids := parseTokenIDs(config.EOSTokenID); len(ids) > 0 {
|
|
t.vocab.EOS = ids
|
|
}
|
|
}
|
|
if t.vocab.BOS < 0 {
|
|
if ids := parseTokenIDs(config.BOSTokenID); len(ids) > 0 {
|
|
t.vocab.BOS = ids[0]
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Priority 3: tokenizer_config.json (token strings + add_bos/add_eos flags)
|
|
if data, err := os.ReadFile(dir + "tokenizer_config.json"); err == nil {
|
|
var config struct {
|
|
BOSToken interface{} `json:"bos_token"`
|
|
EOSToken interface{} `json:"eos_token"`
|
|
PADToken interface{} `json:"pad_token"`
|
|
AddBOSToken *bool `json:"add_bos_token"`
|
|
AddEOSToken *bool `json:"add_eos_token"`
|
|
}
|
|
if err := json.Unmarshal(data, &config); err == nil {
|
|
if t.vocab.BOS < 0 {
|
|
if bosStr := extractTokenString(config.BOSToken); bosStr != "" {
|
|
if id, ok := t.specialTokens[bosStr]; ok {
|
|
t.vocab.BOS = id
|
|
}
|
|
}
|
|
}
|
|
if len(t.vocab.EOS) == 0 {
|
|
if eosStr := extractTokenString(config.EOSToken); eosStr != "" {
|
|
if id, ok := t.specialTokens[eosStr]; ok {
|
|
t.vocab.EOS = []int32{id}
|
|
}
|
|
}
|
|
}
|
|
if t.vocab.PAD < 0 {
|
|
if padStr := extractTokenString(config.PADToken); padStr != "" {
|
|
if id, ok := t.specialTokens[padStr]; ok {
|
|
t.vocab.PAD = id
|
|
}
|
|
}
|
|
}
|
|
if config.AddBOSToken != nil {
|
|
t.vocab.AddBOS = *config.AddBOSToken
|
|
}
|
|
if config.AddEOSToken != nil {
|
|
t.vocab.AddEOS = *config.AddEOSToken
|
|
}
|
|
}
|
|
}
|
|
|
|
// Priority 4: special_tokens_map.json (final fallback)
|
|
if data, err := os.ReadFile(dir + "special_tokens_map.json"); err == nil {
|
|
var tokensMap map[string]interface{}
|
|
if err := json.Unmarshal(data, &tokensMap); err == nil {
|
|
if t.vocab.BOS < 0 {
|
|
if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" {
|
|
if id, ok := t.specialTokens[bosStr]; ok {
|
|
t.vocab.BOS = id
|
|
}
|
|
}
|
|
}
|
|
if len(t.vocab.EOS) == 0 {
|
|
if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" {
|
|
if id, ok := t.specialTokens[eosStr]; ok {
|
|
t.vocab.EOS = []int32{id}
|
|
}
|
|
}
|
|
}
|
|
if t.vocab.PAD < 0 {
|
|
if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" {
|
|
if id, ok := t.specialTokens[padStr]; ok {
|
|
t.vocab.PAD = id
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// extractTokenString extracts the token string from various formats used in HuggingFace configs.
|
|
// Tokens can be represented as:
|
|
// - string: "token"
|
|
// - object: {"content": "token", ...}
|
|
func extractTokenString(v interface{}) string {
|
|
if v == nil {
|
|
return ""
|
|
}
|
|
// Direct string
|
|
if s, ok := v.(string); ok {
|
|
return s
|
|
}
|
|
// Object with content field
|
|
if m, ok := v.(map[string]interface{}); ok {
|
|
if content, ok := m["content"].(string); ok {
|
|
return content
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// rewritePatternForRE2 rewrites HuggingFace pretokenizer regex patterns to be
|
|
// compatible with Go's regexp package (RE2). HuggingFace patterns use PCRE features:
|
|
// - (?!\S) negative lookahead - RE2 doesn't support this
|
|
// - (?i:...) inline case-insensitive groups - RE2 doesn't support this
|
|
//
|
|
// We replace \s+(?!\S)|\s+ with \s+ and fix whitespace boundaries in encodeWithRegex().
|
|
// The lookahead version splits "a b" into ["a", " ", " b"] (space prepended to word).
|
|
// Simple \s+ would give ["a", " ", "b"]. We post-process to match Python's behavior.
|
|
func rewritePatternForRE2(pattern string) string {
|
|
// Replace lookahead pattern with simple \s+ - we fix boundaries in encodeWithRegex()
|
|
pattern = strings.ReplaceAll(pattern, `\s+(?!\S)|\s+`, `\s+`)
|
|
|
|
// Handle the pattern when it appears with a ? suffix (optional contractions in GPT-4o style)
|
|
// IMPORTANT: Must be done before the non-optional version to avoid partial replacement
|
|
pattern = strings.ReplaceAll(pattern,
|
|
`(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
|
|
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?`)
|
|
|
|
// Expand case-insensitive contraction pattern to explicit alternations
|
|
// (?i:'s|'t|'re|'ve|'m|'ll|'d) -> '[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD]
|
|
pattern = strings.ReplaceAll(pattern,
|
|
`(?i:'s|'t|'re|'ve|'m|'ll|'d)`,
|
|
`(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])`)
|
|
|
|
return pattern
|
|
}
|
|
|
|
// LoadFromBytes loads a tokenizer from tokenizer.json bytes.
|
|
// This is useful when loading from blob storage where the file content is already in memory.
|
|
// Note: This won't load special token config from companion files. Use LoadFromBytesWithConfig
|
|
// to provide tokenizer_config.json data for proper PAD/EOS token loading.
|
|
func LoadFromBytes(data []byte) (*Tokenizer, error) {
|
|
return loadFromTokenizerJSON(data, "")
|
|
}
|
|
|
|
// TokenizerConfig holds optional configuration data that can be passed to LoadFromBytesWithConfig.
|
|
type TokenizerConfig struct {
|
|
TokenizerConfigJSON []byte // tokenizer_config.json content
|
|
GenerationConfigJSON []byte // generation_config.json content
|
|
SpecialTokensMapJSON []byte // special_tokens_map.json content
|
|
ConfigJSON []byte // config.json content
|
|
}
|
|
|
|
// LoadFromBytesWithConfig loads a tokenizer from tokenizer.json bytes with additional config files.
|
|
// This is useful when loading from blob storage where companion config files are also blobs.
|
|
func LoadFromBytesWithConfig(data []byte, config *TokenizerConfig) (*Tokenizer, error) {
|
|
t, err := loadFromTokenizerJSON(data, "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if config == nil {
|
|
return t, nil
|
|
}
|
|
|
|
// Apply special token configs from provided data
|
|
loadSpecialTokenConfigFromBytes(t, config)
|
|
|
|
return t, nil
|
|
}
|
|
|
|
// loadSpecialTokenConfigFromBytes loads special token configuration from byte slices.
|
|
func loadSpecialTokenConfigFromBytes(t *Tokenizer, config *TokenizerConfig) {
|
|
// Helper to parse eos_token_id which can be int or []int
|
|
parseTokenIDs := func(v interface{}) []int32 {
|
|
switch val := v.(type) {
|
|
case float64:
|
|
return []int32{int32(val)}
|
|
case []interface{}:
|
|
ids := make([]int32, 0, len(val))
|
|
for _, id := range val {
|
|
if f, ok := id.(float64); ok {
|
|
ids = append(ids, int32(f))
|
|
}
|
|
}
|
|
return ids
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Priority 1: generation_config.json
|
|
if len(config.GenerationConfigJSON) > 0 {
|
|
var genConfig struct {
|
|
EOSTokenID interface{} `json:"eos_token_id"`
|
|
BOSTokenID interface{} `json:"bos_token_id"`
|
|
}
|
|
if err := json.Unmarshal(config.GenerationConfigJSON, &genConfig); err == nil {
|
|
if ids := parseTokenIDs(genConfig.EOSTokenID); len(ids) > 0 {
|
|
t.vocab.EOS = ids
|
|
}
|
|
if ids := parseTokenIDs(genConfig.BOSTokenID); len(ids) > 0 {
|
|
t.vocab.BOS = ids[0]
|
|
}
|
|
}
|
|
}
|
|
|
|
// Priority 2: config.json
|
|
if len(config.ConfigJSON) > 0 && (len(t.vocab.EOS) == 0 || t.vocab.BOS < 0) {
|
|
var modelConfig struct {
|
|
EOSTokenID interface{} `json:"eos_token_id"`
|
|
BOSTokenID interface{} `json:"bos_token_id"`
|
|
}
|
|
if err := json.Unmarshal(config.ConfigJSON, &modelConfig); err == nil {
|
|
if len(t.vocab.EOS) == 0 {
|
|
if ids := parseTokenIDs(modelConfig.EOSTokenID); len(ids) > 0 {
|
|
t.vocab.EOS = ids
|
|
}
|
|
}
|
|
if t.vocab.BOS < 0 {
|
|
if ids := parseTokenIDs(modelConfig.BOSTokenID); len(ids) > 0 {
|
|
t.vocab.BOS = ids[0]
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Priority 3: tokenizer_config.json
|
|
if len(config.TokenizerConfigJSON) > 0 {
|
|
var tokConfig struct {
|
|
BOSToken interface{} `json:"bos_token"`
|
|
EOSToken interface{} `json:"eos_token"`
|
|
PADToken interface{} `json:"pad_token"`
|
|
AddBOSToken *bool `json:"add_bos_token"`
|
|
AddEOSToken *bool `json:"add_eos_token"`
|
|
}
|
|
if err := json.Unmarshal(config.TokenizerConfigJSON, &tokConfig); err == nil {
|
|
if t.vocab.BOS < 0 {
|
|
if bosStr := extractTokenString(tokConfig.BOSToken); bosStr != "" {
|
|
if id, ok := t.specialTokens[bosStr]; ok {
|
|
t.vocab.BOS = id
|
|
}
|
|
}
|
|
}
|
|
if len(t.vocab.EOS) == 0 {
|
|
if eosStr := extractTokenString(tokConfig.EOSToken); eosStr != "" {
|
|
if id, ok := t.specialTokens[eosStr]; ok {
|
|
t.vocab.EOS = []int32{id}
|
|
}
|
|
}
|
|
}
|
|
if t.vocab.PAD < 0 {
|
|
if padStr := extractTokenString(tokConfig.PADToken); padStr != "" {
|
|
if id, ok := t.specialTokens[padStr]; ok {
|
|
t.vocab.PAD = id
|
|
}
|
|
}
|
|
}
|
|
if tokConfig.AddBOSToken != nil {
|
|
t.vocab.AddBOS = *tokConfig.AddBOSToken
|
|
}
|
|
if tokConfig.AddEOSToken != nil {
|
|
t.vocab.AddEOS = *tokConfig.AddEOSToken
|
|
}
|
|
}
|
|
}
|
|
|
|
// Priority 4: special_tokens_map.json
|
|
if len(config.SpecialTokensMapJSON) > 0 {
|
|
var tokensMap map[string]interface{}
|
|
if err := json.Unmarshal(config.SpecialTokensMapJSON, &tokensMap); err == nil {
|
|
if t.vocab.BOS < 0 {
|
|
if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" {
|
|
if id, ok := t.specialTokens[bosStr]; ok {
|
|
t.vocab.BOS = id
|
|
}
|
|
}
|
|
}
|
|
if len(t.vocab.EOS) == 0 {
|
|
if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" {
|
|
if id, ok := t.specialTokens[eosStr]; ok {
|
|
t.vocab.EOS = []int32{id}
|
|
}
|
|
}
|
|
}
|
|
if t.vocab.PAD < 0 {
|
|
if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" {
|
|
if id, ok := t.specialTokens[padStr]; ok {
|
|
t.vocab.PAD = id
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Load loads a tokenizer from a path which can be:
|
|
// - A tokenizer.json file
|
|
// - A directory containing tokenizer.json or vocab.json + merges.txt
|
|
func Load(path string) (*Tokenizer, error) {
|
|
// Check if path is a directory
|
|
if info, err := os.Stat(path); err == nil && info.IsDir() {
|
|
dir := strings.TrimSuffix(path, "/") + "/"
|
|
// Try tokenizer.json first
|
|
if data, err := os.ReadFile(dir + "tokenizer.json"); err == nil {
|
|
return loadFromTokenizerJSON(data, dir)
|
|
}
|
|
// Fall back to vocab.json + merges.txt
|
|
return LoadVocabMerges(path)
|
|
}
|
|
|
|
// It's a file - read it directly
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read tokenizer: %w", err)
|
|
}
|
|
|
|
// Get directory for loading companion files
|
|
dir := ""
|
|
if idx := strings.LastIndex(path, "/"); idx >= 0 {
|
|
dir = path[:idx+1]
|
|
}
|
|
return loadFromTokenizerJSON(data, dir)
|
|
}
|
|
|
|
// loadFromTokenizerJSON parses a tokenizer.json file
|
|
func loadFromTokenizerJSON(data []byte, dir string) (*Tokenizer, error) {
|
|
|
|
var raw struct {
|
|
Model struct {
|
|
Type string `json:"type"` // "BPE" or "WordPiece"
|
|
Vocab map[string]int32 `json:"vocab"`
|
|
Merges json.RawMessage `json:"merges"` // Can be []string or [][]string (BPE only)
|
|
} `json:"model"`
|
|
PreTokenizer json.RawMessage `json:"pre_tokenizer"`
|
|
Decoder json.RawMessage `json:"decoder"`
|
|
AddedTokens []struct {
|
|
ID int32 `json:"id"`
|
|
Content string `json:"content"`
|
|
Special bool `json:"special"`
|
|
} `json:"added_tokens"`
|
|
}
|
|
|
|
if err := json.Unmarshal(data, &raw); err != nil {
|
|
return nil, fmt.Errorf("failed to parse tokenizer: %w", err)
|
|
}
|
|
|
|
// Parse merges - can be []string (Llama) or [][]string (GPT-OSS)
|
|
// WordPiece models don't have merges
|
|
var mergesStrings []string
|
|
if raw.Model.Type != "WordPiece" && raw.Model.Merges != nil {
|
|
var mergesArrays [][]string
|
|
if err := json.Unmarshal(raw.Model.Merges, &mergesStrings); err != nil {
|
|
// Try array of arrays format
|
|
if err := json.Unmarshal(raw.Model.Merges, &mergesArrays); err != nil {
|
|
return nil, fmt.Errorf("failed to parse merges: %w", err)
|
|
}
|
|
// Convert [][]string to []string
|
|
mergesStrings = make([]string, len(mergesArrays))
|
|
for i, pair := range mergesArrays {
|
|
mergesStrings[i] = pair[0] + " " + pair[1]
|
|
}
|
|
}
|
|
}
|
|
|
|
// Build tokenizer
|
|
t := &Tokenizer{
|
|
vocab: &Vocabulary{
|
|
Values: make([]string, len(raw.Model.Vocab)),
|
|
Reverse: raw.Model.Vocab,
|
|
Merges: make(map[string]int, len(mergesStrings)),
|
|
BOS: -1,
|
|
PAD: -1,
|
|
},
|
|
specialTokens: make(map[string]int32),
|
|
}
|
|
|
|
// Build values array
|
|
for token, id := range raw.Model.Vocab {
|
|
if int(id) >= len(t.vocab.Values) {
|
|
newValues := make([]string, id+1)
|
|
copy(newValues, t.vocab.Values)
|
|
t.vocab.Values = newValues
|
|
}
|
|
t.vocab.Values[id] = token
|
|
}
|
|
|
|
// Build merges map
|
|
for i, merge := range mergesStrings {
|
|
t.vocab.Merges[merge] = i
|
|
}
|
|
|
|
// Add all added_tokens to vocabulary and special tokens map.
|
|
// HuggingFace treats ALL added_tokens as special for tokenization purposes -
|
|
// they bypass BPE and get their own token ID. The "special" flag just indicates
|
|
// if it's a "truly special" token like BOS/EOS/PAD, but for tokenization we need
|
|
// to treat all added_tokens as special to match HuggingFace behavior.
|
|
for _, tok := range raw.AddedTokens {
|
|
if int(tok.ID) >= len(t.vocab.Values) {
|
|
newValues := make([]string, tok.ID+1)
|
|
copy(newValues, t.vocab.Values)
|
|
t.vocab.Values = newValues
|
|
}
|
|
t.vocab.Values[tok.ID] = tok.Content
|
|
t.specialTokens[tok.Content] = tok.ID // Add ALL added_tokens to special tokens
|
|
}
|
|
|
|
// Load special token configuration from companion files
|
|
loadSpecialTokenConfig(dir, t)
|
|
|
|
// Precompute byte token IDs for <0xNN> fallback
|
|
initByteTokens(t)
|
|
|
|
// Determine tokenizer type
|
|
switch {
|
|
case raw.Model.Type == "WordPiece":
|
|
t.typ = TokenizerWordPiece
|
|
case detectSentencePiece(raw.Decoder):
|
|
t.typ = TokenizerSentencePiece
|
|
default:
|
|
t.typ = TokenizerBPE
|
|
}
|
|
|
|
// Parse and compile pretokenizer pattern (BPE only - SentencePiece doesn't use pretokenizer)
|
|
if t.typ == TokenizerBPE {
|
|
pattern := extractPretokenizer(raw.PreTokenizer)
|
|
if pattern == "" {
|
|
pattern = `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`
|
|
}
|
|
re, err := regexp.Compile(rewritePatternForRE2(pattern))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to compile pretokenizer regex %q: %w", pattern, err)
|
|
}
|
|
t.pretokenizer = re
|
|
}
|
|
|
|
return t, nil
|
|
}
|
|
|
|
// detectSentencePiece checks if the decoder uses SentencePiece-style (▁ for spaces)
|
|
// vs GPT-2 byte-level encoding
|
|
func detectSentencePiece(data json.RawMessage) bool {
|
|
if data == nil {
|
|
return false
|
|
}
|
|
|
|
// Check for Sequence decoder with Replace step (SentencePiece style)
|
|
var seq struct {
|
|
Type string `json:"type"`
|
|
Decoders []struct {
|
|
Type string `json:"type"`
|
|
Pattern struct {
|
|
String string `json:"String"`
|
|
} `json:"pattern"`
|
|
} `json:"decoders"`
|
|
}
|
|
if err := json.Unmarshal(data, &seq); err == nil {
|
|
if seq.Type == "Sequence" {
|
|
for _, dec := range seq.Decoders {
|
|
// Look for Replace decoder that converts ▁ to space
|
|
if dec.Type == "Replace" && dec.Pattern.String == "▁" {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check for direct ByteLevel decoder (GPT-2 style)
|
|
var simple struct {
|
|
Type string `json:"type"`
|
|
}
|
|
if err := json.Unmarshal(data, &simple); err == nil {
|
|
if simple.Type == "ByteLevel" {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// initByteTokens precomputes byte token IDs for <0xNN> fallback encoding
|
|
func initByteTokens(t *Tokenizer) {
|
|
for i := range t.vocab.byteTokens {
|
|
t.vocab.byteTokens[i] = -1
|
|
}
|
|
for b := 0; b < 256; b++ {
|
|
token := fmt.Sprintf("<0x%02X>", b)
|
|
if id, ok := t.vocab.Reverse[token]; ok {
|
|
t.vocab.byteTokens[b] = id
|
|
}
|
|
}
|
|
}
|
|
|
|
// extractPretokenizer extracts the regex pattern from the pre_tokenizer config
|
|
func extractPretokenizer(data json.RawMessage) string {
|
|
if data == nil {
|
|
return ""
|
|
}
|
|
|
|
// Try to parse as a single Split pretokenizer
|
|
var single struct {
|
|
Type string `json:"type"`
|
|
Pattern struct {
|
|
Regex string `json:"Regex"`
|
|
} `json:"pattern"`
|
|
}
|
|
if err := json.Unmarshal(data, &single); err == nil && single.Pattern.Regex != "" {
|
|
return single.Pattern.Regex
|
|
}
|
|
|
|
// Try to parse as Sequence of pretokenizers - use first Split pattern
|
|
var seq struct {
|
|
Type string `json:"type"`
|
|
Pretokenizers []struct {
|
|
Type string `json:"type"`
|
|
Pattern struct {
|
|
Regex string `json:"Regex"`
|
|
} `json:"pattern"`
|
|
} `json:"pretokenizers"`
|
|
}
|
|
if err := json.Unmarshal(data, &seq); err == nil && seq.Type == "Sequence" {
|
|
for _, pt := range seq.Pretokenizers {
|
|
if pt.Type == "Split" && pt.Pattern.Regex != "" {
|
|
return pt.Pattern.Regex
|
|
}
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// isNonNewlineWhitespace returns true if s contains only whitespace characters (no newlines)
|
|
func isNonNewlineWhitespace(s string) bool {
|
|
if s == "" {
|
|
return false
|
|
}
|
|
for _, r := range s {
|
|
if r == '\n' || r == '\r' {
|
|
return false
|
|
}
|
|
if !unicode.IsSpace(r) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// splitBySpecialTokens splits text into parts, keeping special tokens as separate elements
|
|
func (t *Tokenizer) splitBySpecialTokens(s string) []string {
|
|
if len(t.specialTokens) == 0 {
|
|
return []string{s}
|
|
}
|
|
|
|
// Sort special tokens by length (longest first) to match greedily
|
|
tokens := make([]string, 0, len(t.specialTokens))
|
|
for tok := range t.specialTokens {
|
|
tokens = append(tokens, tok)
|
|
}
|
|
sort.Slice(tokens, func(i, j int) bool {
|
|
return len(tokens[i]) > len(tokens[j])
|
|
})
|
|
|
|
var result []string
|
|
remaining := s
|
|
|
|
for len(remaining) > 0 {
|
|
found := false
|
|
for _, tok := range tokens {
|
|
if strings.HasPrefix(remaining, tok) {
|
|
result = append(result, tok)
|
|
remaining = remaining[len(tok):]
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
// Find next special token position
|
|
nextPos := len(remaining)
|
|
for _, tok := range tokens {
|
|
if idx := strings.Index(remaining, tok); idx != -1 && idx < nextPos {
|
|
nextPos = idx
|
|
}
|
|
}
|
|
if nextPos > 0 {
|
|
result = append(result, remaining[:nextPos])
|
|
}
|
|
remaining = remaining[nextPos:]
|
|
}
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// Encode tokenizes text to token IDs. Parallelizes for large inputs (>10KB).
|
|
func (t *Tokenizer) Encode(s string, addBOS bool) []int32 {
|
|
// First: split by special tokens
|
|
parts := t.splitBySpecialTokens(s)
|
|
|
|
// Second: collect all pretokenizer chunks
|
|
type chunk struct {
|
|
text string
|
|
isSpecial bool
|
|
}
|
|
var allChunks []chunk
|
|
|
|
if t.pretokenizer != nil {
|
|
re := t.pretokenizer
|
|
for _, part := range parts {
|
|
if _, ok := t.specialTokens[part]; ok {
|
|
allChunks = append(allChunks, chunk{part, true})
|
|
continue
|
|
}
|
|
|
|
// Split by pretokenizer regex
|
|
type match struct{ start, end int }
|
|
var matches []match
|
|
offset := 0
|
|
for offset < len(part) {
|
|
loc := re.FindStringIndex(part[offset:])
|
|
if loc == nil {
|
|
break
|
|
}
|
|
matches = append(matches, match{offset + loc[0], offset + loc[1]})
|
|
offset += loc[1]
|
|
}
|
|
|
|
// Apply whitespace boundary fix for Python regex compatibility
|
|
for i := 0; i < len(matches)-1; i++ {
|
|
m := part[matches[i].start:matches[i].end]
|
|
next := part[matches[i+1].start:matches[i+1].end]
|
|
|
|
if isNonNewlineWhitespace(m) && len(next) > 0 {
|
|
firstRune, _ := utf8.DecodeRuneInString(next)
|
|
if unicode.IsLetter(firstRune) {
|
|
lastSpaceStart := matches[i].end
|
|
for j := matches[i].end; j > matches[i].start; {
|
|
r, size := utf8.DecodeLastRuneInString(part[matches[i].start:j])
|
|
if unicode.IsSpace(r) {
|
|
lastSpaceStart = j - size
|
|
break
|
|
}
|
|
j -= size
|
|
}
|
|
if lastSpaceStart > matches[i].start {
|
|
matches[i].end = lastSpaceStart
|
|
matches[i+1].start = lastSpaceStart
|
|
} else {
|
|
matches[i+1].start = matches[i].start
|
|
matches[i].end = matches[i].start
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for _, m := range matches {
|
|
if m.end > m.start {
|
|
allChunks = append(allChunks, chunk{part[m.start:m.end], false})
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
// No pretokenizer - treat each part as a single chunk
|
|
for _, part := range parts {
|
|
if _, ok := t.specialTokens[part]; ok {
|
|
allChunks = append(allChunks, chunk{part, true})
|
|
} else {
|
|
allChunks = append(allChunks, chunk{part, false})
|
|
}
|
|
}
|
|
}
|
|
|
|
// Encode chunks - parallel for large inputs (>4KB), sequential otherwise
|
|
var ids []int32
|
|
if len(s) < 4096 {
|
|
for _, c := range allChunks {
|
|
if c.isSpecial {
|
|
if id, ok := t.specialTokens[c.text]; ok {
|
|
ids = append(ids, id)
|
|
}
|
|
} else {
|
|
ids = t.encodeChunkInto(c.text, ids)
|
|
}
|
|
}
|
|
} else {
|
|
numWorkers := runtime.GOMAXPROCS(0)
|
|
if numWorkers > len(allChunks) {
|
|
numWorkers = len(allChunks)
|
|
}
|
|
|
|
chunksPer := (len(allChunks) + numWorkers - 1) / numWorkers
|
|
results := make([][]int32, numWorkers)
|
|
var wg sync.WaitGroup
|
|
|
|
for i := 0; i < numWorkers; i++ {
|
|
start := i * chunksPer
|
|
end := start + chunksPer
|
|
if end > len(allChunks) {
|
|
end = len(allChunks)
|
|
}
|
|
if start >= end {
|
|
continue
|
|
}
|
|
|
|
wg.Add(1)
|
|
go func(i int, chunks []chunk) {
|
|
defer wg.Done()
|
|
var r []int32
|
|
for _, c := range chunks {
|
|
if c.isSpecial {
|
|
if id, ok := t.specialTokens[c.text]; ok {
|
|
r = append(r, id)
|
|
}
|
|
} else {
|
|
r = t.encodeChunkInto(c.text, r)
|
|
}
|
|
}
|
|
results[i] = r
|
|
}(i, allChunks[start:end])
|
|
}
|
|
wg.Wait()
|
|
|
|
for _, r := range results {
|
|
ids = append(ids, r...)
|
|
}
|
|
}
|
|
|
|
if addBOS && t.vocab.BOS >= 0 {
|
|
ids = append([]int32{t.vocab.BOS}, ids...)
|
|
}
|
|
return ids
|
|
}
|
|
|
|
// encodeChunkInto appends encoded tokens to ids and returns the extended slice
|
|
// Uses BPE merge algorithm when merges are available, otherwise longest-match
|
|
func (t *Tokenizer) encodeChunkInto(s string, ids []int32) []int32 {
|
|
if t.typ == TokenizerWordPiece {
|
|
return t.encodeWordPieceInto(s, ids)
|
|
}
|
|
|
|
if s == "" {
|
|
return ids
|
|
}
|
|
|
|
// Apply encoding transformation
|
|
// SentencePiece: replace space with ▁
|
|
// BPE: convert bytes using precomputed table (GPT-2 byte-level encoding)
|
|
var encoded string
|
|
if t.typ == TokenizerSentencePiece {
|
|
encoded = strings.ReplaceAll(s, " ", "▁")
|
|
} else {
|
|
var sb strings.Builder
|
|
sb.Grow(len(s) * 2)
|
|
for i := 0; i < len(s); i++ {
|
|
sb.WriteRune(byteToRune[s[i]])
|
|
}
|
|
encoded = sb.String()
|
|
}
|
|
|
|
// Fast path: check if entire chunk is a single token
|
|
if id, ok := t.vocab.Reverse[encoded]; ok {
|
|
return append(ids, id)
|
|
}
|
|
|
|
return t.encodeBPEMerge(encoded, ids)
|
|
}
|
|
|
|
// encodeBPEMerge encodes using BPE merge algorithm.
|
|
// Repeatedly merges the pair with lowest rank until no more merges possible.
|
|
// Works correctly with empty merges (falls back to individual rune/byte encoding).
|
|
func (t *Tokenizer) encodeBPEMerge(encoded string, ids []int32) []int32 {
|
|
// Start with individual runes as parts
|
|
runes := []rune(encoded)
|
|
parts := make([]string, len(runes))
|
|
for i, r := range runes {
|
|
parts[i] = string(r)
|
|
}
|
|
|
|
// Repeatedly merge lowest-rank pair
|
|
for len(parts) > 1 {
|
|
minRank := int(0x7FFFFFFF)
|
|
minIdx := -1
|
|
|
|
for i := 0; i < len(parts)-1; i++ {
|
|
// Merge key format: "token1 token2" (space-separated)
|
|
mergeKey := parts[i] + " " + parts[i+1]
|
|
if rank, ok := t.vocab.Merges[mergeKey]; ok {
|
|
if rank < minRank {
|
|
minRank = rank
|
|
minIdx = i
|
|
}
|
|
}
|
|
}
|
|
|
|
if minIdx < 0 {
|
|
break // No more merges possible
|
|
}
|
|
|
|
// Merge the pair
|
|
parts[minIdx] = parts[minIdx] + parts[minIdx+1]
|
|
parts = append(parts[:minIdx+1], parts[minIdx+2:]...)
|
|
}
|
|
|
|
// Convert parts to token IDs
|
|
for _, part := range parts {
|
|
if id, ok := t.vocab.Reverse[part]; ok {
|
|
ids = append(ids, id)
|
|
} else {
|
|
// Byte fallback for unknown parts
|
|
for _, b := range []byte(part) {
|
|
if id := t.vocab.byteTokens[b]; id >= 0 {
|
|
ids = append(ids, id)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return ids
|
|
}
|
|
|
|
// encodeWordPieceInto appends WordPiece tokens to ids and returns extended slice
|
|
// Uses greedy longest-match with ## prefix for continuation tokens
|
|
func (t *Tokenizer) encodeWordPieceInto(s string, ids []int32) []int32 {
|
|
if s == "" {
|
|
return ids
|
|
}
|
|
|
|
// Check if entire string is in vocabulary (common case)
|
|
if id, ok := t.vocab.Reverse[s]; ok {
|
|
return append(ids, id)
|
|
}
|
|
|
|
runes := []rune(s)
|
|
start := 0
|
|
|
|
for start < len(runes) {
|
|
end := len(runes)
|
|
found := false
|
|
|
|
// Greedy longest-match
|
|
for end > start {
|
|
substr := string(runes[start:end])
|
|
if start > 0 {
|
|
// Continuation token: prefix with ##
|
|
substr = "##" + substr
|
|
}
|
|
|
|
if id, ok := t.vocab.Reverse[substr]; ok {
|
|
ids = append(ids, id)
|
|
found = true
|
|
start = end
|
|
break
|
|
}
|
|
end--
|
|
}
|
|
|
|
if !found {
|
|
// No match found - use [UNK] token or skip
|
|
if t.unkToken >= 0 {
|
|
ids = append(ids, t.unkToken)
|
|
}
|
|
start++
|
|
}
|
|
}
|
|
|
|
return ids
|
|
}
|
|
|
|
// Decode converts token IDs back to text
|
|
func (t *Tokenizer) Decode(ids []int32) string {
|
|
var sb strings.Builder
|
|
|
|
for _, id := range ids {
|
|
if int(id) >= len(t.vocab.Values) {
|
|
continue
|
|
}
|
|
|
|
token := t.vocab.Values[id]
|
|
|
|
switch t.typ {
|
|
case TokenizerWordPiece:
|
|
// WordPiece style: strip ## prefix from continuation tokens
|
|
if strings.HasPrefix(token, "##") {
|
|
sb.WriteString(token[2:])
|
|
} else {
|
|
sb.WriteString(token)
|
|
}
|
|
case TokenizerSentencePiece:
|
|
// SentencePiece style: replace ▁ with space, decode byte tokens
|
|
token = strings.ReplaceAll(token, "▁", " ")
|
|
// Handle byte fallback tokens like <0x0D>
|
|
if len(token) == 6 && token[0] == '<' && token[1] == '0' && token[2] == 'x' && token[5] == '>' {
|
|
if v, err := strconv.ParseUint(token[3:5], 16, 8); err == nil {
|
|
sb.WriteByte(byte(v))
|
|
continue
|
|
}
|
|
}
|
|
sb.WriteString(token)
|
|
default:
|
|
// GPT-2 BPE style: decode byte-level encoding
|
|
for _, r := range token {
|
|
switch {
|
|
case r == 0x0100:
|
|
// NULL byte (0x00 encoded as 0x0100)
|
|
sb.WriteByte(0)
|
|
continue
|
|
case r == 0x0143:
|
|
r = 0x00ad
|
|
case r > 0x0100 && r <= 0x0120:
|
|
r = r - 0x0100
|
|
case r > 0x0120 && r <= 0x0142:
|
|
r = r - 0x00a2
|
|
}
|
|
|
|
// Write as byte, not UTF-8 encoded rune
|
|
sb.WriteByte(byte(r))
|
|
}
|
|
}
|
|
}
|
|
|
|
return sb.String()
|
|
}
|
|
|
|
// VocabSize returns the vocabulary size
|
|
func (t *Tokenizer) VocabSize() int {
|
|
return len(t.vocab.Values)
|
|
}
|
|
|
|
// BOS returns the beginning of sequence token ID
|
|
func (t *Tokenizer) BOS() int32 {
|
|
return t.vocab.BOS
|
|
}
|
|
|
|
// EOS returns the first end of sequence token ID (for backwards compatibility)
|
|
func (t *Tokenizer) EOS() int32 {
|
|
if len(t.vocab.EOS) > 0 {
|
|
return t.vocab.EOS[0]
|
|
}
|
|
return -1
|
|
}
|
|
|
|
// EOSTokens returns all end of sequence token IDs
|
|
func (t *Tokenizer) EOSTokens() []int32 {
|
|
return t.vocab.EOS
|
|
}
|
|
|
|
// PAD returns the padding token ID, or -1 if not set
|
|
func (t *Tokenizer) PAD() int32 {
|
|
return t.vocab.PAD
|
|
}
|
|
|
|
// IsEOS returns true if the token ID is an end of sequence token
|
|
func (t *Tokenizer) IsEOS(id int32) bool {
|
|
for _, eos := range t.vocab.EOS {
|
|
if id == eos {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// GetSpecialToken returns the token ID for a special token string
|
|
func (t *Tokenizer) GetSpecialToken(name string) (int32, bool) {
|
|
id, ok := t.specialTokens[name]
|
|
return id, ok
|
|
}
|
|
|
|
// LoadVocabMerges loads a tokenizer from vocab.json + merges.txt format (GPT-style)
|
|
func LoadVocabMerges(dir string) (*Tokenizer, error) {
|
|
vocabPath := dir + "/vocab.json"
|
|
mergesPath := dir + "/merges.txt"
|
|
addedTokensPath := dir + "/added_tokens.json"
|
|
|
|
// Load vocab
|
|
vocabData, err := os.ReadFile(vocabPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read vocab.json: %w", err)
|
|
}
|
|
|
|
vocabMap := make(map[string]int32)
|
|
if err := json.Unmarshal(vocabData, &vocabMap); err != nil {
|
|
return nil, fmt.Errorf("failed to parse vocab.json: %w", err)
|
|
}
|
|
|
|
// Load merges
|
|
mergesData, err := os.ReadFile(mergesPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read merges.txt: %w", err)
|
|
}
|
|
|
|
mergesLines := strings.Split(string(mergesData), "\n")
|
|
var mergesStrings []string
|
|
for _, line := range mergesLines {
|
|
line = strings.TrimSpace(line)
|
|
if line == "" || strings.HasPrefix(line, "#") {
|
|
continue
|
|
}
|
|
mergesStrings = append(mergesStrings, line)
|
|
}
|
|
|
|
// Build tokenizer
|
|
t := &Tokenizer{
|
|
vocab: &Vocabulary{
|
|
Values: make([]string, len(vocabMap)),
|
|
Reverse: vocabMap,
|
|
Merges: make(map[string]int, len(mergesStrings)),
|
|
BOS: -1,
|
|
PAD: -1,
|
|
},
|
|
specialTokens: make(map[string]int32),
|
|
}
|
|
|
|
// Load added tokens if exists
|
|
if addedData, err := os.ReadFile(addedTokensPath); err == nil {
|
|
addedMap := make(map[string]int32)
|
|
if err := json.Unmarshal(addedData, &addedMap); err == nil {
|
|
for token, id := range addedMap {
|
|
vocabMap[token] = id
|
|
t.specialTokens[token] = id
|
|
}
|
|
}
|
|
}
|
|
|
|
// Build values array
|
|
for token, id := range vocabMap {
|
|
if int(id) >= len(t.vocab.Values) {
|
|
newValues := make([]string, id+1)
|
|
copy(newValues, t.vocab.Values)
|
|
t.vocab.Values = newValues
|
|
}
|
|
t.vocab.Values[id] = token
|
|
}
|
|
|
|
// Build merges map
|
|
for i, merge := range mergesStrings {
|
|
t.vocab.Merges[merge] = i
|
|
}
|
|
|
|
// Load special token configuration from companion files
|
|
loadSpecialTokenConfig(dir+"/", t)
|
|
|
|
// Precompute byte token IDs for <0xNN> fallback
|
|
initByteTokens(t)
|
|
|
|
// GPT-2/tiktoken pretokenizer pattern
|
|
pattern := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
|
|
re, err := regexp.Compile(rewritePatternForRE2(pattern))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to compile pretokenizer regex: %w", err)
|
|
}
|
|
t.pretokenizer = re
|
|
|
|
return t, nil
|
|
}
|