Files
ollama-ollama/x/mlxrunner/cache/recurrent.go
Daniel Hiltgen 10e51c5177 MLX: add header vendoring and remove go build tag (#14642)
* prefer rocm v6 on windows

Avoid building with v7 - more changes are needed

* MLX: add header vendoring and remove go build tag

This switches to using a vendoring approach for the mlx-c headers so that Go
can build without requiring a cmake first.  This enables building the new MLX
based code by default.  Every time cmake runs, the headers are refreshed, so we
can easily keep them in sync when we bump mlx versions.  Basic Windows
and Linux support are verified.

* ci: harden for flaky choco repo servers

CI sometimes fails due to choco not actually installing cache.  Since it just speeds up the build, we can proceed without.

* review comments
2026-03-09 17:24:45 -07:00

160 lines
3.7 KiB
Go

package cache
import "github.com/ollama/ollama/x/mlxrunner/mlx"
// RecurrentCache stores state for linear-recurrent layers.
//
// Conv state shape: [B, convTail, convDim]
// Delta state shape: [B, numVHeads, headVDim, headKDim]
type RecurrentCache struct {
convState *mlx.Array
deltaState *mlx.Array
offset int
convTail int
convDim int
numVHeads int
headVDim int
headKDim int
}
func (c *RecurrentCache) setStateRaw(old, v *mlx.Array) *mlx.Array {
if v == nil || !v.Valid() {
return old
}
if old == v {
return old
}
mlx.Pin(v)
if old != nil && old != v {
mlx.Unpin(old)
}
return v
}
func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bool) *mlx.Array {
if v == nil || !v.Valid() {
return old
}
if old == v {
return old
}
root := v
if ensureContiguous {
root = mlx.Contiguous(v, false)
}
detached := root.Clone()
mlx.Pin(detached)
if old != nil && old != detached {
mlx.Unpin(old)
}
return detached
}
func snapshotPinned(a *mlx.Array) *mlx.Array {
if a == nil || !a.Valid() {
return nil
}
snap := mlx.Copy(a)
mlx.Eval(snap)
mlx.Pin(snap)
return snap
}
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
return &RecurrentCache{
convTail: int(convTail),
convDim: int(convDim),
numVHeads: int(numVHeads),
headVDim: int(headVDim),
headKDim: int(headKDim),
}
}
func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) {
if batch <= 0 {
batch = 1
}
needConv := c.convState == nil || !c.convState.Valid() || c.convState.DType() != dtype ||
c.convState.Dim(0) != batch || c.convState.Dim(1) != c.convTail || c.convState.Dim(2) != c.convDim
needDelta := c.deltaState == nil || !c.deltaState.Valid() || c.deltaState.DType() != dtype ||
c.deltaState.Dim(0) != batch || c.deltaState.Dim(1) != c.numVHeads || c.deltaState.Dim(2) != c.headVDim || c.deltaState.Dim(3) != c.headKDim
if !needConv && !needDelta {
return
}
if needConv {
c.convState = c.setStateRaw(c.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim))
}
if needDelta {
c.deltaState = c.setStateRaw(c.deltaState, mlx.Zeros(dtype, batch, c.numVHeads, c.headVDim, c.headKDim))
}
}
func (c *RecurrentCache) ConvState(batch int, dtype mlx.DType) *mlx.Array {
c.ensure(batch, dtype)
return c.convState
}
func (c *RecurrentCache) SetConvState(v *mlx.Array) {
c.convState = c.setStateDetached(c.convState, v, true)
}
func (c *RecurrentCache) DeltaState(batch int, dtype mlx.DType) *mlx.Array {
c.ensure(batch, dtype)
return c.deltaState
}
func (c *RecurrentCache) SetDeltaState(v *mlx.Array) {
c.deltaState = c.setStateDetached(c.deltaState, v, false)
}
func (c *RecurrentCache) Advance(n int) {
c.offset += n
}
func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
return keys, values
}
func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) {
return c.convState, c.deltaState
}
func (c *RecurrentCache) CanTrim() bool { return false }
func (c *RecurrentCache) Trim(n int) int {
// Recurrent state is not directly trimmable. Divergent prefixes must drop the cache.
_ = n
return 0
}
func (c *RecurrentCache) Clone() Cache {
clone := &RecurrentCache{
offset: c.offset,
convTail: c.convTail,
convDim: c.convDim,
numVHeads: c.numVHeads,
headVDim: c.headVDim,
headKDim: c.headKDim,
convState: snapshotPinned(c.convState),
deltaState: snapshotPinned(c.deltaState),
}
return clone
}
func (c *RecurrentCache) Free() {
mlx.Unpin(c.convState, c.deltaState)
c.convState, c.deltaState = nil, nil
c.offset = 0
}
func (c *RecurrentCache) Offset() int { return c.offset }
func (c *RecurrentCache) Len() int { return c.offset }