mlx: fix vision capability + min version (#15106)

This commit is contained in:
Patrick Devine
2026-03-27 17:09:28 -07:00
committed by GitHub
parent 3824e380a8
commit 9e7cb9697e
7 changed files with 208 additions and 22 deletions

View File

@@ -301,7 +301,7 @@ Weigh anchor!
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
Requires: "0.14.0",
Requires: "0.19.0",
}, false, &b); err != nil {
t.Fatal(err)
}
@@ -310,10 +310,17 @@ Weigh anchor!
architecture test
parameters 7B
quantization FP16
requires 0.14.0
requires 0.19.0
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
trimLinePadding := func(s string) string {
lines := strings.Split(s, "\n")
for i, line := range lines {
lines[i] = strings.TrimRight(line, " \t\r")
}
return strings.Join(lines, "\n")
}
if diff := cmp.Diff(trimLinePadding(expect), trimLinePadding(b.String())); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
@@ -1912,7 +1919,7 @@ func TestShowInfoImageGen(t *testing.T) {
QuantizationLevel: "Q8",
},
Capabilities: []model.Capability{model.CapabilityImage},
Requires: "0.14.0",
Requires: "0.19.0",
}, false, &b)
if err != nil {
t.Fatal(err)
@@ -1922,7 +1929,7 @@ func TestShowInfoImageGen(t *testing.T) {
" architecture ZImagePipeline \n" +
" parameters 10.3B \n" +
" quantization Q8 \n" +
" requires 0.14.0 \n" +
" requires 0.19.0 \n" +
"\n" +
" Capabilities\n" +
" image \n" +

View File

@@ -1225,9 +1225,11 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
modelDetails.ParameterSize = format.HumanNumber(uint64(paramCount))
}
}
// Get torch_dtype directly from config.json for quantization level
if dtype, err := xserver.GetSafetensorsDtype(name); err == nil && dtype != "" {
modelDetails.QuantizationLevel = dtype
// Older manifests may not have file_type populated for safetensors models.
if modelDetails.QuantizationLevel == "" {
if dtype, err := xserver.GetSafetensorsDtype(name); err == nil && dtype != "" {
modelDetails.QuantizationLevel = dtype
}
}
}

View File

@@ -26,6 +26,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/types/model"
@@ -547,6 +548,38 @@ func TestRoutes(t *testing.T) {
}
}
func TestGetModelInfo_SafetensorsUsesStoredFileType(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
cfgData, err := json.Marshal(model.ConfigV2{
ModelFormat: "safetensors",
FileType: "mxfp8",
Capabilities: []string{"completion"},
})
if err != nil {
t.Fatalf("failed to marshal config: %v", err)
}
configLayer, err := manifest.NewLayer(bytes.NewReader(cfgData), "application/vnd.docker.container.image.v1+json")
if err != nil {
t.Fatalf("failed to create config layer: %v", err)
}
name := model.ParseName("show-safetensors")
if err := manifest.WriteManifest(name, configLayer, nil); err != nil {
t.Fatalf("failed to write manifest: %v", err)
}
resp, err := GetModelInfo(api.ShowRequest{Model: name.String()})
if err != nil {
t.Fatalf("GetModelInfo() error = %v", err)
}
if resp.Details.QuantizationLevel != "mxfp8" {
t.Fatalf("QuantizationLevel = %q, want %q", resp.Details.QuantizationLevel, "mxfp8")
}
}
func casingShuffle(s string) string {
rr := []rune(s)
for i := range rr {

View File

@@ -26,7 +26,7 @@ import (
)
// MinOllamaVersion is the minimum Ollama version required for safetensors models.
const MinOllamaVersion = "0.14.0"
const MinOllamaVersion = "0.19.0"
// ModelfileConfig holds configuration extracted from a Modelfile.
type ModelfileConfig struct {
@@ -132,12 +132,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
if isSafetensors {
modelType = "safetensors model"
spinnerKey = "create"
capabilities = []string{"completion"}
// Check if model supports thinking based on architecture
if supportsThinking(opts.ModelDir) {
capabilities = append(capabilities, "thinking")
}
capabilities = inferSafetensorsCapabilities(opts.ModelDir)
// Set parser and renderer name based on architecture
parserName = getParserName(opts.ModelDir)
@@ -188,6 +183,21 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
return nil
}
func inferSafetensorsCapabilities(modelDir string) []string {
capabilities := []string{"completion"}
// Qwen3.5 multimodal checkpoints use ConditionalGeneration architectures.
if supportsVision(modelDir) {
capabilities = append(capabilities, "vision")
}
if supportsThinking(modelDir) {
capabilities = append(capabilities, "thinking")
}
return capabilities
}
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
func newLayerCreator() create.LayerCreator {
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
@@ -338,6 +348,7 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
// Create config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
FileType: strings.ToLower(strings.TrimSpace(opts.Quantize)),
Capabilities: caps,
Requires: MinOllamaVersion,
Parser: resolveParserName(opts.Modelfile, parserName),
@@ -485,6 +496,34 @@ func supportsThinking(modelDir string) bool {
return false
}
// supportsVision checks if the model supports image input based on its architecture.
// Qwen3.5 multimodal checkpoints are published as ConditionalGeneration architectures.
func supportsVision(modelDir string) bool {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return false
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return false
}
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "qwen3") && strings.Contains(archLower, "conditionalgeneration") {
return true
}
}
typeLower := strings.ToLower(cfg.ModelType)
return strings.Contains(typeLower, "qwen3") && strings.Contains(typeLower, "conditionalgeneration")
}
// getParserName returns the parser name for a model based on its architecture.
// This reads the config.json from the model directory and determines the appropriate parser.
func getParserName(modelDir string) string {

View File

@@ -3,11 +3,15 @@ package client
import (
"encoding/json"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/create"
)
func TestModelfileConfig(t *testing.T) {
@@ -120,8 +124,8 @@ func TestMinOllamaVersion(t *testing.T) {
if MinOllamaVersion == "" {
t.Error("MinOllamaVersion should not be empty")
}
if MinOllamaVersion != "0.14.0" {
t.Errorf("MinOllamaVersion = %q, want %q", MinOllamaVersion, "0.14.0")
if MinOllamaVersion != "0.19.0" {
t.Errorf("MinOllamaVersion = %q, want %q", MinOllamaVersion, "0.19.0")
}
}
@@ -289,6 +293,52 @@ func TestCreateOptions_Defaults(t *testing.T) {
}
}
func TestInferSafetensorsCapabilities(t *testing.T) {
tests := []struct {
name string
configJSON string
want []string
}{
{
name: "qwen3.5 text model",
configJSON: `{
"architectures": ["Qwen3_5ForCausalLM"],
"model_type": "qwen3"
}`,
want: []string{"completion", "thinking"},
},
{
name: "qwen3.5 multimodal model",
configJSON: `{
"architectures": ["Qwen3_5ForConditionalGeneration"],
"model_type": "qwen3"
}`,
want: []string{"completion", "vision", "thinking"},
},
{
name: "non-qwen conditional generation model",
configJSON: `{
"architectures": ["SomeOtherForConditionalGeneration"],
"model_type": "other"
}`,
want: []string{"completion"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644); err != nil {
t.Fatal(err)
}
if got := inferSafetensorsCapabilities(dir); !slices.Equal(got, tt.want) {
t.Fatalf("inferSafetensorsCapabilities() = %#v, want %#v", got, tt.want)
}
})
}
}
func TestQuantizeSupported(t *testing.T) {
// This just verifies the function exists and returns a boolean
// The actual value depends on build tags (mlx vs non-mlx)
@@ -339,3 +389,43 @@ func TestCreateModelfileLayersIncludesParameters(t *testing.T) {
t.Fatalf("temperature = %v, want %v", got["temperature"], float64(0.7))
}
}
func TestNewManifestWriter_PopulatesFileTypeFromQuantize(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
opts := CreateOptions{
ModelName: "test-quantized",
ModelDir: t.TempDir(),
Quantize: "MXFP8",
}
writer := newManifestWriter(opts, []string{"completion"}, "qwen3", "qwen3")
if err := writer(opts.ModelName, create.LayerInfo{}, nil); err != nil {
t.Fatalf("newManifestWriter() error = %v", err)
}
name := model.ParseName(opts.ModelName)
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
t.Fatalf("ParseNamedManifest() error = %v", err)
}
configPath, err := manifest.BlobsPath(mf.Config.Digest)
if err != nil {
t.Fatalf("BlobsPath() error = %v", err)
}
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatalf("ReadFile() error = %v", err)
}
var cfg model.ConfigV2
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if cfg.FileType != "mxfp8" {
t.Fatalf("FileType = %q, want %q", cfg.FileType, "mxfp8")
}
}

