mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 00:54:05 +02:00
727 lines
18 KiB
Go
727 lines
18 KiB
Go
//go:build mlx
|
|
|
|
// Package schema converts OpenAI-compatible JSON Schema into constrained grammars.
|
|
package schema
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"regexp"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/ollama/ollama/x/grammar"
|
|
)
|
|
|
|
// schemaNode represents OpenAI-compatible JSON Schema for structured outputs.
|
|
// See: https://platform.openai.com/docs/guides/structured-outputs
|
|
type schemaNode struct {
|
|
// Core types
|
|
Type interface{} `json:"type"` // string, []string, or nil
|
|
|
|
// Object properties
|
|
Properties map[string]*schemaNode `json:"properties"`
|
|
Required []string `json:"required"`
|
|
AdditionalProperties interface{} `json:"additionalProperties"`
|
|
|
|
// Array properties
|
|
Items *schemaNode `json:"items"`
|
|
MinItems *int `json:"minItems"`
|
|
MaxItems *int `json:"maxItems"`
|
|
|
|
// String properties
|
|
Pattern string `json:"pattern"` // Regex pattern
|
|
Format string `json:"format"` // date-time, email, uuid, etc.
|
|
|
|
// Number properties (noted but not enforced in grammar - validated post-generation)
|
|
Minimum *float64 `json:"minimum"`
|
|
Maximum *float64 `json:"maximum"`
|
|
ExclusiveMinimum *float64 `json:"exclusiveMinimum"`
|
|
ExclusiveMaximum *float64 `json:"exclusiveMaximum"`
|
|
MultipleOf *float64 `json:"multipleOf"`
|
|
|
|
// Enum and const
|
|
Enum []interface{} `json:"enum"`
|
|
Const interface{} `json:"const"`
|
|
|
|
// Composition
|
|
AnyOf []*schemaNode `json:"anyOf"`
|
|
OneOf []*schemaNode `json:"oneOf"` // Treated same as anyOf for grammar
|
|
|
|
// References and definitions
|
|
Ref string `json:"$ref"`
|
|
Defs map[string]*schemaNode `json:"$defs"`
|
|
|
|
// Description (ignored for grammar but useful for docs)
|
|
Description string `json:"description"`
|
|
}
|
|
|
|
// converter handles JSON Schema to EBNF conversion with state.
|
|
type converter struct {
|
|
schema *schemaNode
|
|
definitions map[string]*schemaNode // Resolved $defs
|
|
usedTypes map[string]bool
|
|
rules []string
|
|
ruleNum int
|
|
definedRefs map[string]bool // Track which refs we've already defined as rules
|
|
}
|
|
|
|
// EBNF converts a JSON Schema to EBNF grammar
|
|
func EBNF(schemaJSON string) (string, error) {
|
|
var schema schemaNode
|
|
if err := json.Unmarshal([]byte(schemaJSON), &schema); err != nil {
|
|
return "", fmt.Errorf("failed to parse JSON Schema: %w", err)
|
|
}
|
|
|
|
conv := &converter{
|
|
schema: &schema,
|
|
definitions: schema.Defs,
|
|
usedTypes: make(map[string]bool),
|
|
definedRefs: make(map[string]bool),
|
|
}
|
|
|
|
return conv.convert()
|
|
}
|
|
|
|
func (c *converter) convert() (string, error) {
|
|
var b strings.Builder
|
|
|
|
// Generate root rule
|
|
rootExpr := c.schemaToExpr(c.schema, "root")
|
|
b.WriteString("root = ")
|
|
b.WriteString(rootExpr)
|
|
b.WriteString(" .\n")
|
|
|
|
// Add generated rules (refs, items, etc.)
|
|
for _, rule := range c.rules {
|
|
b.WriteString(rule)
|
|
b.WriteString("\n")
|
|
}
|
|
|
|
// Add primitives based on usage
|
|
c.addPrimitives(&b)
|
|
|
|
return b.String(), nil
|
|
}
|
|
|
|
func (c *converter) addPrimitives(b *strings.Builder) {
|
|
if c.usedTypes["string"] {
|
|
b.WriteString(`
|
|
string = "\"" { character } "\"" .
|
|
`)
|
|
}
|
|
|
|
if c.usedTypes["string"] || c.usedTypes["character"] {
|
|
b.WriteString(`
|
|
character = unescaped | escaped .
|
|
unescaped = " " | "!" | "#" … "[" | "]" … "~" .
|
|
escaped = "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | unicode ) .
|
|
unicode = "u" hex hex hex hex .
|
|
`)
|
|
}
|
|
|
|
if c.usedTypes["number"] {
|
|
b.WriteString(`
|
|
number = [ "-" ] integer [ fraction ] [ exponent ] .
|
|
integer = "0" | onenine { digit } .
|
|
fraction = "." digit { digit } .
|
|
exponent = ( "e" | "E" ) [ "+" | "-" ] digit { digit } .
|
|
`)
|
|
}
|
|
|
|
if c.usedTypes["integer"] {
|
|
b.WriteString(`
|
|
int = [ "-" ] ( "0" | onenine { digit } ) .
|
|
`)
|
|
}
|
|
|
|
if c.usedTypes["number"] || c.usedTypes["integer"] || c.usedTypes["digit"] {
|
|
b.WriteString(`
|
|
digit = "0" … "9" .
|
|
`)
|
|
}
|
|
|
|
// onenine only needed for number/integer, not for digit-only formats
|
|
if c.usedTypes["number"] || c.usedTypes["integer"] {
|
|
b.WriteString(`onenine = "1" … "9" .
|
|
`)
|
|
}
|
|
|
|
if c.usedTypes["string"] || c.usedTypes["character"] || c.usedTypes["hex"] {
|
|
b.WriteString(`
|
|
hex = "0" … "9" | "A" … "F" | "a" … "f" .
|
|
`)
|
|
}
|
|
|
|
if c.usedTypes["ws"] {
|
|
b.WriteString(`
|
|
ws = { " " | "\t" | "\n" | "\r" } .
|
|
`)
|
|
}
|
|
}
|
|
|
|
func (c *converter) schemaToExpr(schema *schemaNode, name string) string {
|
|
if schema == nil {
|
|
c.usedTypes["string"] = true
|
|
c.usedTypes["number"] = true
|
|
return "( string | number | object | array | \"true\" | \"false\" | \"null\" )"
|
|
}
|
|
|
|
// Handle $ref first
|
|
if schema.Ref != "" {
|
|
return c.resolveRef(schema.Ref)
|
|
}
|
|
|
|
// Handle const
|
|
if schema.Const != nil {
|
|
return c.constToExpr(schema.Const)
|
|
}
|
|
|
|
// Handle enum
|
|
if len(schema.Enum) > 0 {
|
|
return c.enumToExpr(schema.Enum)
|
|
}
|
|
|
|
// Handle anyOf / oneOf
|
|
if len(schema.AnyOf) > 0 {
|
|
return c.anyOfToExpr(schema.AnyOf, name)
|
|
}
|
|
if len(schema.OneOf) > 0 {
|
|
return c.anyOfToExpr(schema.OneOf, name)
|
|
}
|
|
|
|
// Handle type
|
|
types := c.getTypes(schema.Type)
|
|
if len(types) == 0 {
|
|
// No type specified, could be anything
|
|
c.usedTypes["string"] = true
|
|
c.usedTypes["number"] = true
|
|
return "( string | number | \"true\" | \"false\" | \"null\" )"
|
|
}
|
|
|
|
if len(types) == 1 {
|
|
return c.typeToExpr(types[0], schema, name)
|
|
}
|
|
|
|
// Multiple types (e.g., ["string", "null"])
|
|
var parts []string
|
|
for _, t := range types {
|
|
parts = append(parts, c.typeToExpr(t, schema, name))
|
|
}
|
|
return "( " + strings.Join(parts, " | ") + " )"
|
|
}
|
|
|
|
func (c *converter) typeToExpr(typeName string, schema *schemaNode, name string) string {
|
|
switch typeName {
|
|
case "object":
|
|
return c.objectToExpr(schema, name)
|
|
case "array":
|
|
return c.arrayToExpr(schema, name)
|
|
case "string":
|
|
return c.stringToExpr(schema, name)
|
|
case "number":
|
|
c.usedTypes["number"] = true
|
|
return "number"
|
|
case "integer":
|
|
c.usedTypes["integer"] = true
|
|
c.usedTypes["digit"] = true
|
|
return "int"
|
|
case "boolean":
|
|
return `( "true" | "false" )`
|
|
case "null":
|
|
return `"null"`
|
|
default:
|
|
c.usedTypes["string"] = true
|
|
c.usedTypes["number"] = true
|
|
return "string"
|
|
}
|
|
}
|
|
|
|
func (c *converter) objectToExpr(schema *schemaNode, name string) string {
|
|
c.usedTypes["ws"] = true
|
|
|
|
if len(schema.Properties) == 0 {
|
|
return `"{" ws "}"`
|
|
}
|
|
|
|
// Sort properties for deterministic output
|
|
// Required properties come first, in their required order
|
|
var propOrder []string
|
|
requiredSet := make(map[string]bool)
|
|
for _, r := range schema.Required {
|
|
requiredSet[r] = true
|
|
propOrder = append(propOrder, r)
|
|
}
|
|
|
|
// Add any non-required properties (though OpenAI requires all to be required)
|
|
var optionalProps []string
|
|
for propName := range schema.Properties {
|
|
if !requiredSet[propName] {
|
|
optionalProps = append(optionalProps, propName)
|
|
}
|
|
}
|
|
sort.Strings(optionalProps)
|
|
propOrder = append(propOrder, optionalProps...)
|
|
|
|
var propExprs []string
|
|
first := true
|
|
|
|
for _, propName := range propOrder {
|
|
propSchema, exists := schema.Properties[propName]
|
|
if !exists {
|
|
continue
|
|
}
|
|
|
|
propExpr := c.schemaToExpr(propSchema, propName)
|
|
|
|
prefix := ""
|
|
if !first {
|
|
prefix = `"," ws `
|
|
}
|
|
first = false
|
|
|
|
propExprs = append(propExprs, fmt.Sprintf(`%s"\"%s\"" ws ":" ws %s`, prefix, propName, propExpr))
|
|
}
|
|
|
|
if len(propExprs) == 0 {
|
|
return `"{" ws "}"`
|
|
}
|
|
|
|
return `"{" ws ` + strings.Join(propExprs, " ") + ` ws "}"`
|
|
}
|
|
|
|
func (c *converter) arrayToExpr(schema *schemaNode, name string) string {
|
|
c.usedTypes["ws"] = true
|
|
|
|
itemExpr := "value"
|
|
if schema.Items != nil {
|
|
itemExpr = c.schemaToExpr(schema.Items, name+"_item")
|
|
} else {
|
|
c.usedTypes["string"] = true
|
|
c.usedTypes["number"] = true
|
|
}
|
|
|
|
// Create item rule
|
|
c.ruleNum++
|
|
itemRule := fmt.Sprintf("item%d", c.ruleNum)
|
|
c.rules = append(c.rules, fmt.Sprintf("%s = %s .", itemRule, itemExpr))
|
|
|
|
// Handle minItems/maxItems
|
|
if schema.MinItems != nil || schema.MaxItems != nil {
|
|
return c.arrayWithBounds(itemRule, schema.MinItems, schema.MaxItems)
|
|
}
|
|
|
|
// Default: zero or more items
|
|
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s { "," ws %s } ws "]" )`, itemRule, itemRule)
|
|
}
|
|
|
|
func (c *converter) arrayWithBounds(itemRule string, minItems, maxItems *int) string {
|
|
min := 0
|
|
max := -1 // unlimited
|
|
|
|
if minItems != nil {
|
|
min = *minItems
|
|
}
|
|
if maxItems != nil {
|
|
max = *maxItems
|
|
}
|
|
|
|
if min == 0 && max < 0 {
|
|
// No constraints
|
|
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s { "," ws %s } ws "]" )`, itemRule, itemRule)
|
|
}
|
|
|
|
if min == 0 && max == 0 {
|
|
return `"[" ws "]"`
|
|
}
|
|
|
|
// Build pattern for bounded array
|
|
// For min=2, max=4: item "," item [ "," item ] [ "," item ]
|
|
var parts []string
|
|
|
|
// Required items
|
|
for i := 0; i < min; i++ {
|
|
if i > 0 {
|
|
parts = append(parts, `"," ws`)
|
|
}
|
|
parts = append(parts, itemRule)
|
|
}
|
|
|
|
// Optional items up to max
|
|
if max > min {
|
|
for i := min; i < max; i++ {
|
|
if i == 0 {
|
|
parts = append(parts, fmt.Sprintf(`[ %s`, itemRule))
|
|
} else {
|
|
parts = append(parts, fmt.Sprintf(`[ "," ws %s`, itemRule))
|
|
}
|
|
}
|
|
// Close all optional brackets
|
|
for i := min; i < max; i++ {
|
|
parts = append(parts, "]")
|
|
}
|
|
} else if max < 0 {
|
|
// Unlimited after min
|
|
if min > 0 {
|
|
parts = append(parts, fmt.Sprintf(`{ "," ws %s }`, itemRule))
|
|
} else {
|
|
parts = append(parts, fmt.Sprintf(`[ %s { "," ws %s } ]`, itemRule, itemRule))
|
|
}
|
|
}
|
|
|
|
if min == 0 {
|
|
return fmt.Sprintf(`( "[" ws "]" | "[" ws %s ws "]" )`, strings.Join(parts, " "))
|
|
}
|
|
return fmt.Sprintf(`"[" ws %s ws "]"`, strings.Join(parts, " "))
|
|
}
|
|
|
|
func (c *converter) stringToExpr(schema *schemaNode, name string) string {
|
|
// Handle format
|
|
if schema.Format != "" {
|
|
return c.formatToExpr(schema.Format)
|
|
}
|
|
|
|
// Handle pattern (regex)
|
|
if schema.Pattern != "" {
|
|
return c.patternToExpr(schema.Pattern, name)
|
|
}
|
|
|
|
// Default string
|
|
c.usedTypes["string"] = true
|
|
if name == "root" {
|
|
c.usedTypes["character"] = true
|
|
return `"\"" { character } "\""`
|
|
}
|
|
return "string"
|
|
}
|
|
|
|
func (c *converter) formatToExpr(format string) string {
|
|
switch format {
|
|
case "date":
|
|
// YYYY-MM-DD
|
|
c.ruleNum++
|
|
c.usedTypes["digit"] = true
|
|
ruleName := fmt.Sprintf("date%d", c.ruleNum)
|
|
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit digit digit "-" digit digit "-" digit digit "\"" .`, ruleName))
|
|
return ruleName
|
|
|
|
case "time":
|
|
// HH:MM:SS
|
|
c.ruleNum++
|
|
c.usedTypes["digit"] = true
|
|
ruleName := fmt.Sprintf("time%d", c.ruleNum)
|
|
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit ":" digit digit ":" digit digit "\"" .`, ruleName))
|
|
return ruleName
|
|
|
|
case "date-time":
|
|
// YYYY-MM-DDTHH:MM:SSZ or with offset
|
|
c.ruleNum++
|
|
c.usedTypes["digit"] = true
|
|
ruleName := fmt.Sprintf("datetime%d", c.ruleNum)
|
|
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit digit digit digit "-" digit digit "-" digit digit "T" digit digit ":" digit digit ":" digit digit ( "Z" | ( "+" | "-" ) digit digit ":" digit digit ) "\"" .`, ruleName))
|
|
return ruleName
|
|
|
|
case "email":
|
|
// Simplified email pattern
|
|
c.ruleNum++
|
|
ruleName := fmt.Sprintf("email%d", c.ruleNum)
|
|
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" emailchar { emailchar } "@" emailchar { emailchar } "." emailchar { emailchar } "\"" .`, ruleName))
|
|
c.rules = append(c.rules, `emailchar = "a" … "z" | "A" … "Z" | "0" … "9" | "." | "-" | "_" .`)
|
|
return ruleName
|
|
|
|
case "uuid":
|
|
// 8-4-4-4-12 hex pattern
|
|
c.ruleNum++
|
|
ruleName := fmt.Sprintf("uuid%d", c.ruleNum)
|
|
c.usedTypes["hex"] = true
|
|
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" hex hex hex hex hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex "-" hex hex hex hex hex hex hex hex hex hex hex hex "\"" .`, ruleName))
|
|
return ruleName
|
|
|
|
case "ipv4":
|
|
c.ruleNum++
|
|
c.usedTypes["digit"] = true
|
|
ruleName := fmt.Sprintf("ipv4_%d", c.ruleNum)
|
|
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" digit { digit } "." digit { digit } "." digit { digit } "." digit { digit } "\"" .`, ruleName))
|
|
return ruleName
|
|
|
|
case "uri", "hostname":
|
|
// Fallback to general string for complex formats
|
|
c.usedTypes["string"] = true
|
|
return "string"
|
|
|
|
default:
|
|
c.usedTypes["string"] = true
|
|
return "string"
|
|
}
|
|
}
|
|
|
|
func (c *converter) patternToExpr(pattern string, name string) string {
|
|
// Try to convert simple regex patterns to EBNF
|
|
// This handles common cases; complex regex falls back to string
|
|
|
|
// Remove anchors
|
|
pattern = strings.TrimPrefix(pattern, "^")
|
|
pattern = strings.TrimSuffix(pattern, "$")
|
|
|
|
// Try to parse and convert
|
|
expr, ok := c.regexToEBNF(pattern)
|
|
if !ok {
|
|
// Fallback to general string
|
|
c.usedTypes["string"] = true
|
|
return "string"
|
|
}
|
|
|
|
c.ruleNum++
|
|
ruleName := fmt.Sprintf("pattern%d", c.ruleNum)
|
|
c.rules = append(c.rules, fmt.Sprintf(`%s = "\"" %s "\"" .`, ruleName, expr))
|
|
return ruleName
|
|
}
|
|
|
|
func (c *converter) regexToEBNF(pattern string) (string, bool) {
|
|
// Simple regex to EBNF converter
|
|
// Handles: literals, [a-z], [A-Z], [0-9], +, *, ?, basic groups
|
|
|
|
var result strings.Builder
|
|
i := 0
|
|
|
|
for i < len(pattern) {
|
|
ch := pattern[i]
|
|
|
|
switch ch {
|
|
case '[':
|
|
// Character class
|
|
end := strings.Index(pattern[i:], "]")
|
|
if end == -1 {
|
|
return "", false
|
|
}
|
|
class := pattern[i+1 : i+end]
|
|
ebnfClass, ok := c.charClassToEBNF(class)
|
|
if !ok {
|
|
return "", false
|
|
}
|
|
result.WriteString(ebnfClass)
|
|
i += end + 1
|
|
|
|
case '(':
|
|
// Group - find matching )
|
|
depth := 1
|
|
start := i + 1
|
|
j := start
|
|
for j < len(pattern) && depth > 0 {
|
|
if pattern[j] == '(' {
|
|
depth++
|
|
} else if pattern[j] == ')' {
|
|
depth--
|
|
}
|
|
j++
|
|
}
|
|
if depth != 0 {
|
|
return "", false
|
|
}
|
|
groupContent := pattern[start : j-1]
|
|
groupExpr, ok := c.regexToEBNF(groupContent)
|
|
if !ok {
|
|
return "", false
|
|
}
|
|
result.WriteString("( ")
|
|
result.WriteString(groupExpr)
|
|
result.WriteString(" )")
|
|
i = j
|
|
|
|
case '|':
|
|
result.WriteString(" | ")
|
|
i++
|
|
|
|
case '+':
|
|
// One or more - wrap previous in { } and add one required
|
|
// This is a simplification
|
|
return "", false // TODO: handle properly
|
|
|
|
case '*':
|
|
// Zero or more - need to wrap previous
|
|
return "", false // TODO: handle properly
|
|
|
|
case '?':
|
|
// Optional - need to wrap previous in [ ]
|
|
return "", false // TODO: handle properly
|
|
|
|
case '\\':
|
|
// Escape sequence
|
|
if i+1 >= len(pattern) {
|
|
return "", false
|
|
}
|
|
next := pattern[i+1]
|
|
switch next {
|
|
case 'd':
|
|
result.WriteString("digit")
|
|
c.usedTypes["digit"] = true
|
|
case 'w':
|
|
result.WriteString(`( "a" … "z" | "A" … "Z" | "0" … "9" | "_" )`)
|
|
case 's':
|
|
result.WriteString(`( " " | "\t" )`)
|
|
default:
|
|
result.WriteString(fmt.Sprintf(`"%c"`, next))
|
|
}
|
|
i += 2
|
|
|
|
default:
|
|
// Literal character
|
|
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' || ch == '.' {
|
|
result.WriteString(fmt.Sprintf(`"%c" `, ch))
|
|
} else {
|
|
// Special char, try to escape
|
|
result.WriteString(fmt.Sprintf(`"%c" `, ch))
|
|
}
|
|
i++
|
|
}
|
|
}
|
|
|
|
return strings.TrimSpace(result.String()), true
|
|
}
|
|
|
|
func (c *converter) charClassToEBNF(class string) (string, bool) {
|
|
// Handle character classes like a-z, A-Z, 0-9
|
|
if class == "a-zA-Z0-9_" || class == "a-zA-Z_" {
|
|
return `( "a" … "z" | "A" … "Z" | "0" … "9" | "_" )`, true
|
|
}
|
|
if class == "a-zA-Z0-9" {
|
|
return `( "a" … "z" | "A" … "Z" | "0" … "9" )`, true
|
|
}
|
|
if class == "a-z" {
|
|
return `"a" … "z"`, true
|
|
}
|
|
if class == "A-Z" {
|
|
return `"A" … "Z"`, true
|
|
}
|
|
if class == "0-9" {
|
|
c.usedTypes["digit"] = true
|
|
return "digit", true
|
|
}
|
|
|
|
// Try to parse range patterns
|
|
if matched, _ := regexp.MatchString(`^[a-zA-Z]-[a-zA-Z]$`, class); matched {
|
|
return fmt.Sprintf(`"%c" … "%c"`, class[0], class[2]), true
|
|
}
|
|
if matched, _ := regexp.MatchString(`^[0-9]-[0-9]$`, class); matched {
|
|
return fmt.Sprintf(`"%c" … "%c"`, class[0], class[2]), true
|
|
}
|
|
|
|
return "", false
|
|
}
|
|
|
|
func (c *converter) anyOfToExpr(schemas []*schemaNode, name string) string {
|
|
var parts []string
|
|
for i, s := range schemas {
|
|
expr := c.schemaToExpr(s, fmt.Sprintf("%s_opt%d", name, i))
|
|
parts = append(parts, expr)
|
|
}
|
|
return "( " + strings.Join(parts, " | ") + " )"
|
|
}
|
|
|
|
func (c *converter) enumToExpr(values []interface{}) string {
|
|
var parts []string
|
|
for _, v := range values {
|
|
parts = append(parts, c.constToExpr(v))
|
|
}
|
|
return "( " + strings.Join(parts, " | ") + " )"
|
|
}
|
|
|
|
func (c *converter) constToExpr(v interface{}) string {
|
|
switch val := v.(type) {
|
|
case string:
|
|
return fmt.Sprintf(`"\"%s\""`, c.escapeString(val))
|
|
case float64:
|
|
if val == float64(int(val)) {
|
|
return fmt.Sprintf(`"%d"`, int(val))
|
|
}
|
|
return fmt.Sprintf(`"%v"`, val)
|
|
case bool:
|
|
if val {
|
|
return `"true"`
|
|
}
|
|
return `"false"`
|
|
case nil:
|
|
return `"null"`
|
|
default:
|
|
c.usedTypes["string"] = true
|
|
return "string"
|
|
}
|
|
}
|
|
|
|
func (c *converter) resolveRef(ref string) string {
|
|
// Handle #/$defs/name references
|
|
if strings.HasPrefix(ref, "#/$defs/") {
|
|
defName := strings.TrimPrefix(ref, "#/$defs/")
|
|
return c.resolveDefRef(defName)
|
|
}
|
|
|
|
// Handle root recursion #
|
|
if ref == "#" {
|
|
return "root"
|
|
}
|
|
|
|
// Unknown ref format
|
|
c.usedTypes["string"] = true
|
|
return "string"
|
|
}
|
|
|
|
func (c *converter) resolveDefRef(defName string) string {
|
|
// Check if we've already defined this as a rule
|
|
ruleName := "def_" + defName
|
|
if c.definedRefs[defName] {
|
|
return ruleName
|
|
}
|
|
|
|
// Mark as defined to prevent infinite recursion
|
|
c.definedRefs[defName] = true
|
|
|
|
// Look up the definition
|
|
if c.definitions == nil {
|
|
c.usedTypes["string"] = true
|
|
return "string"
|
|
}
|
|
|
|
defSchema, ok := c.definitions[defName]
|
|
if !ok {
|
|
c.usedTypes["string"] = true
|
|
return "string"
|
|
}
|
|
|
|
// Generate the rule
|
|
expr := c.schemaToExpr(defSchema, ruleName)
|
|
c.rules = append(c.rules, fmt.Sprintf("%s = %s .", ruleName, expr))
|
|
|
|
return ruleName
|
|
}
|
|
|
|
func (c *converter) getTypes(t interface{}) []string {
|
|
switch v := t.(type) {
|
|
case string:
|
|
return []string{v}
|
|
case []interface{}:
|
|
var types []string
|
|
for _, item := range v {
|
|
if s, ok := item.(string); ok {
|
|
types = append(types, s)
|
|
}
|
|
}
|
|
return types
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *converter) escapeString(s string) string {
|
|
s = strings.ReplaceAll(s, `\`, `\\`)
|
|
s = strings.ReplaceAll(s, `"`, `\"`)
|
|
return s
|
|
}
|
|
|
|
// Grammar converts a JSON Schema string into a compiled grammar.
|
|
func Grammar(schemaJSON string) (*grammar.Grammar, error) {
|
|
ebnf, err := EBNF(schemaJSON)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return grammar.ParseEBNF(ebnf, "root")
|
|
}
|