mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
* create: Clean up experimental paths
This cleans up the experimental features, and adds both unit and integration test coverage to verify no regressions.
* create: preserve config and layer names when creating from safetensors models
When creating a model FROM an existing safetensors model, ModelFormat,
Capabilities, and layer Name fields were lost. ModelFormat stayed empty
because it's only set from GGML layers (which safetensors models lack),
and layer names weren't copied in parseFromModel. This caused derived
models to fail loading ("config.json not found in manifest").
* review comments
352 lines
9.5 KiB
Go
352 lines
9.5 KiB
Go
package create
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/ollama/ollama/x/safetensors"
|
|
)
|
|
|
|
type qwen35ImportTransform struct {
|
|
shouldShiftNormWeights bool
|
|
rewriteLanguageModel bool
|
|
}
|
|
|
|
type qwen35SourceInfo struct {
|
|
hasPrequantizedWeights bool
|
|
shouldShiftNormWeights bool
|
|
}
|
|
|
|
func newQwen35ImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) {
|
|
sourceInfo, err := qwen35InspectSource(modelDir)
|
|
if err != nil {
|
|
return qwen35ImportTransform{}, err
|
|
}
|
|
if sourceInfo.hasPrequantizedWeights {
|
|
return noopImportTransform{}, nil
|
|
}
|
|
|
|
return qwen35ImportTransform{
|
|
shouldShiftNormWeights: sourceInfo.shouldShiftNormWeights,
|
|
rewriteLanguageModel: strings.Contains(cfg.Architecture(), "ConditionalGeneration"),
|
|
}, nil
|
|
}
|
|
|
|
func qwen35InspectSource(modelDir string) (qwen35SourceInfo, error) {
|
|
entries, err := os.ReadDir(modelDir)
|
|
if err != nil {
|
|
return qwen35SourceInfo{}, err
|
|
}
|
|
|
|
var info qwen35SourceInfo
|
|
for _, entry := range entries {
|
|
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".safetensors") {
|
|
continue
|
|
}
|
|
|
|
extractor, err := safetensors.OpenForExtraction(filepath.Join(modelDir, entry.Name()))
|
|
if err != nil {
|
|
return qwen35SourceInfo{}, err
|
|
}
|
|
|
|
for _, name := range extractor.ListTensors() {
|
|
if strings.HasSuffix(name, ".scales") {
|
|
extractor.Close()
|
|
info.hasPrequantizedWeights = true
|
|
return info, nil
|
|
}
|
|
// This should change when MTP is supported
|
|
if strings.Contains(name, "mtp.") {
|
|
info.shouldShiftNormWeights = true
|
|
continue
|
|
}
|
|
if info.shouldShiftNormWeights || !strings.Contains(name, "conv1d.weight") {
|
|
continue
|
|
}
|
|
|
|
td, err := extractor.GetTensor(name)
|
|
if err != nil {
|
|
extractor.Close()
|
|
return qwen35SourceInfo{}, err
|
|
}
|
|
if len(td.Shape) == 3 && td.Shape[2] != 1 {
|
|
info.shouldShiftNormWeights = true
|
|
}
|
|
}
|
|
|
|
extractor.Close()
|
|
}
|
|
|
|
return info, nil
|
|
}
|
|
|
|
func (t qwen35ImportTransform) skipTensor(name string) bool {
|
|
return strings.Contains(name, "mtp.")
|
|
}
|
|
|
|
func qwen35ShouldKeepBF16ForDirectNonAffine(name string) bool {
|
|
switch {
|
|
case strings.HasSuffix(name, "embed_tokens.weight"):
|
|
return true
|
|
case strings.HasSuffix(name, "lm_head.weight"):
|
|
return true
|
|
case strings.HasSuffix(name, ".linear_attn.in_proj_a.weight"):
|
|
return true
|
|
case strings.HasSuffix(name, ".linear_attn.in_proj_b.weight"):
|
|
return true
|
|
case strings.HasSuffix(name, ".linear_attn.in_proj_ba.weight"):
|
|
return true
|
|
case strings.HasSuffix(name, ".mlp.gate.weight") && !strings.Contains(name, "_proj"):
|
|
return true
|
|
case strings.HasSuffix(name, ".mlp.shared_expert_gate.weight"):
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (t qwen35ImportTransform) quantizationType(name string, shape []int32, quantize string) string {
|
|
if strings.HasPrefix(name, "vision_tower.") {
|
|
return ""
|
|
}
|
|
|
|
stackedExpert := isStackedExpertWeight(name)
|
|
if strings.HasSuffix(name, ".bias") || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".qbias") ||
|
|
strings.HasSuffix(name, ".biases") || strings.HasSuffix(name, ".scales") {
|
|
return ""
|
|
}
|
|
if !stackedExpert && !strings.HasSuffix(name, ".weight") {
|
|
return ""
|
|
}
|
|
if strings.Contains(name, "norm") || strings.Contains(name, "ln_") || strings.Contains(name, "layernorm") {
|
|
return ""
|
|
}
|
|
if len(shape) != 2 && !(len(shape) == 3 && stackedExpert) {
|
|
return ""
|
|
}
|
|
|
|
var elems int64 = 1
|
|
for _, d := range shape {
|
|
elems *= int64(d)
|
|
}
|
|
if elems < 1024 {
|
|
return ""
|
|
}
|
|
|
|
quantNorm := normalizeQuantType(quantize)
|
|
groupSize := int32(32)
|
|
switch quantNorm {
|
|
case "nvfp4":
|
|
groupSize = 16
|
|
case "int4", "int8":
|
|
groupSize = 64
|
|
}
|
|
if shape[len(shape)-1]%groupSize != 0 {
|
|
return ""
|
|
}
|
|
|
|
// Match the working HF-FP8 import policy for direct NVFP4/MXFP4/MXFP8 imports:
|
|
// keep embeddings, LM head, low-rank linear_attn projections, and routing
|
|
// gates in BF16 rather than forcing them into a non-affine quantized format.
|
|
if (quantNorm == "nvfp4" || quantNorm == "mxfp4" || quantNorm == "mxfp8") && qwen35ShouldKeepBF16ForDirectNonAffine(name) {
|
|
return ""
|
|
}
|
|
|
|
return quantNorm
|
|
}
|
|
|
|
func (t qwen35ImportTransform) rewriteTensorData(td *safetensors.TensorData) (*safetensors.TensorData, error) {
|
|
if td == nil {
|
|
return td, nil
|
|
}
|
|
|
|
shiftNorm := t.shouldShiftNormWeights && qwen35ShouldShiftNormKey(td.Name)
|
|
transposeConv := strings.Contains(td.Name, "conv1d.weight") && len(td.Shape) == 3 && td.Shape[2] != 1
|
|
castToBF16 := qwen35NeedsCastToBF16(td.Name, td.Dtype)
|
|
if !shiftNorm && !transposeConv && !castToBF16 {
|
|
return td, nil
|
|
}
|
|
|
|
raw, err := io.ReadAll(td.Reader())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read tensor %s: %w", td.Name, err)
|
|
}
|
|
|
|
values, err := DecodeFloatTensor(td.Dtype, raw)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode tensor %s: %w", td.Name, err)
|
|
}
|
|
|
|
shape := append([]int32(nil), td.Shape...)
|
|
if transposeConv {
|
|
values, shape = qwen35TransposeConv1D(values, shape)
|
|
}
|
|
if shiftNorm {
|
|
for i := range values {
|
|
values[i] += 1.0
|
|
}
|
|
}
|
|
|
|
targetDtype := td.Dtype
|
|
if castToBF16 {
|
|
targetDtype = "BF16"
|
|
}
|
|
|
|
out, err := EncodeFloatTensor(targetDtype, values)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to encode tensor %s: %w", td.Name, err)
|
|
}
|
|
|
|
return safetensors.NewTensorDataFromBytes(td.Name, targetDtype, shape, out), nil
|
|
}
|
|
|
|
func (t qwen35ImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
|
|
if td == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
name := t.canonicalTensorName(td.Name)
|
|
|
|
// Phase 1: rename/split into intermediate tensors
|
|
var intermediates []*safetensors.TensorData
|
|
stripped := strings.TrimSuffix(name, ".weight")
|
|
switch {
|
|
case strings.HasSuffix(stripped, ".mlp.experts.gate_up_proj"):
|
|
prefix := strings.TrimSuffix(stripped, ".mlp.experts.gate_up_proj")
|
|
raw, err := io.ReadAll(td.Reader())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read tensor %s: %w", td.Name, err)
|
|
}
|
|
gateRaw, upRaw, splitShape, err := qwen35SplitAxis1Raw(raw, td.Dtype, td.Shape)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to split tensor %s: %w", td.Name, err)
|
|
}
|
|
intermediates = []*safetensors.TensorData{
|
|
safetensors.NewTensorDataFromBytes(prefix+".mlp.switch_mlp.gate_proj.weight", td.Dtype, splitShape, gateRaw),
|
|
safetensors.NewTensorDataFromBytes(prefix+".mlp.switch_mlp.up_proj.weight", td.Dtype, splitShape, upRaw),
|
|
}
|
|
case strings.HasSuffix(stripped, ".mlp.experts.down_proj"):
|
|
newName := strings.TrimSuffix(stripped, ".mlp.experts.down_proj") + ".mlp.switch_mlp.down_proj.weight"
|
|
intermediates = []*safetensors.TensorData{td.WithName(newName)}
|
|
default:
|
|
intermediates = []*safetensors.TensorData{td.WithName(name)}
|
|
}
|
|
|
|
// Phase 2: rewrite all intermediates
|
|
results := make([]*safetensors.TensorData, 0, len(intermediates))
|
|
for _, inter := range intermediates {
|
|
rewritten, err := t.rewriteTensorData(inter)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
results = append(results, rewritten)
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
func (t qwen35ImportTransform) canonicalTensorName(name string) string {
|
|
// Vision tensors: normalize to vision_tower.* prefix
|
|
switch {
|
|
case strings.HasPrefix(name, "model.visual."):
|
|
return "vision_tower." + strings.TrimPrefix(name, "model.visual.")
|
|
case strings.HasPrefix(name, "vision_tower."):
|
|
return name
|
|
}
|
|
|
|
// Language model tensors: normalize to language_model.model.* prefix
|
|
if !t.rewriteLanguageModel {
|
|
return name
|
|
}
|
|
switch {
|
|
case strings.HasPrefix(name, "model.language_model"):
|
|
return "language_model.model" + strings.TrimPrefix(name, "model.language_model")
|
|
case strings.HasPrefix(name, "language_model."):
|
|
return name
|
|
default:
|
|
return "language_model." + name
|
|
}
|
|
}
|
|
|
|
func qwen35ShouldShiftNormKey(key string) bool {
|
|
for _, suffix := range []string{
|
|
".input_layernorm.weight",
|
|
".post_attention_layernorm.weight",
|
|
"model.norm.weight",
|
|
".q_norm.weight",
|
|
".k_norm.weight",
|
|
} {
|
|
if strings.HasSuffix(key, suffix) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func qwen35NeedsCastToBF16(name, dtype string) bool {
|
|
if strings.HasSuffix(name, "A_log") {
|
|
return false
|
|
}
|
|
switch strings.ToUpper(dtype) {
|
|
case "F16", "F32", "F64":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func qwen35TransposeConv1D(values []float32, shape []int32) ([]float32, []int32) {
|
|
if len(shape) != 3 {
|
|
return values, shape
|
|
}
|
|
|
|
d0, d1, d2 := int(shape[0]), int(shape[1]), int(shape[2])
|
|
out := make([]float32, len(values))
|
|
for i := range d0 {
|
|
for j := range d1 {
|
|
for k := range d2 {
|
|
inIdx := (i*d1+j)*d2 + k
|
|
outIdx := (i*d2+k)*d1 + j
|
|
out[outIdx] = values[inIdx]
|
|
}
|
|
}
|
|
}
|
|
|
|
return out, []int32{shape[0], shape[2], shape[1]}
|
|
}
|
|
|
|
func qwen35SplitAxis1Raw(raw []byte, dtype string, shape []int32) ([]byte, []byte, []int32, error) {
|
|
if len(shape) != 3 {
|
|
return nil, nil, nil, fmt.Errorf("expected 3D tensor, got shape %v", shape)
|
|
}
|
|
if shape[1]%2 != 0 {
|
|
return nil, nil, nil, fmt.Errorf("axis 1 dim %d is not even", shape[1])
|
|
}
|
|
|
|
elemSize, err := DTypeSize(dtype)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
d0, d1, d2 := int(shape[0]), int(shape[1]), int(shape[2])
|
|
perExpertBytes := d1 * d2 * elemSize
|
|
if len(raw) != d0*perExpertBytes {
|
|
return nil, nil, nil, fmt.Errorf("raw byte length %d does not match shape %v and dtype %s", len(raw), shape, dtype)
|
|
}
|
|
|
|
halfD1 := d1 / 2
|
|
halfExpertBytes := halfD1 * d2 * elemSize
|
|
gateRaw := make([]byte, d0*halfExpertBytes)
|
|
upRaw := make([]byte, d0*halfExpertBytes)
|
|
for e := range d0 {
|
|
src := e * perExpertBytes
|
|
dst := e * halfExpertBytes
|
|
copy(gateRaw[dst:dst+halfExpertBytes], raw[src:src+halfExpertBytes])
|
|
copy(upRaw[dst:dst+halfExpertBytes], raw[src+halfExpertBytes:src+perExpertBytes])
|
|
}
|
|
|
|
return gateRaw, upRaw, []int32{shape[0], int32(halfD1), shape[2]}, nil
|
|
}
|