mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 16:54:13 +02:00
597 lines
15 KiB
Go
597 lines
15 KiB
Go
package qwen3next
|
|
|
|
import (
|
|
"math"
|
|
"slices"
|
|
|
|
"github.com/ollama/ollama/kvcache"
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/model/input"
|
|
)
|
|
|
|
var _ kvcache.Cache = (*HybridCache)(nil)
|
|
|
|
// HybridCache stores:
|
|
// - a standard causal KV cache for full attention layers
|
|
// - per-sequence conv state for linear attention layers
|
|
// - per-sequence delta state for linear attention layers
|
|
//
|
|
// Conv state shape (per layer, per sequence): [convKernelSize-1, convChannels]
|
|
// Delta state shape (per layer, per sequence): [headVDim, headVDim * numVHeads]
|
|
type HybridCache struct {
|
|
kv *kvcache.Causal
|
|
|
|
backend ml.Backend
|
|
dtype ml.DType
|
|
maxSequences int
|
|
|
|
// Conv state dimensions
|
|
convDim int // convKernelSize - 1
|
|
convChannels int // d_inner + 2 * num_k_heads * head_k_dim
|
|
|
|
// Delta state dimensions
|
|
deltaStateSize int // headVDim * headVDim * numVHeads
|
|
|
|
// slot mapping for recurrent state (copy-on-write)
|
|
slotForSeq map[int]int
|
|
refCount []int
|
|
freeSlots []int
|
|
|
|
// per-layer conv state buffers (allocated lazily)
|
|
convCtxs map[int]ml.Context
|
|
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
|
|
|
|
// per-layer delta state buffers (allocated lazily)
|
|
deltaCtxs map[int]ml.Context
|
|
deltaStates map[int]ml.Tensor // [deltaStateSize, maxSlots]
|
|
|
|
// recurrent checkpoints (per slot)
|
|
checkpointCount int
|
|
checkpointMinPos int32
|
|
checkpointInterval int32
|
|
checkpointCtxSize int
|
|
checkpoints map[int]*slotCheckpointStore
|
|
pendingRestore map[int]checkpointRestore
|
|
curCheckpointPos []int32
|
|
curCheckpointSlots map[int]int
|
|
reserveCheckpoints bool
|
|
checkpointConvCtxs map[int]ml.Context
|
|
checkpointDeltaCtxs map[int]ml.Context
|
|
checkpointReserved map[int]struct{}
|
|
|
|
// current forward batch (derived in StartForward)
|
|
curSeqs []int
|
|
curSlots []int
|
|
curSlotsInput ml.Tensor
|
|
curSeqTokens int
|
|
|
|
// track if EnsureWritable has been called for this forward pass
|
|
writableEnsured bool
|
|
writableError error
|
|
}
|
|
|
|
func NewHybridCache(
|
|
shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error),
|
|
convDim, convChannels, deltaStateSize int,
|
|
) *HybridCache {
|
|
return &HybridCache{
|
|
kv: kvcache.NewCausalCache(shift),
|
|
convDim: convDim,
|
|
convChannels: convChannels,
|
|
deltaStateSize: deltaStateSize,
|
|
slotForSeq: make(map[int]int),
|
|
convCtxs: make(map[int]ml.Context),
|
|
convStates: make(map[int]ml.Tensor),
|
|
deltaCtxs: make(map[int]ml.Context),
|
|
deltaStates: make(map[int]ml.Tensor),
|
|
checkpointCount: checkpointCountDefault,
|
|
checkpointMinPos: checkpointMinPosDefault,
|
|
checkpointInterval: checkpointIntervalDefault,
|
|
checkpoints: make(map[int]*slotCheckpointStore),
|
|
pendingRestore: make(map[int]checkpointRestore),
|
|
curCheckpointSlots: make(map[int]int),
|
|
checkpointConvCtxs: make(map[int]ml.Context),
|
|
checkpointDeltaCtxs: make(map[int]ml.Context),
|
|
checkpointReserved: make(map[int]struct{}),
|
|
}
|
|
}
|
|
|
|
func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
|
c.backend = backend
|
|
c.dtype = dtype
|
|
c.maxSequences = maxSequences
|
|
c.checkpoints = make(map[int]*slotCheckpointStore)
|
|
c.pendingRestore = make(map[int]checkpointRestore)
|
|
c.curCheckpointPos = c.curCheckpointPos[:0]
|
|
c.curCheckpointSlots = make(map[int]int)
|
|
c.checkpointReserved = make(map[int]struct{})
|
|
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
|
|
if c.checkpointCtxSize < 8 {
|
|
c.checkpointCtxSize = 8
|
|
}
|
|
|
|
// initialize slot allocator
|
|
c.refCount = make([]int, maxSequences)
|
|
c.freeSlots = c.freeSlots[:0]
|
|
for i := maxSequences - 1; i >= 0; i-- {
|
|
c.freeSlots = append(c.freeSlots, i)
|
|
}
|
|
|
|
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
|
}
|
|
|
|
func (c *HybridCache) Close() {
|
|
for _, ctx := range c.convCtxs {
|
|
ctx.Close()
|
|
}
|
|
for _, ctx := range c.deltaCtxs {
|
|
ctx.Close()
|
|
}
|
|
for _, ctx := range c.checkpointConvCtxs {
|
|
ctx.Close()
|
|
}
|
|
for _, ctx := range c.checkpointDeltaCtxs {
|
|
ctx.Close()
|
|
}
|
|
c.kv.Close()
|
|
}
|
|
|
|
func (c *HybridCache) SetConfig(config ml.CacheConfig) {
|
|
c.kv.SetConfig(config)
|
|
}
|
|
|
|
func (c *HybridCache) SetLayer(layer int) {
|
|
c.kv.SetLayer(layer)
|
|
}
|
|
|
|
func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
|
return c.kv.Get(ctx)
|
|
}
|
|
|
|
func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
|
c.kv.Put(ctx, key, value)
|
|
}
|
|
|
|
func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
|
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Derive equal-length sequence layout for recurrent layers
|
|
seqCounts := make(map[int]int)
|
|
c.curSeqs = c.curSeqs[:0]
|
|
for _, s := range batch.Sequences {
|
|
if _, ok := seqCounts[s]; !ok {
|
|
c.curSeqs = append(c.curSeqs, s)
|
|
}
|
|
seqCounts[s]++
|
|
}
|
|
|
|
if len(c.curSeqs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
nTokens := len(batch.Sequences)
|
|
nSeqs := len(c.curSeqs)
|
|
want := nTokens / nSeqs
|
|
for _, s := range c.curSeqs {
|
|
if seqCounts[s] != want {
|
|
return kvcache.ErrNotSupported
|
|
}
|
|
}
|
|
|
|
c.curSeqTokens = want
|
|
|
|
// When reserving memory for estimation, use fake slot assignments
|
|
if reserve {
|
|
c.curSlots = c.curSlots[:0]
|
|
slots := make([]int32, nSeqs)
|
|
for i := range nSeqs {
|
|
c.curSlots = append(c.curSlots, i)
|
|
slots[i] = int32(i)
|
|
}
|
|
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
|
c.reserveCheckpoints = true
|
|
c.planCheckpoints(batch)
|
|
return nil
|
|
}
|
|
|
|
// Ensure slots exist for sequences in this batch
|
|
c.curSlots = c.curSlots[:0]
|
|
var newSlots []int
|
|
for _, s := range c.curSeqs {
|
|
slot, ok := c.slotForSeq[s]
|
|
if !ok {
|
|
var err error
|
|
slot, err = c.allocSlot()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.slotForSeq[s] = slot
|
|
c.refCount[slot] = 1
|
|
newSlots = append(newSlots, slot)
|
|
}
|
|
c.curSlots = append(c.curSlots, slot)
|
|
}
|
|
|
|
// Zero state for newly allocated slots
|
|
if len(newSlots) > 0 {
|
|
c.zeroSlots(ctx, newSlots)
|
|
}
|
|
|
|
// Create a tensor for the current slots
|
|
slots := make([]int32, len(c.curSlots))
|
|
for i, v := range c.curSlots {
|
|
slots[i] = int32(v)
|
|
}
|
|
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
|
|
|
// Reset writable state for new forward pass
|
|
c.writableEnsured = false
|
|
c.writableError = nil
|
|
c.reserveCheckpoints = false
|
|
c.planCheckpoints(batch)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *HybridCache) allocSlot() (int, error) {
|
|
if len(c.freeSlots) == 0 {
|
|
return 0, kvcache.ErrKvCacheFull
|
|
}
|
|
slot := c.freeSlots[len(c.freeSlots)-1]
|
|
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
|
return slot, nil
|
|
}
|
|
|
|
func (c *HybridCache) freeSlot(slot int) {
|
|
if slot >= 0 && slot < c.maxSequences {
|
|
c.freeSlots = append(c.freeSlots, slot)
|
|
}
|
|
}
|
|
|
|
// zeroSlots zeros the recurrent state for the given slots across all layers.
|
|
func (c *HybridCache) zeroSlots(ctx ml.Context, slots []int) {
|
|
if len(slots) == 0 {
|
|
return
|
|
}
|
|
|
|
inputCtx := ctx.Input()
|
|
|
|
slotIndices := make([]int32, len(slots))
|
|
for i, s := range slots {
|
|
slotIndices[i] = int32(s)
|
|
}
|
|
slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
|
|
|
|
// Zero conv states
|
|
if len(c.convStates) > 0 {
|
|
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
|
|
for _, buf := range c.convStates {
|
|
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
|
}
|
|
}
|
|
|
|
// Zero delta states
|
|
if len(c.deltaStates) > 0 {
|
|
zeros := inputCtx.Zeros(ml.DTypeF32, c.deltaStateSize, len(slots))
|
|
for _, buf := range c.deltaStates {
|
|
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
|
}
|
|
}
|
|
}
|
|
|
|
// EnsureWritable ensures sequences have private slots (copy-on-write).
|
|
func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
|
|
for i, seq := range c.curSeqs {
|
|
slot, ok := c.slotForSeq[seq]
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
if slot < 0 || slot >= len(c.refCount) {
|
|
continue
|
|
}
|
|
|
|
if c.refCount[slot] <= 1 {
|
|
continue
|
|
}
|
|
|
|
newSlot, err := c.allocSlot()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.refCount[slot]--
|
|
c.refCount[newSlot] = 1
|
|
c.slotForSeq[seq] = newSlot
|
|
c.curSlots[i] = newSlot
|
|
|
|
c.copyRecurrentState(ctx, slot, newSlot)
|
|
c.copyCheckpoints(ctx, slot, newSlot)
|
|
}
|
|
|
|
// Rebuild current slots tensor
|
|
slots := make([]int32, len(c.curSlots))
|
|
for i, v := range c.curSlots {
|
|
slots[i] = int32(v)
|
|
}
|
|
c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *HybridCache) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
|
|
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
|
|
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
|
|
|
|
for _, buf := range c.convStates {
|
|
rows := buf.Rows(ctx, src)
|
|
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
|
|
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
|
|
}
|
|
|
|
for _, buf := range c.deltaStates {
|
|
rows := buf.Rows(ctx, src)
|
|
rowsF32 := rows.Cast(ctx, ml.DTypeF32)
|
|
ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
|
|
}
|
|
}
|
|
|
|
func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
|
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
|
|
|
// Copy-on-write for recurrent state
|
|
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
|
if c.validSlot(dstSlot) {
|
|
c.refCount[dstSlot]--
|
|
if c.refCount[dstSlot] <= 0 {
|
|
c.refCount[dstSlot] = 0
|
|
c.freeSlot(dstSlot)
|
|
}
|
|
}
|
|
delete(c.slotForSeq, dstSeq)
|
|
}
|
|
|
|
srcSlot, ok := c.slotForSeq[srcSeq]
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if c.validSlot(srcSlot) {
|
|
c.slotForSeq[dstSeq] = srcSlot
|
|
c.refCount[srcSlot]++
|
|
}
|
|
}
|
|
|
|
func (c *HybridCache) CanResume(seq int, pos int32) bool {
|
|
if !c.kv.CanResume(seq, pos) {
|
|
return false
|
|
}
|
|
if pos == 0 {
|
|
return true
|
|
}
|
|
return c.hasCheckpoint(seq, pos)
|
|
}
|
|
|
|
func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
|
|
if beginIndex > 0 && endIndex != math.MaxInt32 {
|
|
return kvcache.ErrNotSupported
|
|
}
|
|
|
|
if beginIndex > 0 {
|
|
restore, ok := c.pendingRestore[seq]
|
|
if !ok || restore.pos+1 != beginIndex {
|
|
return kvcache.ErrNotSupported
|
|
}
|
|
if !c.restoreComplete(restore) {
|
|
return kvcache.ErrNotSupported
|
|
}
|
|
// If the recurrent slot is shared, detach it before applying a restore.
|
|
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
|
|
newSlot, err := c.allocSlot()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ctx := c.backend.NewContext()
|
|
c.copyRecurrentState(ctx, slot, newSlot)
|
|
c.copyCheckpoints(ctx, slot, newSlot)
|
|
if len(c.convStates) > 0 || len(c.deltaStates) > 0 {
|
|
ctx.Compute()
|
|
}
|
|
ctx.Close()
|
|
|
|
c.refCount[slot]--
|
|
c.refCount[newSlot] = 1
|
|
c.slotForSeq[seq] = newSlot
|
|
|
|
restore.slot = newSlot
|
|
c.pendingRestore[seq] = restore
|
|
}
|
|
}
|
|
|
|
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
|
return err
|
|
}
|
|
|
|
if beginIndex > 0 {
|
|
restore := c.pendingRestore[seq]
|
|
delete(c.pendingRestore, seq)
|
|
return c.applyCheckpointRestore(restore)
|
|
}
|
|
|
|
// Removal invalidates recurrent state
|
|
slot, ok := c.slotForSeq[seq]
|
|
delete(c.pendingRestore, seq)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
if !c.validSlot(slot) {
|
|
delete(c.slotForSeq, seq)
|
|
return nil
|
|
}
|
|
|
|
c.refCount[slot]--
|
|
if c.refCount[slot] <= 0 {
|
|
c.refCount[slot] = 0
|
|
c.clearCheckpoints(slot)
|
|
c.freeSlot(slot)
|
|
}
|
|
delete(c.slotForSeq, seq)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *HybridCache) validSlot(slot int) bool {
|
|
return slot >= 0 && slot < len(c.refCount)
|
|
}
|
|
|
|
func (c *HybridCache) slotsTensor() ml.Tensor {
|
|
return c.curSlotsInput
|
|
}
|
|
|
|
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
|
|
func (c *HybridCache) contiguousSlots() (int, bool) {
|
|
if len(c.curSlots) == 0 {
|
|
return 0, false
|
|
}
|
|
start := c.curSlots[0]
|
|
for i, s := range c.curSlots {
|
|
if s != start+i {
|
|
return 0, false
|
|
}
|
|
}
|
|
return start, true
|
|
}
|
|
|
|
func (c *HybridCache) seqTokens() int {
|
|
return c.curSeqTokens
|
|
}
|
|
|
|
func (c *HybridCache) numSeqs() int {
|
|
return len(c.curSeqs)
|
|
}
|
|
|
|
func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
|
|
if buf, ok := c.convStates[layer]; ok {
|
|
return buf
|
|
}
|
|
|
|
if _, ok := c.convCtxs[layer]; !ok {
|
|
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
|
}
|
|
|
|
// Recurrent state must stay in F32 (ssm_conv kernels are F32-only).
|
|
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
|
|
c.convStates[layer] = buf
|
|
return buf
|
|
}
|
|
|
|
func (c *HybridCache) deltaBuffer(ctx ml.Context, layer int) ml.Tensor {
|
|
if buf, ok := c.deltaStates[layer]; ok {
|
|
return buf
|
|
}
|
|
|
|
if _, ok := c.deltaCtxs[layer]; !ok {
|
|
c.deltaCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
|
}
|
|
|
|
// Recurrent delta state must stay in F32.
|
|
buf := c.deltaCtxs[layer].Zeros(ml.DTypeF32, c.deltaStateSize, c.maxSequences)
|
|
c.deltaStates[layer] = buf
|
|
return buf
|
|
}
|
|
|
|
func (c *HybridCache) ensureWritableOnce(ctx ml.Context) {
|
|
if !c.writableEnsured {
|
|
needsWritable := false
|
|
for _, seq := range c.curSeqs {
|
|
slot, ok := c.slotForSeq[seq]
|
|
if !ok {
|
|
continue
|
|
}
|
|
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
|
needsWritable = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if needsWritable {
|
|
if err := c.EnsureWritable(ctx); err != nil {
|
|
c.writableError = err
|
|
}
|
|
}
|
|
c.writableEnsured = true
|
|
}
|
|
}
|
|
|
|
// ConvState returns the conv state for current batch sequences as [convDim, convChannels, nSeqs].
|
|
func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
|
c.ensureWritableOnce(ctx)
|
|
|
|
if c.writableError != nil {
|
|
return nil, c.writableError
|
|
}
|
|
|
|
buf := c.convBuffer(ctx, layer)
|
|
cur := buf.Rows(ctx, c.slotsTensor())
|
|
return cur.Reshape(ctx, c.convDim, c.convChannels, c.numSeqs()), nil
|
|
}
|
|
|
|
// UpdateConvState writes a new conv state for current batch sequences.
|
|
func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
|
buf := c.convBuffer(ctx, layer)
|
|
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.numSeqs())
|
|
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
|
if start, ok := c.contiguousSlots(); ok {
|
|
// Fast path: contiguous slots allow a single view + copy
|
|
offset := start * buf.Stride(1)
|
|
view := buf.View(ctx, offset, c.convDim*c.convChannels, buf.Stride(1), c.numSeqs())
|
|
ctx.Forward(srcF32.Copy(ctx, view))
|
|
} else {
|
|
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
|
}
|
|
|
|
c.captureConvCheckpoint(ctx, layer, srcF32)
|
|
}
|
|
|
|
// DeltaState returns the delta state for current batch sequences as [headVDim, headVDim*numVHeads, nSeqs].
|
|
func (c *HybridCache) DeltaState(ctx ml.Context, layer int, headVDim, numVHeads int) (ml.Tensor, error) {
|
|
c.ensureWritableOnce(ctx)
|
|
|
|
if c.writableError != nil {
|
|
return nil, c.writableError
|
|
}
|
|
|
|
buf := c.deltaBuffer(ctx, layer)
|
|
cur := buf.Rows(ctx, c.slotsTensor())
|
|
return cur.Reshape(ctx, headVDim, headVDim*numVHeads, c.numSeqs()), nil
|
|
}
|
|
|
|
// UpdateDeltaState writes a new delta state for current batch sequences.
|
|
func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Tensor) {
|
|
buf := c.deltaBuffer(ctx, layer)
|
|
src := newState.Reshape(ctx, c.deltaStateSize, c.numSeqs())
|
|
srcF32 := src.Cast(ctx, ml.DTypeF32)
|
|
if start, ok := c.contiguousSlots(); ok {
|
|
// Fast path: contiguous slots allow a single view + copy
|
|
offset := start * buf.Stride(1)
|
|
view := buf.View(ctx, offset, c.deltaStateSize, buf.Stride(1), c.numSeqs())
|
|
ctx.Forward(srcF32.Copy(ctx, view))
|
|
} else {
|
|
ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
|
|
}
|
|
|
|
c.captureDeltaCheckpoint(ctx, layer, srcF32)
|
|
}
|
|
|
|
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
|
|
func (c *HybridCache) IsSupportedForBatch() bool {
|
|
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
|
}
|
|
|
|
// Seqs returns the ordered unique sequences for the current forward pass.
|
|
func (c *HybridCache) Seqs() []int {
|
|
return slices.Clone(c.curSeqs)
|
|
}
|