Files
ollama/x/mlxrunner/model/base/base.go
Daniel Hiltgen 10e51c5177 MLX: add header vendoring and remove go build tag (#14642)
* prefer rocm v6 on windows

Avoid building with v7 - more changes are needed

* MLX: add header vendoring and remove go build tag

This switches to using a vendoring approach for the mlx-c headers so that Go
can build without requiring a cmake first.  This enables building the new MLX
based code by default.  Every time cmake runs, the headers are refreshed, so we
can easily keep them in sync when we bump mlx versions.  Basic Windows
and Linux support are verified.

* ci: harden for flaky choco repo servers

CI sometimes fails due to choco not actually installing cache.  Since it just speeds up the build, we can proceed without.

* review comments
2026-03-09 17:24:45 -07:00

98 lines
2.6 KiB
Go

package base
import (
"encoding/json"
"fmt"
"log/slog"
"sync"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/tokenizer"
)
// Model is the interface that model implementations must satisfy.
type Model interface {
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
Unembed(x *mlx.Array) *mlx.Array
NumLayers() int
Tokenizer() *tokenizer.Tokenizer
MaxContextLength() int
// LoadWeights receives all tensors loaded from the manifest and assigns
// them to model fields. Model-specific logic (MLA absorption, expert
// stacking, quantized layer creation) happens here.
LoadWeights(tensors map[string]*mlx.Array) error
}
var (
mu sync.Mutex
registry = make(map[string]func(root *model.Root) (Model, error))
)
// Register registers a model constructor by architecture name.
// Called from init() in model packages. Panics on duplicate registration.
func Register(arch string, fn func(root *model.Root) (Model, error)) {
mu.Lock()
defer mu.Unlock()
if _, exists := registry[arch]; exists {
panic(fmt.Sprintf("model architecture %q already registered", arch))
}
registry[arch] = fn
}
// New reads config.json from the manifest, detects the architecture, looks up
// the registered constructor, and calls it to create the model (with config
// parsed and struct created, but weights not yet loaded).
func New(root *model.Root) (Model, error) {
configData, err := root.Manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("failed to read config.json: %w", err)
}
var archConfig struct {
Architectures []string `json:"architectures"`
}
if err := json.Unmarshal(configData, &archConfig); err != nil {
return nil, fmt.Errorf("failed to parse config.json: %w", err)
}
if len(archConfig.Architectures) == 0 {
return nil, fmt.Errorf("no architectures found in config.json")
}
arch := archConfig.Architectures[0]
slog.Info("Model architecture", "arch", arch)
mu.Lock()
fn, ok := registry[arch]
mu.Unlock()
if !ok {
return nil, fmt.Errorf("unsupported architecture: %s", arch)
}
return fn(root)
}
// Weights returns a function that loads model weights, then pins all
// arrays reachable from the model struct and sweeps everything else.
func Weights(m Model) func(map[string]*mlx.Array) error {
return func(tensors map[string]*mlx.Array) error {
if err := m.LoadWeights(tensors); err != nil {
return err
}
collected := mlx.Collect(m)
for _, arr := range collected {
mlx.Pin(arr)
}
mlx.Sweep()
mlx.Eval(collected...)
return nil
}
}