mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
* prefer rocm v6 on windows Avoid building with v7 - more changes are needed * MLX: add header vendoring and remove go build tag This switches to using a vendoring approach for the mlx-c headers so that Go can build without requiring a cmake first. This enables building the new MLX based code by default. Every time cmake runs, the headers are refreshed, so we can easily keep them in sync when we bump mlx versions. Basic Windows and Linux support are verified. * ci: harden for flaky choco repo servers CI sometimes fails due to choco not actually installing cache. Since it just speeds up the build, we can proceed without. * review comments
166 lines
3.6 KiB
Go
166 lines
3.6 KiB
Go
package safetensors
|
|
|
|
import (
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
)
|
|
|
|
func TestLoadModelWeights(t *testing.T) {
|
|
// Skip if no model available
|
|
modelDir := "../weights/gpt-oss-20b"
|
|
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
|
t.Skip("model weights not available")
|
|
}
|
|
|
|
mw, err := LoadModelWeights(modelDir)
|
|
if err != nil {
|
|
t.Fatalf("LoadModelWeights: %v", err)
|
|
}
|
|
defer mw.ReleaseAll()
|
|
|
|
// Check we found tensors
|
|
tensors := mw.ListTensors()
|
|
if len(tensors) == 0 {
|
|
t.Fatal("no tensors found")
|
|
}
|
|
t.Logf("found %d tensors", len(tensors))
|
|
|
|
// Check HasTensor
|
|
if !mw.HasTensor(tensors[0]) {
|
|
t.Errorf("HasTensor(%q) = false", tensors[0])
|
|
}
|
|
if mw.HasTensor("nonexistent.weight") {
|
|
t.Error("HasTensor returned true for nonexistent tensor")
|
|
}
|
|
}
|
|
|
|
func TestGetTensor(t *testing.T) {
|
|
modelDir := "../weights/gpt-oss-20b"
|
|
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
|
t.Skip("model weights not available")
|
|
}
|
|
|
|
mw, err := LoadModelWeights(modelDir)
|
|
if err != nil {
|
|
t.Fatalf("LoadModelWeights: %v", err)
|
|
}
|
|
defer mw.ReleaseAll()
|
|
|
|
tensors := mw.ListTensors()
|
|
if len(tensors) == 0 {
|
|
t.Skip("no tensors")
|
|
}
|
|
|
|
// Load first tensor
|
|
arr, err := mw.GetTensor(tensors[0])
|
|
if err != nil {
|
|
t.Fatalf("GetTensor(%q): %v", tensors[0], err)
|
|
}
|
|
|
|
// Verify it has a shape
|
|
shape := arr.Shape()
|
|
if len(shape) == 0 {
|
|
t.Error("tensor has no shape")
|
|
}
|
|
t.Logf("%s: shape=%v dtype=%v", tensors[0], shape, arr.Dtype())
|
|
}
|
|
|
|
func TestLoadWithDtype(t *testing.T) {
|
|
modelDir := "../weights/gpt-oss-20b"
|
|
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
|
t.Skip("model weights not available")
|
|
}
|
|
|
|
mw, err := LoadModelWeights(modelDir)
|
|
if err != nil {
|
|
t.Fatalf("LoadModelWeights: %v", err)
|
|
}
|
|
defer mw.ReleaseAll()
|
|
|
|
// Load all tensors as bfloat16
|
|
if err := mw.Load(mlx.DtypeBFloat16); err != nil {
|
|
t.Fatalf("Load: %v", err)
|
|
}
|
|
|
|
// Get a tensor from cache
|
|
tensors := mw.ListTensors()
|
|
arr, err := mw.Get(tensors[0])
|
|
if err != nil {
|
|
t.Fatalf("Get: %v", err)
|
|
}
|
|
|
|
// Verify dtype (unless it was already bf16)
|
|
t.Logf("%s: dtype=%v", tensors[0], arr.Dtype())
|
|
}
|
|
|
|
func TestLookupTensor(t *testing.T) {
|
|
modelDir := "../weights/gpt-oss-20b"
|
|
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
|
t.Skip("model weights not available")
|
|
}
|
|
|
|
mw, err := LoadModelWeights(modelDir)
|
|
if err != nil {
|
|
t.Fatalf("LoadModelWeights: %v", err)
|
|
}
|
|
defer mw.ReleaseAll()
|
|
|
|
// HasTensor returns false for nonexistent
|
|
if mw.HasTensor("nonexistent") {
|
|
t.Error("HasTensor should return false for nonexistent")
|
|
}
|
|
|
|
// HasTensor returns true for existing tensor
|
|
tensors := mw.ListTensors()
|
|
if !mw.HasTensor(tensors[0]) {
|
|
t.Error("HasTensor should return true for existing tensor")
|
|
}
|
|
}
|
|
|
|
func TestParseSafetensorHeader(t *testing.T) {
|
|
modelDir := "../weights/gpt-oss-20b"
|
|
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
|
|
t.Skip("model weights not available")
|
|
}
|
|
|
|
// Find a safetensors file
|
|
entries, err := os.ReadDir(modelDir)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
var stFile string
|
|
for _, e := range entries {
|
|
if filepath.Ext(e.Name()) == ".safetensors" {
|
|
stFile = filepath.Join(modelDir, e.Name())
|
|
break
|
|
}
|
|
}
|
|
if stFile == "" {
|
|
t.Skip("no safetensors file found")
|
|
}
|
|
|
|
header, err := parseSafetensorHeader(stFile)
|
|
if err != nil {
|
|
t.Fatalf("parseSafetensorHeader: %v", err)
|
|
}
|
|
|
|
if len(header) == 0 {
|
|
t.Error("header is empty")
|
|
}
|
|
|
|
// Check a tensor has valid info
|
|
for name, info := range header {
|
|
if info.Dtype == "" {
|
|
t.Errorf("%s: empty dtype", name)
|
|
}
|
|
if len(info.Shape) == 0 {
|
|
t.Errorf("%s: empty shape", name)
|
|
}
|
|
break // just check one
|
|
}
|
|
}
|