View File

@@ -15,6 +15,10 @@ import (
"github.com/ollama/ollama/types/model"
)
func canonicalQuantType(quantType string) string {
return strings.ToLower(strings.TrimSpace(quantType))
}
// modelConfig represents the HuggingFace config.json structure
type modelConfig struct {
Architectures []string `json:"architectures"`
@@ -256,7 +260,7 @@ func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
}
if info.QuantType != "" {
quantType := strings.ToUpper(info.QuantType)
quantType := canonicalQuantType(info.QuantType)
shape := make([]uint64, len(info.Shape))
for i, s := range info.Shape {
@@ -323,8 +327,8 @@ func GetSafetensorsDtype(name model.Name) (string, error) {
if err != nil {
continue
}
if info.QuantType != "" {
return strings.ToUpper(info.QuantType), nil
if quantType := canonicalQuantType(info.QuantType); quantType != "" {
return quantType, nil
}
// Only check the first tensor blob
break

View File

@@ -705,8 +705,8 @@ func TestGetTensorInfoFromManifest_Quantized(t *testing.T) {
if tensor.Name != "model.layers.0.mlp.up_proj.weight" {
t.Errorf("Name = %v, want model.layers.0.mlp.up_proj.weight", tensor.Name)
}
if tensor.Type != "INT4" {
t.Errorf("Type = %v, want INT4", tensor.Type)
if tensor.Type != "int4" {
t.Errorf("Type = %v, want int4", tensor.Type)
}
// Shape should be unpacked: 320 * 8 = 2560
if len(tensor.Shape) != 2 || tensor.Shape[0] != 2560 || tensor.Shape[1] != 2560 {
@@ -1196,6 +1196,17 @@ func TestGetTensorInfoFromManifest_Packed(t *testing.T) {
if !packedNames["model.layers.0.mlp.experts.0.gate_proj.weight"] {
t.Error("missing packed tensor: model.layers.0.mlp.experts.0.gate_proj.weight")
}
packedTypes := make(map[string]string)
for _, r := range result[1:] {
packedTypes[r.Name] = r.Type
}
if packedTypes["model.layers.0.mlp.experts.0.down_proj.weight"] != "int8" {
t.Errorf("down_proj.Type = %v, want int8", packedTypes["model.layers.0.mlp.experts.0.down_proj.weight"])
}
if packedTypes["model.layers.0.mlp.experts.0.gate_proj.weight"] != "int4" {
t.Errorf("gate_proj.Type = %v, want int4", packedTypes["model.layers.0.mlp.experts.0.gate_proj.weight"])
}
}
func TestReadSafetensorsHeader(t *testing.T) {