mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
60 lines
1.6 KiB
Go
60 lines
1.6 KiB
Go
package qwen3next
|
|
|
|
import (
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/kvcache"
|
|
"github.com/ollama/ollama/ml"
|
|
)
|
|
|
|
var (
|
|
_ kvcache.Cache = (*HybridCache)(nil)
|
|
_ kvcache.CheckpointCache = (*HybridCache)(nil)
|
|
)
|
|
|
|
// HybridCache adapts the shared recurrent cache base for Qwen3-Next naming.
|
|
type HybridCache struct {
|
|
*kvcache.Recurrent
|
|
}
|
|
|
|
func NewHybridCache(
|
|
shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error),
|
|
convDim, convChannels, deltaStateSize int,
|
|
) *HybridCache {
|
|
base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{
|
|
Shift: shift,
|
|
ConvDim: convDim,
|
|
ConvChannels: convChannels,
|
|
RecurrentStateSize: deltaStateSize,
|
|
CheckpointLogPrefix: "qwen3next",
|
|
})
|
|
return &HybridCache{Recurrent: base}
|
|
}
|
|
|
|
// DeltaState returns the delta state for current batch sequences as
|
|
// [headVDim, headVDim*numVHeads, nSeqs].
|
|
func (c *HybridCache) DeltaState(ctx ml.Context, layer int, headVDim, numVHeads int) (ml.Tensor, error) {
|
|
return c.RecurrentState(ctx, layer, headVDim, headVDim*numVHeads)
|
|
}
|
|
|
|
// UpdateDeltaState writes a new delta state for current batch sequences.
|
|
func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Tensor) {
|
|
c.UpdateRecurrentState(ctx, layer, newState)
|
|
}
|
|
|
|
func (c *HybridCache) seqTokens() int {
|
|
return c.SeqTokens()
|
|
}
|
|
|
|
func (c *HybridCache) numSeqs() int {
|
|
return c.NumSeqs()
|
|
}
|
|
|
|
// Keep qwen3next behavior for partial mid-sequence removals.
|
|
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
|
|
if beginIndex > 0 && endIndex != math.MaxInt32 {
|
|
return kvcache.ErrNotSupported
|
|
}
|
|
return c.Recurrent.Remove(seq, beginIndex, endIndex)
|
|
}
|