mirror of
https://github.com/ollama/ollama.git
synced 2026-04-21 16:25:42 +02:00
x/grammar: add experimental GPU accelerated constrained decoding package
This commit is contained in:
726
x/grammar/schema/schema.go
Normal file
726
x/grammar/schema/schema.go
Normal file
@@ -0,0 +1,726 @@
|
||||
//go:build mlx
|
||||
|
||||
// Package schema converts OpenAI-compatible JSON Schema into constrained grammars.
|
||||
package schema
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/grammar"
|
||||
)
|
||||
|
||||
// schemaNode represents OpenAI-compatible JSON Schema for structured outputs.
|
||||
// See: https://platform.openai.com/docs/guides/structured-outputs
|
||||
type schemaNode struct {
|
||||
// Core types
|
||||
Type interface{} `json:"type"` // string, []string, or nil
|
||||
|
||||
// Object properties
|
||||
Properties map[string]*schemaNode `json:"properties"`
|
||||
Required []string `json:"required"`
|
||||
AdditionalProperties interface{} `json:"additionalProperties"`
|
||||
|
||||
// Array properties
|
||||
Items *schemaNode `json:"items"`
|
||||
MinItems *int `json:"minItems"`
|
||||
MaxItems *int `json:"maxItems"`
|
||||
|
||||
// String properties
|
||||
Pattern string `json:"pattern"` // Regex pattern
|
||||
Format string `json:"format"` // date-time, email, uuid, etc.
|
||||
|
||||
// Number properties (noted but not enforced in grammar - validated post-generation)
|
||||
Minimum *float64 `json:"minimum"`
|
||||
Maximum *float64 `json:"maximum"`
|
||||
ExclusiveMinimum *float64 `json:"exclusiveMinimum"`
|
||||
ExclusiveMaximum *float64 `json:"exclusiveMaximum"`
|
||||
MultipleOf *float64 `json:"multipleOf"`
|
||||
|
||||
// Enum and const
|
||||
Enum []interface{} `json:"enum"`
|
||||
Const interface{} `json:"const"`
|
||||
|
||||
// Composition
|
||||
AnyOf []*schemaNode `json:"anyOf"`
|
||||
OneOf []*schemaNode `json:"oneOf"` // Treated same as anyOf for grammar
|
||||
|
||||
// References and definitions
|
||||
Ref string `json:"$ref"`
|
||||
Defs map[string]*schemaNode `json:"$defs"`
|
||||
|
||||
// Description (ignored for grammar but useful for docs)
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// converter handles JSON Schema to EBNF conversion with state.
|
||||
type converter struct {
|
||||
schema *schemaNode
|
||||
definitions map[string]*schemaNode // Resolved $defs
|
||||
usedTypes map[string]bool
|
||||
rules []string
|
||||
ruleNum int
|
||||
definedRefs map[string]bool // Track which refs we've already defined as rules
|
||||
}
|
||||
|
||||
// EBNF converts a JSON Schema to EBNF grammar
|
||||
func EBNF(schemaJSON string) (string, error) {
|
||||
var schema schemaNode
|
||||
if err := json.Unmarshal([]byte(schemaJSON), &schema); err != nil {
|
||||
return "", fmt.Errorf("failed to parse JSON Schema: %w", err)
|
||||
}
|
||||
|
||||
conv := &converter{
|
||||
schema: &schema,
|
||||
definitions: schema.Defs,
|
||||
usedTypes: make(map[string]bool),
|
||||
definedRefs: make(map[string]bool),
|
||||
}
|
||||
|
||||
return conv.convert()
|
||||
}
|
||||
|
||||
func (c *converter) convert() (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Generate root rule
|
||||
rootExpr := c.schemaToExpr(c.schema, "root")
|
||||
b.WriteString("root = ")
|
||||
b.WriteString(rootExpr)
|
||||
b.WriteString(" .\n")
|
||||
|
||||
// Add generated rules (refs, items, etc.)
|
||||
for _, rule := range c.rules {
|
||||
b.WriteString(rule)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Add primitives based on usage
|
||||
c.addPrimitives(&b)
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
func (c *converter) addPrimitives(b *strings.Builder) {
|
||||
if c.usedTypes["string"] {
|
||||
b.WriteString(`
|
||||
string = "\"" { character } "\"" .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["string"] || c.usedTypes["character"] {
|
||||
b.WriteString(`
|
||||
character = unescaped | escaped .
|
||||
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
||||
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
|
||||
unicode = "u" hex hex hex hex .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["number"] {
|
||||
b.WriteString(`
|
||||
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
||||
integer = "0" | onenine { digit } .
|
||||
fraction = "." digit { digit } .
|
||||
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["integer"] {
|
||||
b.WriteString(`
|
||||
int = [ "-" ] ( "0" | onenine { digit } ) .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["number"] || c.usedTypes["integer"] || c.usedTypes["digit"] {
|
||||
b.WriteString(`
|
||||
digit = "0" … "9" .
|
||||
`)
|
||||
}
|
||||
|
||||
// onenine only needed for number/integer, not for digit-only formats
|
||||
if c.usedTypes["number"] || c.usedTypes["integer"] {
|
||||
b.WriteString(`onenine = "1" … "9" .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["string"] || c.usedTypes["character"] || c.usedTypes["hex"] {
|
||||
b.WriteString(`
|
||||
hex = "0" … "9" | "A" … "F" | "a" … "f" .
|
||||
`)
|
||||
}
|
||||
|
||||
if c.usedTypes["ws"] {
|
||||
b.WriteString(`
|
||||
ws = { " " | "\t" | "\n" | "\r" } .
|
||||
`)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *converter) schemaToExpr(schema *schemaNode, name string) string {
|
||||
if schema == nil {
|
||||
c.usedTypes["string"] = true
|
||||
c.usedTypes["number"] = true
|
||||
return "( string | number | object | array | \"true\" | \"false\" | \"null\" )"
|
||||
}
|
||||
|
||||
// Handle $ref first
|
||||
if schema.Ref != "" {
|
||||
return c.resolveRef(schema.Ref)
|
||||
}
|
||||
|
||||
// Handle const
|
||||
if schema.Const != nil {
|
||||
return c.constToExpr(schema.Const)
|
||||
}
|
||||
|
||||
// Handle enum
|
||||
if len(schema.Enum) > 0 {
|
||||
return c.enumToExpr(schema.Enum)
|
||||
}
|
||||
|
||||
// Handle anyOf / oneOf
|
||||
if len(schema.AnyOf) > 0 {
|
||||
return c.anyOfToExpr(schema.AnyOf, name)
|
||||
}
|
||||
if len(schema.OneOf) > 0 {
|
||||
return c.anyOfToExpr(schema.OneOf, name)
|
||||
}
|
||||
|
||||
// Handle type
|
||||
types := c.getTypes(schema.Type)
|
||||
if len(types) == 0 {
|
||||
// No type specified, could be anything
|
||||
c.usedTypes["string"] = true
|
||||
c.usedTypes["number"] = true
|
||||
return "( string | number | \"true\" | \"false\" | \"null\" )"
|
||||
}
|
||||
|
||||
if len(types) == 1 {
|
||||
return c.typeToExpr(types[0], schema, name)
|
||||
}
|
||||
|
||||
// Multiple types (e.g., ["string", "null"])
|
||||
var parts []string
|
||||
for _, t := range types {
|
||||
parts = append(parts, c.typeToExpr(t, schema, name))
|
||||
}
|
||||
return "( " + strings.Join(parts, " | ") + " )"
|
||||
}
|
||||
|
||||
func (c *converter) typeToExpr(typeName string, schema *schemaNode, name string) string {
|
||||
switch typeName {
|
||||
case "object":
|
||||
return c.objectToExpr(schema, name)
|
||||
case "array":
|
||||
return c.arrayToExpr(schema, name)
|
||||
case "string":
|
||||
return c.stringToExpr(schema, name)
|
||||
case "number":
|
||||
c.usedTypes["number"] = true
|
||||
return "number"
|
||||
case "integer":
|
||||
c.usedTypes["integer"] = true
|
||||
c.usedTypes["digit"] = true
|
||||
return "int"
|
||||
case "boolean":
|
||||
return `( "true" | "false" )`
|
||||
case "null":
|
||||
return `"null"`
|
||||
default:
|
||||
c.usedTypes["string"] = true
|
||||
c.usedTypes["number"] = true
|
||||
return "string"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *converter) objectToExpr(schema *schemaNode, name string) string {
|
||||
c.usedTypes["ws"] = true
|
||||
|
||||
if len(schema.Properties) == 0 {
|
||||
return `"{" ws "}"`
|
||||
}
|
||||
|
||||
// Sort properties for deterministic output
|
||||
// Required properties come first, in their required order
|
||||
var propOrder []string
|
||||
requiredSet := make(map[string]bool)
|
||||
for _, r := range schema.Required {
|
||||
requiredSet[r] = true
|
||||
propOrder = append(propOrder, r)
|
||||
}
|
||||
|
||||
// Add any non-required properties (though OpenAI requires all to be required)
|
||||
var optionalProps []string
|
||||
for propName := range schema.Properties {
|
||||
if !requiredSet[propName] {
|
||||
optionalProps = append(optionalProps, propName)
|
||||
}
|
||||
}
|
||||
sort.Strings(optionalProps)
|
||||
propOrder = append(propOrder, optionalProps...)
|
||||
|
||||
var propExprs []string
|
||||
first := true
|
||||
|
||||
for _, propName := range propOrder {
|
||||
propSchema, exists := schema.Properties[propName]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
propExpr := c.schemaToExpr(propSchema, propName)
|
||||
|
||||
prefix := ""
|
||||
if !first {
|
||||
prefix = `"," ws `
|
||||
}
|
||||
first = false
|
||||
|
||||
propExprs = append(propExprs, fmt.Sprintf(`%s"\"%s\"" ws ":" ws %s`, prefix, propName, propExpr))
|
||||
}
|
||||
|
||||
if len(propExprs) == 0 {
|
||||
return `"{" ws "}"`
|
||||
}
|
||||
|
||||
return `"{" ws ` + strings.Join(propExprs, " ") + ` ws "}"`
|
||||
}
|
||||
|
||||
func (c *converter) arrayToExpr(schema *schemaNode, name string) string {
|
||||
c.usedTypes["ws"] = true
|
||||
|
||||
itemExpr := "value"
|
||||
if schema.Items != nil {
|
||||
itemExpr = c.schemaToExpr(schema.Items, name+"_item")
|
||||
} else {
|
||||
c.usedTypes["string"] = true
|
||||
c.usedTypes["number"] = true
|
||||
}
|
||||
|
||||
// Create item rule
|
||||
c.ruleNum++
|
||||
itemRule := fmt.Sprintf("item%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf("%s = %s .", itemRule, itemExpr))
|
||||
|
||||
// Handle minItems/maxItems
|
||||
if schema.MinItems != nil || schema.MaxItems != nil {
|
||||
return c.arrayWithBounds(itemRule, schema.MinItems, schema.MaxItems)
|
||||
}
|
||||
|
||||
// Default: zero or more items
|
||||
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s { "," ws %s } ws "]" )`, itemRule, itemRule)
|
||||
}
|
||||
|
||||
func (c *converter) arrayWithBounds(itemRule string, minItems, maxItems *int) string {
|
||||
min := 0
|
||||
max := -1 // unlimited
|
||||
|
||||
if minItems != nil {
|
||||
min = *minItems
|
||||
}
|
||||
if maxItems != nil {
|
||||
max = *maxItems
|
||||
}
|
||||
|
||||
if min == 0 && max < 0 {
|
||||
// No constraints
|
||||
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s { "," ws %s } ws "]" )`, itemRule, itemRule)
|
||||
}
|
||||
|
||||
if min == 0 && max == 0 {
|
||||
return `"[" ws "]"`
|
||||
}
|
||||
|
||||
// Build pattern for bounded array
|
||||
// For min=2, max=4: item "," item [ "," item ] [ "," item ]
|
||||
var parts []string
|
||||
|
||||
// Required items
|
||||
for i := 0; i < min; i++ {
|
||||
if i > 0 {
|
||||
parts = append(parts, `"," ws`)
|
||||
}
|
||||
parts = append(parts, itemRule)
|
||||
}
|
||||
|
||||
// Optional items up to max
|
||||
if max > min {
|
||||
for i := min; i < max; i++ {
|
||||
if i == 0 {
|
||||
parts = append(parts, fmt.Sprintf(`[ %s`, itemRule))
|
||||
} else {
|
||||
parts = append(parts, fmt.Sprintf(`[ "," ws %s`, itemRule))
|
||||
}
|
||||
}
|
||||
// Close all optional brackets
|
||||
for i := min; i < max; i++ {
|
||||
parts = append(parts, "]")
|
||||
}
|
||||
} else if max < 0 {
|
||||
// Unlimited after min
|
||||
if min > 0 {
|
||||
parts = append(parts, fmt.Sprintf(`{ "," ws %s }`, itemRule))
|
||||
} else {
|
||||
parts = append(parts, fmt.Sprintf(`[ %s { "," ws %s } ]`, itemRule, itemRule))
|
||||
}
|
||||
}
|
||||
|
||||
if min == 0 {
|
||||
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s ws "]" )`, strings.Join(parts, " "))
|
||||
}
|
||||
return fmt.Sprintf(`"[" ws %s ws "]"`, strings.Join(parts, " "))
|
||||
}
|
||||
|
||||
func (c *converter) stringToExpr(schema *schemaNode, name string) string {
|
||||
// Handle format
|
||||
if schema.Format != "" {
|
||||
return c.formatToExpr(schema.Format)
|
||||
}
|
||||
|
||||
// Handle pattern (regex)
|
||||
if schema.Pattern != "" {
|
||||
return c.patternToExpr(schema.Pattern, name)
|
||||
}
|
||||
|
||||
// Default string
|
||||
c.usedTypes["string"] = true
|
||||
if name == "root" {
|
||||
c.usedTypes["character"] = true
|
||||
return `"\"" { character } "\""`
|
||||
}
|
||||
return "string"
|
||||
}
|
||||
|
||||
func (c *converter) formatToExpr(format string) string {
|
||||
switch format {
|
||||
case "date":
|
||||
// YYYY-MM-DD
|
||||
c.ruleNum++
|
||||
c.usedTypes["digit"] = true
|
||||
ruleName := fmt.Sprintf("date%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit digit digit "-" digit digit "-" digit digit "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "time":
|
||||
// HH:MM:SS
|
||||
c.ruleNum++
|
||||
c.usedTypes["digit"] = true
|
||||
ruleName := fmt.Sprintf("time%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit ":" digit digit ":" digit digit "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "date-time":
|
||||
// YYYY-MM-DDTHH:MM:SSZ or with offset
|
||||
c.ruleNum++
|
||||
c.usedTypes["digit"] = true
|
||||
ruleName := fmt.Sprintf("datetime%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit digit digit "-" digit digit "-" digit digit "T" digit digit ":" digit digit ":" digit digit ( "Z" | ( "+" | "-" ) digit digit ":" digit digit ) "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "email":
|
||||
// Simplified email pattern
|
||||
c.ruleNum++
|
||||
ruleName := fmt.Sprintf("email%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" emailchar { emailchar } "@" emailchar { emailchar } "." emailchar { emailchar } "\"" .`, ruleName))
|
||||
c.rules = append(c.rules, `emailchar = "a" … "z" | "A" … "Z" | "0" … "9" | "." | "-" | "_" .`)
|
||||
return ruleName
|
||||
|
||||
case "uuid":
|
||||
// 8-4-4-4-12 hex pattern
|
||||
c.ruleNum++
|
||||
ruleName := fmt.Sprintf("uuid%d", c.ruleNum)
|
||||
c.usedTypes["hex"] = true
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" hex hex hex hex hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex hex hex hex hex hex hex hex hex "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "ipv4":
|
||||
c.ruleNum++
|
||||
c.usedTypes["digit"] = true
|
||||
ruleName := fmt.Sprintf("ipv4_%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit { digit } "." digit { digit } "." digit { digit } "." digit { digit } "\"" .`, ruleName))
|
||||
return ruleName
|
||||
|
||||
case "uri", "hostname":
|
||||
// Fallback to general string for complex formats
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
|
||||
default:
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *converter) patternToExpr(pattern string, name string) string {
|
||||
// Try to convert simple regex patterns to EBNF
|
||||
// This handles common cases; complex regex falls back to string
|
||||
|
||||
// Remove anchors
|
||||
pattern = strings.TrimPrefix(pattern, "^")
|
||||
pattern = strings.TrimSuffix(pattern, "$")
|
||||
|
||||
// Try to parse and convert
|
||||
expr, ok := c.regexToEBNF(pattern)
|
||||
if !ok {
|
||||
// Fallback to general string
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
|
||||
c.ruleNum++
|
||||
ruleName := fmt.Sprintf("pattern%d", c.ruleNum)
|
||||
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" %s "\"" .`, ruleName, expr))
|
||||
return ruleName
|
||||
}
|
||||
|
||||
func (c *converter) regexToEBNF(pattern string) (string, bool) {
|
||||
// Simple regex to EBNF converter
|
||||
// Handles: literals, [a-z], [A-Z], [0-9], +, *, ?, basic groups
|
||||
|
||||
var result strings.Builder
|
||||
i := 0
|
||||
|
||||
for i < len(pattern) {
|
||||
ch := pattern[i]
|
||||
|
||||
switch ch {
|
||||
case '[':
|
||||
// Character class
|
||||
end := strings.Index(pattern[i:], "]")
|
||||
if end == -1 {
|
||||
return "", false
|
||||
}
|
||||
class := pattern[i+1 : i+end]
|
||||
ebnfClass, ok := c.charClassToEBNF(class)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
result.WriteString(ebnfClass)
|
||||
i += end + 1
|
||||
|
||||
case '(':
|
||||
// Group - find matching )
|
||||
depth := 1
|
||||
start := i + 1
|
||||
j := start
|
||||
for j < len(pattern) && depth > 0 {
|
||||
if pattern[j] == '(' {
|
||||
depth++
|
||||
} else if pattern[j] == ')' {
|
||||
depth--
|
||||
}
|
||||
j++
|
||||
}
|
||||
if depth != 0 {
|
||||
return "", false
|
||||
}
|
||||
groupContent := pattern[start : j-1]
|
||||
groupExpr, ok := c.regexToEBNF(groupContent)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
result.WriteString("( ")
|
||||
result.WriteString(groupExpr)
|
||||
result.WriteString(" )")
|
||||
i = j
|
||||
|
||||
case '|':
|
||||
result.WriteString(" | ")
|
||||
i++
|
||||
|
||||
case '+':
|
||||
// One or more - wrap previous in { } and add one required
|
||||
// This is a simplification
|
||||
return "", false // TODO: handle properly
|
||||
|
||||
case '*':
|
||||
// Zero or more - need to wrap previous
|
||||
return "", false // TODO: handle properly
|
||||
|
||||
case '?':
|
||||
// Optional - need to wrap previous in [ ]
|
||||
return "", false // TODO: handle properly
|
||||
|
||||
case '\\':
|
||||
// Escape sequence
|
||||
if i+1 >= len(pattern) {
|
||||
return "", false
|
||||
}
|
||||
next := pattern[i+1]
|
||||
switch next {
|
||||
case 'd':
|
||||
result.WriteString("digit")
|
||||
c.usedTypes["digit"] = true
|
||||
case 'w':
|
||||
result.WriteString(`( "a" … "z" | "A" … "Z" | "0" … "9" | "_" )`)
|
||||
case 's':
|
||||
result.WriteString(`( " " | "\t" )`)
|
||||
default:
|
||||
result.WriteString(fmt.Sprintf(`"%c"`, next))
|
||||
}
|
||||
i += 2
|
||||
|
||||
default:
|
||||
// Literal character
|
||||
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' || ch == '.' {
|
||||
result.WriteString(fmt.Sprintf(`"%c" `, ch))
|
||||
} else {
|
||||
// Special char, try to escape
|
||||
result.WriteString(fmt.Sprintf(`"%c" `, ch))
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return strings.TrimSpace(result.String()), true
|
||||
}
|
||||
|
||||
func (c *converter) charClassToEBNF(class string) (string, bool) {
|
||||
// Handle character classes like a-z, A-Z, 0-9
|
||||
if class == "a-zA-Z0-9_" || class == "a-zA-Z_" {
|
||||
return `( "a" … "z" | "A" … "Z" | "0" … "9" | "_" )`, true
|
||||
}
|
||||
if class == "a-zA-Z0-9" {
|
||||
return `( "a" … "z" | "A" … "Z" | "0" … "9" )`, true
|
||||
}
|
||||
if class == "a-z" {
|
||||
return `"a" … "z"`, true
|
||||
}
|
||||
if class == "A-Z" {
|
||||
return `"A" … "Z"`, true
|
||||
}
|
||||
if class == "0-9" {
|
||||
c.usedTypes["digit"] = true
|
||||
return "digit", true
|
||||
}
|
||||
|
||||
// Try to parse range patterns
|
||||
if matched, _ := regexp.MatchString(`^[a-zA-Z]-[a-zA-Z]$`, class); matched {
|
||||
return fmt.Sprintf(`"%c" … "%c"`, class[0], class[2]), true
|
||||
}
|
||||
if matched, _ := regexp.MatchString(`^[0-9]-[0-9]$`, class); matched {
|
||||
return fmt.Sprintf(`"%c" … "%c"`, class[0], class[2]), true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (c *converter) anyOfToExpr(schemas []*schemaNode, name string) string {
|
||||
var parts []string
|
||||
for i, s := range schemas {
|
||||
expr := c.schemaToExpr(s, fmt.Sprintf("%s_opt%d", name, i))
|
||||
parts = append(parts, expr)
|
||||
}
|
||||
return "( " + strings.Join(parts, " | ") + " )"
|
||||
}
|
||||
|
||||
func (c *converter) enumToExpr(values []interface{}) string {
|
||||
var parts []string
|
||||
for _, v := range values {
|
||||
parts = append(parts, c.constToExpr(v))
|
||||
}
|
||||
return "( " + strings.Join(parts, " | ") + " )"
|
||||
}
|
||||
|
||||
func (c *converter) constToExpr(v interface{}) string {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return fmt.Sprintf(`"\"%s\""`, c.escapeString(val))
|
||||
case float64:
|
||||
if val == float64(int(val)) {
|
||||
return fmt.Sprintf(`"%d"`, int(val))
|
||||
}
|
||||
return fmt.Sprintf(`"%v"`, val)
|
||||
case bool:
|
||||
if val {
|
||||
return `"true"`
|
||||
}
|
||||
return `"false"`
|
||||
case nil:
|
||||
return `"null"`
|
||||
default:
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *converter) resolveRef(ref string) string {
|
||||
// Handle #/$defs/name references
|
||||
if strings.HasPrefix(ref, "#/$defs/") {
|
||||
defName := strings.TrimPrefix(ref, "#/$defs/")
|
||||
return c.resolveDefRef(defName)
|
||||
}
|
||||
|
||||
// Handle root recursion #
|
||||
if ref == "#" {
|
||||
return "root"
|
||||
}
|
||||
|
||||
// Unknown ref format
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
|
||||
func (c *converter) resolveDefRef(defName string) string {
|
||||
// Check if we've already defined this as a rule
|
||||
ruleName := "def_" + defName
|
||||
if c.definedRefs[defName] {
|
||||
return ruleName
|
||||
}
|
||||
|
||||
// Mark as defined to prevent infinite recursion
|
||||
c.definedRefs[defName] = true
|
||||
|
||||
// Look up the definition
|
||||
if c.definitions == nil {
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
|
||||
defSchema, ok := c.definitions[defName]
|
||||
if !ok {
|
||||
c.usedTypes["string"] = true
|
||||
return "string"
|
||||
}
|
||||
|
||||
// Generate the rule
|
||||
expr := c.schemaToExpr(defSchema, ruleName)
|
||||
c.rules = append(c.rules, fmt.Sprintf("%s = %s .", ruleName, expr))
|
||||
|
||||
return ruleName
|
||||
}
|
||||
|
||||
func (c *converter) getTypes(t interface{}) []string {
|
||||
switch v := t.(type) {
|
||||
case string:
|
||||
return []string{v}
|
||||
case []interface{}:
|
||||
var types []string
|
||||
for _, item := range v {
|
||||
if s, ok := item.(string); ok {
|
||||
types = append(types, s)
|
||||
}
|
||||
}
|
||||
return types
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *converter) escapeString(s string) string {
|
||||
s = strings.ReplaceAll(s, `\`, `\\`)
|
||||
s = strings.ReplaceAll(s, `"`, `\"`)
|
||||
return s
|
||||
}
|
||||
|
||||
// Grammar converts a JSON Schema string into a compiled grammar.
|
||||
func Grammar(schemaJSON string) (*grammar.Grammar, error) {
|
||||
ebnf, err := EBNF(schemaJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return grammar.ParseEBNF(ebnf, "root")
|
||||
}
|
||||
336
x/grammar/schema/schema_test.go
Normal file
336
x/grammar/schema/schema_test.go
Normal file
@@ -0,0 +1,336 @@
|
||||
//go:build mlx
|
||||
|
||||
package schema
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
gram "github.com/ollama/ollama/x/grammar"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
func TestJSONEBNF(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
schema string
|
||||
}{
|
||||
{
|
||||
name: "simple object",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "with enum",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"enum": ["active", "inactive", "pending"]}
|
||||
},
|
||||
"required": ["status"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "array of objects",
|
||||
schema: `{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "integer"}
|
||||
},
|
||||
"required": ["id"]
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "nested object",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email": {"type": "string"}
|
||||
},
|
||||
"required": ["email"]
|
||||
}
|
||||
},
|
||||
"required": ["user"]
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ebnf, err := EBNF(tc.schema)
|
||||
if err != nil {
|
||||
t.Fatalf("EBNF failed: %v", err)
|
||||
}
|
||||
|
||||
// Try to compile it
|
||||
grammar, err := gram.ParseEBNF(ebnf, "root")
|
||||
if err != nil {
|
||||
t.Fatalf("ParseEBNF failed: %v", err)
|
||||
}
|
||||
|
||||
if grammar == nil {
|
||||
t.Fatal("grammar is nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrammarEngine(t *testing.T) {
|
||||
schema := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}`
|
||||
|
||||
grammar, err := Grammar(schema)
|
||||
if err != nil {
|
||||
t.Fatalf("Grammar failed: %v", err)
|
||||
}
|
||||
|
||||
vocab := []string{
|
||||
"{", "}", "[", "]", ":", ",",
|
||||
"\"name\"", "\"age\"", "\"test\"",
|
||||
"\"", "a", "b", "c",
|
||||
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
|
||||
" ", "\n",
|
||||
"true", "false", "null",
|
||||
}
|
||||
|
||||
engine, err := gram.NewEngine(grammar, vocab)
|
||||
if err != nil {
|
||||
t.Fatalf("grammar.NewEngine failed: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
logits := mlx.Ones(int32(len(vocab)))
|
||||
mlx.Keep(logits)
|
||||
|
||||
// Test that we can apply mask
|
||||
masked := engine.ApplyMask(logits)
|
||||
mlx.Eval(masked)
|
||||
}
|
||||
|
||||
// TestOpenAIStructuredOutputs tests features required for OpenAI compatibility
|
||||
func TestOpenAIStructuredOutputs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
schema string
|
||||
}{
|
||||
{
|
||||
name: "anyOf union",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "integer"}
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": ["value"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "nullable string via type array",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": ["string", "null"]}
|
||||
},
|
||||
"required": ["name"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "$ref with $defs",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"person": {"$ref": "#/$defs/Person"}
|
||||
},
|
||||
"required": ["person"],
|
||||
"$defs": {
|
||||
"Person": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "const value",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"const": "user"}
|
||||
},
|
||||
"required": ["type"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "format date-time",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created": {"type": "string", "format": "date-time"}
|
||||
},
|
||||
"required": ["created"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "format date",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"birthday": {"type": "string", "format": "date"}
|
||||
},
|
||||
"required": ["birthday"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "format email",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email": {"type": "string", "format": "email"}
|
||||
},
|
||||
"required": ["email"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "format uuid",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "string", "format": "uuid"}
|
||||
},
|
||||
"required": ["id"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "array with minItems maxItems",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"minItems": 1,
|
||||
"maxItems": 3
|
||||
}
|
||||
},
|
||||
"required": ["tags"]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "deeply nested with refs",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"company": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"employees": {
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/$defs/Employee"}
|
||||
}
|
||||
},
|
||||
"required": ["name", "employees"]
|
||||
}
|
||||
},
|
||||
"required": ["company"],
|
||||
"$defs": {
|
||||
"Employee": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"role": {"enum": ["engineer", "manager", "intern"]}
|
||||
},
|
||||
"required": ["name", "role"]
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "multiple refs same def",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"from": {"$ref": "#/$defs/Address"},
|
||||
"to": {"$ref": "#/$defs/Address"}
|
||||
},
|
||||
"required": ["from", "to"],
|
||||
"$defs": {
|
||||
"Address": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"zip": {"type": "string"}
|
||||
},
|
||||
"required": ["city", "zip"]
|
||||
}
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "oneOf variant",
|
||||
schema: `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"result": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"success": {"type": "boolean"}},
|
||||
"required": ["success"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"error": {"type": "string"}},
|
||||
"required": ["error"]
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": ["result"]
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ebnf, err := EBNF(tc.schema)
|
||||
if err != nil {
|
||||
t.Fatalf("EBNF failed: %v", err)
|
||||
}
|
||||
|
||||
grammar, err := gram.ParseEBNF(ebnf, "root")
|
||||
if err != nil {
|
||||
t.Fatalf("ParseEBNF failed: %v", err)
|
||||
}
|
||||
|
||||
if grammar == nil {
|
||||
t.Fatal("grammar is nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user