ggml: ensure tensor size is valid (#14406)

When quantizing tensors during model creation validate that the resulting sizes match what is expected based on the shape.
This commit is contained in:
Bruce MacDonald
2026-02-24 21:52:44 -04:00
committed by GitHub
parent f4f0a4a471
commit 9d902d63ce
4 changed files with 96 additions and 10 deletions

View File

@@ -33,6 +33,9 @@ func (q quantizer) WriteTo(w io.Writer) (int64, error) {
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err)
}
if uint64(len(data)) < q.from.Size() {
return 0, fmt.Errorf("tensor %s data size %d is less than expected %d from shape %v", q.from.Name, len(data), q.from.Size(), q.from.Shape)
}
var f32s []float32
newType := fsggml.TensorType(q.to.Kind)
if fsggml.TensorType(q.from.Kind) == fsggml.TensorTypeF32 {

View File

@@ -173,6 +173,7 @@ func TestQuantizeModel(t *testing.T) {
tensors []*fsggml.Tensor
newType string
expectedTensorTypes map[string]fsggml.TensorType
expectErr bool
}{
{
name: "f16_q4_k",
@@ -253,6 +254,36 @@ func TestQuantizeModel(t *testing.T) {
"output.weight": fsggml.TensorTypeQ8_0,
},
},
{
name: "f32_short_data",
kv: map[string]any{
"general.architecture": "foo",
},
tensors: []*fsggml.Tensor{
{
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF32),
Offset: uint64(0), Shape: []uint64{512, 2},
WriterTo: bytes.NewReader(make([]byte, 32)),
},
},
newType: "Q4_K",
expectErr: true,
},
{
name: "f16_short_data",
kv: map[string]any{
"general.architecture": "foo",
},
tensors: []*fsggml.Tensor{
{
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF16),
Offset: uint64(0), Shape: []uint64{512, 2},
WriterTo: bytes.NewReader(make([]byte, 32)),
},
},
newType: "Q4_K",
expectErr: true,
},
}
for _, tt := range cases {
@@ -264,6 +295,9 @@ func TestQuantizeModel(t *testing.T) {
}
defer fp.Close()
meta, err := fsggml.Decode(fp, -1)
if tt.expectErr && err != nil {
return
}
if err != nil {
t.Fatal(err.Error())
}
@@ -283,6 +317,12 @@ func TestQuantizeModel(t *testing.T) {
}
err = quantize(fp, tmp, meta, ftype, progress)
if tt.expectErr {
if err == nil {
t.Fatal("expected quantize to return an error")
}
return
}
if err != nil {
t.Fatalf("error during quantize: %s", err)
}