mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 15:54:11 +02:00
158 lines
3.8 KiB
Go
158 lines
3.8 KiB
Go
package main
|
|
|
|
import (
|
|
"embed"
|
|
"flag"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"slices"
|
|
"strings"
|
|
"text/template"
|
|
|
|
tree_sitter "github.com/tree-sitter/go-tree-sitter"
|
|
tree_sitter_cpp "github.com/tree-sitter/tree-sitter-cpp/bindings/go"
|
|
)
|
|
|
|
//go:embed *.gotmpl
|
|
var fsys embed.FS
|
|
|
|
// optionalSymbols lists symbols that may not be present in all builds
|
|
// (e.g., float16/bfloat16 are unavailable in CUDA builds of MLX).
|
|
var optionalSymbols = map[string]bool{
|
|
"mlx_array_item_float16": true,
|
|
"mlx_array_item_bfloat16": true,
|
|
"mlx_array_data_float16": true,
|
|
"mlx_array_data_bfloat16": true,
|
|
}
|
|
|
|
type Function struct {
|
|
Type,
|
|
Name,
|
|
Parameters,
|
|
Args string
|
|
Optional bool
|
|
}
|
|
|
|
func ParseFunction(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) Function {
|
|
var fn Function
|
|
fn.Name = node.ChildByFieldName("declarator").Utf8Text(source)
|
|
if params := node.ChildByFieldName("parameters"); params != nil {
|
|
fn.Parameters = params.Utf8Text(source)
|
|
fn.Args = ParseParameters(params, tc, source)
|
|
}
|
|
|
|
var types []string
|
|
for node.Parent() != nil && node.Parent().Kind() != "declaration" {
|
|
if node.Parent().Kind() == "pointer_declarator" {
|
|
types = append(types, "*")
|
|
}
|
|
node = node.Parent()
|
|
}
|
|
|
|
for sibling := node.PrevSibling(); sibling != nil; sibling = sibling.PrevSibling() {
|
|
types = append(types, sibling.Utf8Text(source))
|
|
}
|
|
|
|
slices.Reverse(types)
|
|
fn.Type = strings.Join(types, " ")
|
|
return fn
|
|
}
|
|
|
|
func ParseParameters(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) string {
|
|
var s []string
|
|
for _, child := range node.Children(tc) {
|
|
if child.IsNamed() {
|
|
child := child.ChildByFieldName("declarator")
|
|
for child != nil && child.Kind() != "identifier" {
|
|
if child.Kind() == "parenthesized_declarator" {
|
|
child = child.Child(1)
|
|
} else {
|
|
child = child.ChildByFieldName("declarator")
|
|
}
|
|
}
|
|
|
|
if child != nil {
|
|
s = append(s, child.Utf8Text(source))
|
|
}
|
|
}
|
|
}
|
|
return strings.Join(s, ", ")
|
|
}
|
|
|
|
func main() {
|
|
var output string
|
|
flag.StringVar(&output, "output", ".", "Output directory for generated files")
|
|
flag.Parse()
|
|
|
|
parser := tree_sitter.NewParser()
|
|
defer parser.Close()
|
|
|
|
language := tree_sitter.NewLanguage(tree_sitter_cpp.Language())
|
|
parser.SetLanguage(language)
|
|
|
|
query, _ := tree_sitter.NewQuery(language, `(function_declarator declarator: (identifier)) @func`)
|
|
defer query.Close()
|
|
|
|
qc := tree_sitter.NewQueryCursor()
|
|
defer qc.Close()
|
|
|
|
var files []string
|
|
for _, arg := range flag.Args() {
|
|
matches, err := filepath.Glob(arg)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "Error expanding glob %s: %v\n", arg, err)
|
|
continue
|
|
}
|
|
files = append(files, matches...)
|
|
}
|
|
|
|
var funs []Function
|
|
for _, arg := range files {
|
|
bts, err := os.ReadFile(arg)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "Error reading file %s: %v\n", arg, err)
|
|
continue
|
|
}
|
|
|
|
tree := parser.Parse(bts, nil)
|
|
defer tree.Close()
|
|
|
|
tc := tree.Walk()
|
|
defer tc.Close()
|
|
|
|
matches := qc.Matches(query, tree.RootNode(), bts)
|
|
for match := matches.Next(); match != nil; match = matches.Next() {
|
|
for _, capture := range match.Captures {
|
|
fn := ParseFunction(&capture.Node, tc, bts)
|
|
fn.Optional = optionalSymbols[fn.Name]
|
|
funs = append(funs, fn)
|
|
}
|
|
}
|
|
}
|
|
|
|
tmpl, err := template.New("").ParseFS(fsys, "*.gotmpl")
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "Error parsing template: %v\n", err)
|
|
return
|
|
}
|
|
|
|
for _, tmpl := range tmpl.Templates() {
|
|
name := filepath.Join(output, strings.TrimSuffix(tmpl.Name(), ".gotmpl"))
|
|
|
|
fmt.Println("Generating", name)
|
|
f, err := os.Create(name)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "Error creating file %s: %v\n", name, err)
|
|
continue
|
|
}
|
|
defer f.Close()
|
|
|
|
if err := tmpl.Execute(f, map[string]any{
|
|
"Functions": funs,
|
|
}); err != nil {
|
|
fmt.Fprintf(os.Stderr, "Error executing template %s: %v\n", tmpl.Name(), err)
|
|
}
|
|
}
|
|
}
|