mirror of
https://github.com/ollama/ollama.git
synced 2026-04-25 18:25:42 +02:00
106 lines
2.6 KiB
Go
106 lines
2.6 KiB
Go
//go:build mlx
|
|
|
|
package grammar
|
|
|
|
import "unicode/utf8"
|
|
|
|
// terminalType distinguishes different kinds of grammar terminals
|
|
type terminalType int
|
|
|
|
const (
|
|
terminalLiteral terminalType = iota // Exact string: "true", "{"
|
|
terminalRange // Character range: [a-z], [0-9]
|
|
)
|
|
|
|
// terminal represents a compiled grammar terminal
|
|
type terminal struct {
|
|
ID int
|
|
Type terminalType
|
|
Pattern string // Original pattern from grammar
|
|
Unescaped string // Unescaped literal (for terminalLiteral)
|
|
LowRune rune // For unicode ranges: low bound
|
|
HighRune rune // For unicode ranges: high bound
|
|
}
|
|
|
|
// terminalMatch represents a terminal that matched at a position
|
|
type terminalMatch struct {
|
|
TerminalID int
|
|
Length int // Number of bytes consumed
|
|
}
|
|
|
|
// trieNode is a node in the literal matching trie
|
|
type trieNode struct {
|
|
children [256]*trieNode // Byte-indexed children
|
|
terminalID int // -1 if not accepting, else terminal ID
|
|
}
|
|
|
|
// terminalMatcher tests which terminals match at a position in a byte slice
|
|
type terminalMatcher struct {
|
|
// Trie for literal matching (fast path)
|
|
literalTrie *trieNode
|
|
|
|
// Range terminals (single-byte matches)
|
|
ranges []terminal
|
|
|
|
// All terminals for enumeration
|
|
terminals []terminal
|
|
|
|
// Pattern to terminal ID map for fast lookup (keyed by raw pattern)
|
|
patternToID map[string]int
|
|
}
|
|
|
|
// addLiteralToTrie adds a literal pattern to the trie
|
|
func (m *terminalMatcher) addLiteralToTrie(pattern string, terminalID int) {
|
|
node := m.literalTrie
|
|
for i := 0; i < len(pattern); i++ {
|
|
c := pattern[i]
|
|
if node.children[c] == nil {
|
|
node.children[c] = &trieNode{terminalID: -1}
|
|
}
|
|
node = node.children[c]
|
|
}
|
|
node.terminalID = terminalID
|
|
}
|
|
|
|
// matchesAt returns all terminals that match at pos in data
|
|
func (m *terminalMatcher) matchesAt(data []byte, pos int) []terminalMatch {
|
|
if pos >= len(data) {
|
|
return nil
|
|
}
|
|
|
|
var matches []terminalMatch
|
|
|
|
// Check literal matches via trie
|
|
node := m.literalTrie
|
|
for i := pos; i < len(data) && node != nil; i++ {
|
|
c := data[i]
|
|
node = node.children[c]
|
|
if node != nil && node.terminalID >= 0 {
|
|
matches = append(matches, terminalMatch{
|
|
TerminalID: node.terminalID,
|
|
Length: i - pos + 1,
|
|
})
|
|
}
|
|
}
|
|
|
|
// Check range matches (unicode-aware)
|
|
r, runeLen := utf8.DecodeRune(data[pos:])
|
|
if r != utf8.RuneError {
|
|
for _, rng := range m.ranges {
|
|
if r >= rng.LowRune && r <= rng.HighRune {
|
|
matches = append(matches, terminalMatch{
|
|
TerminalID: rng.ID,
|
|
Length: runeLen,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
return matches
|
|
}
|
|
|
|
// terminalCount returns the number of terminals
|
|
func (m *terminalMatcher) terminalCount() int {
|
|
return len(m.terminals)
|
|
}
|