diff --git a/model/model.go b/model/model.go index 42fe7f25c..db398bf1b 100644 --- a/model/model.go +++ b/model/model.go @@ -47,6 +47,12 @@ type Validator interface { Validate() error } +// PostLoader is an optional interface that models can implement to run +// initialization steps after backend weights have been loaded. +type PostLoader interface { + PostLoad() error +} + // MultimodalProcessor must be implemented by multimodal models. type MultimodalProcessor interface { // EncodeMultimodal processes a single input (such as an image) and diff --git a/model/models/gemma4/model.go b/model/models/gemma4/model.go index 39628ce2a..8c04046a5 100644 --- a/model/models/gemma4/model.go +++ b/model/models/gemma4/model.go @@ -131,9 +131,6 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input return nil, model.ErrNoVisionModel } - // Initialize clamp values from model tensors (lazy, once, after model is fully loaded) - m.VisionModel.InitClamp(m.MultiModalProjector) - t0 := time.Now() img, _, err := image.Decode(bytes.NewReader(multimodalData)) if err != nil { @@ -162,6 +159,11 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input return []input.Multimodal{{Tensor: visionOutputs}}, nil } +func (m *Model) PostLoad() error { + m.VisionModel.InitClamp(m.MultiModalProjector) + return nil +} + func (m *Model) encodeAudioMultimodal(ctx ml.Context, data []byte) ([]input.Multimodal, error) { if m.AudioModel == nil || m.audioOpts == nil { return nil, model.ErrNoVisionModel diff --git a/model/models/gemma4/model_vision.go b/model/models/gemma4/model_vision.go index b0ba2f090..87f646cd3 100644 --- a/model/models/gemma4/model_vision.go +++ b/model/models/gemma4/model_vision.go @@ -80,8 +80,6 @@ func (l *ClippableLinear) loadClampFromScalars() { } func (l *ClippableLinear) Forward(ctx ml.Context, x ml.Tensor) ml.Tensor { - l.loadClampFromScalars() - if l.hasClamp { x = x.Clamp(ctx, l.inMin, l.inMax) } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 49e4a5ed6..ccf646539 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -1258,6 +1258,12 @@ func (s *Server) loadModel() { panic(fmt.Errorf("failed to load model: %v", err)) } + if postLoader, ok := s.model.(model.PostLoader); ok { + if err := postLoader.PostLoad(); err != nil { + panic(fmt.Errorf("failed to finalize model initialization: %v", err)) + } + } + s.status = llm.ServerStatusReady s.ready.Done() }