mirror of
https://github.com/ollama/ollama.git
synced 2026-04-20 07:54:25 +02:00
ggml: ensure tensor size is valid (#14406)
When quantizing tensors during model creation validate that the resulting sizes match what is expected based on the shape.
This commit is contained in:
@@ -173,6 +173,7 @@ func TestQuantizeModel(t *testing.T) {
|
||||
tensors []*fsggml.Tensor
|
||||
newType string
|
||||
expectedTensorTypes map[string]fsggml.TensorType
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "f16_q4_k",
|
||||
@@ -253,6 +254,36 @@ func TestQuantizeModel(t *testing.T) {
|
||||
"output.weight": fsggml.TensorTypeQ8_0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "f32_short_data",
|
||||
kv: map[string]any{
|
||||
"general.architecture": "foo",
|
||||
},
|
||||
tensors: []*fsggml.Tensor{
|
||||
{
|
||||
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF32),
|
||||
Offset: uint64(0), Shape: []uint64{512, 2},
|
||||
WriterTo: bytes.NewReader(make([]byte, 32)),
|
||||
},
|
||||
},
|
||||
newType: "Q4_K",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "f16_short_data",
|
||||
kv: map[string]any{
|
||||
"general.architecture": "foo",
|
||||
},
|
||||
tensors: []*fsggml.Tensor{
|
||||
{
|
||||
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF16),
|
||||
Offset: uint64(0), Shape: []uint64{512, 2},
|
||||
WriterTo: bytes.NewReader(make([]byte, 32)),
|
||||
},
|
||||
},
|
||||
newType: "Q4_K",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
@@ -264,6 +295,9 @@ func TestQuantizeModel(t *testing.T) {
|
||||
}
|
||||
defer fp.Close()
|
||||
meta, err := fsggml.Decode(fp, -1)
|
||||
if tt.expectErr && err != nil {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
@@ -283,6 +317,12 @@ func TestQuantizeModel(t *testing.T) {
|
||||
}
|
||||
|
||||
err = quantize(fp, tmp, meta, ftype, progress)
|
||||
if tt.expectErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected quantize to return an error")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("error during quantize: %s", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user