mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 00:54:05 +02:00
tokenizer: add SentencePiece-style BPE support (#15162)
* tokenizer: add SentencePiece-style BPE support Add WithSentencePieceNormalizer option to BytePairEncoding for models that use BPE with SentencePiece-style space markers (space to/from U+2581). NewBytePairEncoding is unchanged; the new NewBytePairEncodingWithOptions constructor accepts BPEOption functions. Decoding handles the reverse mapping of U+2581 back to spaces. * review comments
This commit is contained in:
@@ -14,20 +14,40 @@ import (
|
||||
)
|
||||
|
||||
type BytePairEncoding struct {
|
||||
vocab *Vocabulary
|
||||
regexps []*regexp2.Regexp
|
||||
vocab *Vocabulary
|
||||
regexps []*regexp2.Regexp
|
||||
spaceToSpmSep bool // When true, normalize spaces to ▁ instead of GPT-2 byte-level encoding
|
||||
}
|
||||
|
||||
var _ Tokenizer = (*BytePairEncoding)(nil)
|
||||
|
||||
// BPEOption configures BytePairEncoding behavior
|
||||
type BPEOption func(*BytePairEncoding)
|
||||
|
||||
// WithSentencePieceNormalizer enables ▁ space normalization instead of GPT-2 byte-level encoding.
|
||||
func WithSentencePieceNormalizer() BPEOption {
|
||||
return func(bpe *BytePairEncoding) {
|
||||
bpe.spaceToSpmSep = true
|
||||
}
|
||||
}
|
||||
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEncoding {
|
||||
return newBytePairEncoding(vocab, pretokenizer)
|
||||
}
|
||||
|
||||
func NewBytePairEncodingWithOptions(vocab *Vocabulary, pretokenizer []string, opts ...BPEOption) BytePairEncoding {
|
||||
bpe := newBytePairEncoding(vocab, pretokenizer, opts...)
|
||||
return bpe
|
||||
}
|
||||
|
||||
func newBytePairEncoding(vocab *Vocabulary, pretokenizer []string, opts ...BPEOption) BytePairEncoding {
|
||||
if len(pretokenizer) == 0 {
|
||||
// set default byte-level pretokenizer if none provided, e.g.
|
||||
// https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
|
||||
pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
}
|
||||
|
||||
return BytePairEncoding{
|
||||
bpe := BytePairEncoding{
|
||||
vocab: vocab,
|
||||
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
||||
for _, p := range pretokenizer {
|
||||
@@ -37,6 +57,12 @@ func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEnco
|
||||
}
|
||||
}),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&bpe)
|
||||
}
|
||||
|
||||
return bpe
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
|
||||
@@ -136,28 +162,35 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
|
||||
for split := range bpe.split(frag.value) {
|
||||
// TODO: process splits concurrently
|
||||
var sb strings.Builder
|
||||
for _, b := range []byte(split) {
|
||||
r := rune(b)
|
||||
switch {
|
||||
case r == 0x00ad:
|
||||
r = 0x0143
|
||||
case r <= 0x0020:
|
||||
r = r + 0x0100
|
||||
case r >= 0x007f && r <= 0x00a0:
|
||||
r = r + 0x00a2
|
||||
var normalized string
|
||||
if bpe.spaceToSpmSep {
|
||||
// SentencePiece-style: replace spaces with ▁
|
||||
normalized = strings.ReplaceAll(split, " ", spmWhitespaceSep)
|
||||
} else {
|
||||
// GPT-2 byte-level: map bytes to shifted Unicode codepoints
|
||||
var sb strings.Builder
|
||||
for _, b := range []byte(split) {
|
||||
r := rune(b)
|
||||
switch {
|
||||
case r == 0x00ad:
|
||||
r = 0x0143
|
||||
case r <= 0x0020:
|
||||
r = r + 0x0100
|
||||
case r >= 0x007f && r <= 0x00a0:
|
||||
r = r + 0x00a2
|
||||
}
|
||||
sb.WriteRune(r)
|
||||
}
|
||||
|
||||
sb.WriteRune(r)
|
||||
normalized = sb.String()
|
||||
}
|
||||
|
||||
// short circuit if the fragment is in the vocabulary
|
||||
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
|
||||
if id := bpe.vocab.Encode(normalized); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
continue
|
||||
}
|
||||
|
||||
runes := []rune(sb.String())
|
||||
runes := []rune(normalized)
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
@@ -257,6 +290,8 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
for _, r := range bpe.vocab.Decode(id) {
|
||||
// GPT-2 byte-level BPE uses Unicode chars in the 0x0100-0x0143
|
||||
// range to represent bytes. Remap them back to actual bytes.
|
||||
switch {
|
||||
case r == 0x0100:
|
||||
// this produces 0x00 aka NULL
|
||||
@@ -267,6 +302,15 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
r = r - 0x0100
|
||||
case r > 0x0120 && r <= 0x0142:
|
||||
r = r - 0x00a2
|
||||
case r > 0x0143:
|
||||
// Non-GPT2 rune (e.g., SentencePiece-style BPE).
|
||||
// Handle ▁ as word separator, otherwise write the rune as-is.
|
||||
if r == 0x2581 { // ▁ (LOWER ONE EIGHTH BLOCK)
|
||||
sb.WriteByte(' ')
|
||||
} else {
|
||||
sb.WriteRune(r)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// NOTE: not using WriteRune here because it writes the UTF-8
|
||||
|
||||
Reference in New Issue
Block a user