Files
ollama/x/mlxrunner/model/base/base.go
Patrick Devine 97323d1c68 consolidate the tokenizer (#14327)
This change adds a new x/tokenizer package which includes:
  * New BPE and SentencePiece tokenizers
  * Removing the dependency on the imagegen tokenizers
  * Fixes to multibyte decoding in the pipeline
  * Various correctness and benchmark tests

Not included in this PR is the WordPiece tokenizer for BERT models which will be
added when we add embedding models. The imagegen tokenizers will also be removed in
a follow-up PR.
2026-02-19 15:55:45 -08:00

86 lines
2.4 KiB
Go

//go:build mlx
package base
import (
"encoding/json"
"fmt"
"log/slog"
"sync"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/tokenizer"
)
// Model is the interface that model implementations must satisfy.
type Model interface {
Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
Unembed(x *mlx.Array) *mlx.Array
NumLayers() int
Tokenizer() *tokenizer.Tokenizer
// LoadWeights receives all tensors loaded from the manifest and assigns
// them to model fields. Model-specific logic (MLA absorption, expert
// stacking, quantized layer creation) happens here.
LoadWeights(tensors map[string]*mlx.Array) error
}
var (
mu sync.Mutex
registry = make(map[string]func(root *model.Root) (Model, error))
)
// Register registers a model constructor by architecture name.
// Called from init() in model packages. Panics on duplicate registration.
func Register(arch string, fn func(root *model.Root) (Model, error)) {
mu.Lock()
defer mu.Unlock()
if _, exists := registry[arch]; exists {
panic(fmt.Sprintf("model architecture %q already registered", arch))
}
registry[arch] = fn
}
// New reads config.json from the manifest, detects the architecture, looks up
// the registered constructor, and calls it to create the model (with config
// parsed and struct created, but weights not yet loaded).
func New(root *model.Root) (Model, error) {
configData, err := root.Manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("failed to read config.json: %w", err)
}
var archConfig struct {
Architectures []string `json:"architectures"`
}
if err := json.Unmarshal(configData, &archConfig); err != nil {
return nil, fmt.Errorf("failed to parse config.json: %w", err)
}
if len(archConfig.Architectures) == 0 {
return nil, fmt.Errorf("no architectures found in config.json")
}
arch := archConfig.Architectures[0]
slog.Info("Model architecture", "arch", arch)
mu.Lock()
fn, ok := registry[arch]
mu.Unlock()
if !ok {
return nil, fmt.Errorf("unsupported architecture: %s", arch)
}
return fn(root)
}
// Weights returns the model's LoadWeights method, which encapsulates all
// weight assignment and post-processing (MLA absorption, expert stacking).
func Weights(m Model) func(map[string]*mlx.Array) error {
return m.LoadWeights
}