mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 11:54:36 +02:00
* prefer rocm v6 on windows Avoid building with v7 - more changes are needed * MLX: add header vendoring and remove go build tag This switches to using a vendoring approach for the mlx-c headers so that Go can build without requiring a cmake first. This enables building the new MLX based code by default. Every time cmake runs, the headers are refreshed, so we can easily keep them in sync when we bump mlx versions. Basic Windows and Linux support are verified. * ci: harden for flaky choco repo servers CI sometimes fails due to choco not actually installing cache. Since it just speeds up the build, we can proceed without. * review comments
206 lines
5.1 KiB
Go
206 lines
5.1 KiB
Go
package tokenizer
|
||
|
||
import (
|
||
"bufio"
|
||
"encoding/json"
|
||
"os"
|
||
"path/filepath"
|
||
"runtime"
|
||
"strings"
|
||
"testing"
|
||
)
|
||
|
||
func llama32GGMLFixturePath(tb testing.TB, file string) string {
|
||
tb.Helper()
|
||
|
||
_, filename, _, ok := runtime.Caller(0)
|
||
if !ok {
|
||
tb.Fatal("failed to resolve test file path")
|
||
}
|
||
|
||
return filepath.Join(filepath.Dir(filename), "..", "..", "tokenizer", "testdata", "llama3.2", file)
|
||
}
|
||
|
||
func loadLlama32FromGGMLFixture(tb testing.TB) *Tokenizer {
|
||
tb.Helper()
|
||
|
||
f, err := os.Open(llama32GGMLFixturePath(tb, "encoder.json"))
|
||
if err != nil {
|
||
tb.Fatalf("failed to open encoder.json: %v", err)
|
||
}
|
||
defer f.Close()
|
||
|
||
vocab := make(map[string]int32)
|
||
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
|
||
tb.Fatalf("failed to decode encoder.json: %v", err)
|
||
}
|
||
|
||
type addedToken struct {
|
||
ID int32 `json:"id"`
|
||
Content string `json:"content"`
|
||
Special bool `json:"special"`
|
||
}
|
||
var addedTokens []addedToken
|
||
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
|
||
if _, ok := vocab[token]; !ok {
|
||
id := int32(len(vocab))
|
||
vocab[token] = id
|
||
addedTokens = append(addedTokens, addedToken{ID: id, Content: token, Special: true})
|
||
}
|
||
}
|
||
|
||
mf, err := os.Open(llama32GGMLFixturePath(tb, "vocab.bpe"))
|
||
if err != nil {
|
||
tb.Fatalf("failed to open vocab.bpe: %v", err)
|
||
}
|
||
defer mf.Close()
|
||
|
||
var merges []string
|
||
scanner := bufio.NewScanner(mf)
|
||
for scanner.Scan() {
|
||
line := scanner.Text()
|
||
if strings.HasPrefix(line, "#") {
|
||
continue
|
||
}
|
||
line = strings.TrimSpace(line)
|
||
if line != "" {
|
||
merges = append(merges, line)
|
||
}
|
||
}
|
||
if err := scanner.Err(); err != nil {
|
||
tb.Fatalf("failed to read vocab.bpe: %v", err)
|
||
}
|
||
|
||
payload := struct {
|
||
Model struct {
|
||
Type string `json:"type"`
|
||
Vocab map[string]int32 `json:"vocab"`
|
||
Merges []string `json:"merges"`
|
||
} `json:"model"`
|
||
PreTokenizer struct {
|
||
Type string `json:"type"`
|
||
Pretokenizers []struct {
|
||
Type string `json:"type"`
|
||
Pattern struct {
|
||
Regex string `json:"Regex"`
|
||
} `json:"pattern"`
|
||
} `json:"pretokenizers"`
|
||
} `json:"pre_tokenizer"`
|
||
AddedTokens []addedToken `json:"added_tokens"`
|
||
}{}
|
||
|
||
payload.Model.Type = "BPE"
|
||
payload.Model.Vocab = vocab
|
||
payload.Model.Merges = merges
|
||
payload.PreTokenizer.Type = "Sequence"
|
||
payload.PreTokenizer.Pretokenizers = []struct {
|
||
Type string `json:"type"`
|
||
Pattern struct {
|
||
Regex string `json:"Regex"`
|
||
} `json:"pattern"`
|
||
}{
|
||
{
|
||
Type: "Split",
|
||
Pattern: struct {
|
||
Regex string `json:"Regex"`
|
||
}{
|
||
Regex: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||
},
|
||
},
|
||
}
|
||
payload.AddedTokens = addedTokens
|
||
|
||
data, err := json.Marshal(payload)
|
||
if err != nil {
|
||
tb.Fatalf("failed to marshal synthetic tokenizer.json: %v", err)
|
||
}
|
||
|
||
tok, err := LoadFromBytes(data)
|
||
if err != nil {
|
||
tb.Fatalf("failed to load tokenizer from fixture data: %v", err)
|
||
}
|
||
return tok
|
||
}
|
||
|
||
func TestGGMLLlamaKnownEncodings(t *testing.T) {
|
||
tok := loadLlama32FromGGMLFixture(t)
|
||
|
||
cases := map[string][]int32{
|
||
"hello world": {15339, 1917},
|
||
"hello <|end_of_text|>": {15339, 220, 128001},
|
||
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
|
||
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
|
||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
|
||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
|
||
}
|
||
|
||
for input, want := range cases {
|
||
got := tok.Encode(input, false)
|
||
if !equalIDs(got, want) {
|
||
t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestGGMLLlamaRepeatedZeros(t *testing.T) {
|
||
tok := loadLlama32FromGGMLFixture(t)
|
||
|
||
cases := map[int][]int32{
|
||
1: {15},
|
||
2: {410},
|
||
3: {931},
|
||
4: {931, 15},
|
||
5: {931, 410},
|
||
6: {931, 931},
|
||
7: {931, 931, 15},
|
||
8: {931, 931, 410},
|
||
9: {931, 931, 931},
|
||
10: {931, 931, 931, 15},
|
||
11: {931, 931, 931, 410},
|
||
12: {931, 931, 931, 931},
|
||
13: {931, 931, 931, 931, 15},
|
||
14: {931, 931, 931, 931, 410},
|
||
15: {931, 931, 931, 931, 931},
|
||
16: {931, 931, 931, 931, 931, 15},
|
||
17: {931, 931, 931, 931, 931, 410},
|
||
}
|
||
|
||
for n, want := range cases {
|
||
input := strings.Repeat("0", n)
|
||
got := tok.Encode(input, false)
|
||
if !equalIDs(got, want) {
|
||
t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestGGMLLlamaRoundtripAndByteBehavior(t *testing.T) {
|
||
tok := loadLlama32FromGGMLFixture(t)
|
||
|
||
cases := []string{
|
||
"hello",
|
||
"hello ",
|
||
"hello ",
|
||
" hello",
|
||
" hello ",
|
||
" hello ",
|
||
"hello world",
|
||
"请考试我的软件!12345",
|
||
}
|
||
|
||
for _, input := range cases {
|
||
ids := tok.Encode(input, false)
|
||
got := tok.Decode(ids)
|
||
if got != input {
|
||
t.Fatalf("roundtrip mismatch for %q: got %q", input, got)
|
||
}
|
||
}
|
||
|
||
// Match GGML tokenizer behavior: 0x00 is omitted when decoding.
|
||
ids := tok.Encode(string(rune(0x00)), false)
|
||
got := tok.Decode(ids)
|
||
if got != "" {
|
||
t.Fatalf("expected empty decode for 0x00, got %q (ids=%v)", got, ids)
|
||
}
|
||
}
|