diff --git a/x/imagegen/mlx/README.md b/x/imagegen/mlx/README.md index 3a2f7b3d8..5e7a1f42a 100644 --- a/x/imagegen/mlx/README.md +++ b/x/imagegen/mlx/README.md @@ -1,7 +1,5 @@ # MLX Memory Management -| This package will get consolidated with `x/ml/backend/mlx` in the future. - ## Automatic Tracking All arrays are automatically tracked when created. On `Eval()`, non-kept arrays are freed. diff --git a/x/kvcache/cache.go b/x/kvcache/cache.go deleted file mode 100644 index f0627584a..000000000 --- a/x/kvcache/cache.go +++ /dev/null @@ -1,77 +0,0 @@ -package kvcache - -import ( - "errors" - - "github.com/ollama/ollama/x/ml" - "github.com/ollama/ollama/x/model/input" -) - -var ( - ErrKvCacheFull = errors.New("could not find a kv cache slot") - ErrNotSupported = errors.New("model does not support operation") -) - -type Cache interface { - // ** used by model implementations ** - - // SetLayer sets the active layer of the cache - SetLayer(layer int) - - // Get returns the history of key and value tensors plus a mask - // - // The shape of the tensors is documented in the specific - // cache implementation used. - Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) - - // Put stores a batch of key and value in the cache - // - // The shape of the tensors is documented in the specific - // cache implementation used. - Put(ctx ml.Context, key, value ml.Tensor) - - // SetConfig controls optimizations (mostly backend-specific) that may transform - // the output of the cache to work better with specific kernels. If not called, - // the backend settings will be used. This works well when calling Attention. - // - // The config can be overridden by models, especially if they require vanilla - // output when implementing their own version of attention. To do this, pass - // an empty ml.CacheConfig. - // - // Most models will not need to use this. - SetConfig(ml.CacheConfig) - - // ** cache management ** - - // Init sets up runtime parameters. - // backend: Used to allocate cache data storage and execute management operations (such as defrag) - // dtype: The data type for storing cache entries - // maxSequences: The maximum number of sequences stored in the cache - across all batches - // capacity: The number of cache entries to store, per sequence - // maxBatch: The maximum number of tokens that can occur in a single batch - Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) - - // Close closes the cache and frees resources associated with it - Close() - - // StartForward is called before the start of the model's forward pass. - // For each token in the coming batch, there must be a corresponding - // entry in positions and seqs. reserve is to preallocate memory - // without actually storing data in the cache. - StartForward(ctx ml.Context, batch input.Batch, reserve bool) error - - // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq - CopyPrefix(srcSeq, dstSeq int, len int32) - - // CanResume returns true if the cache can continue with the next token at - // the given position and sequence. Assumes that the caller has already - // verified the contents of the cache. - CanResume(seq int, pos int32) bool - - // Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set - // endIndex to math.MaxInt32 to remove everything starting at beginIndex. - // - // If an error occurs, the entire context for the sequence should be - // removed by calling Remove(seq, 0, math.MaxInt32) - Remove(seq int, beginIndex, endIndex int32) error -} diff --git a/x/kvcache/causal.go b/x/kvcache/causal.go deleted file mode 100644 index 967fed674..000000000 --- a/x/kvcache/causal.go +++ /dev/null @@ -1,797 +0,0 @@ -package kvcache - -// import ( -// "errors" -// "fmt" -// "log/slog" -// "math" -// "slices" - -// "github.com/ollama/ollama/ml" -// "github.com/ollama/ollama/model/input" -// ) - -// type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) - -// // Causal cache stores K and V tensors according to their position in the -// // sequence. Returns the history and a mask for attending to past tokens -// // -// // The tensors are of shape embed dim, kv heads, batch size -// // The mask is of shape history size, batch size -// type Causal struct { -// DType ml.DType - -// // swaWindowSize is the number of tokens that will be included in the mask -// // during attention operations. swaMemorySize is the number of tokens that -// // will be retained in memory for partial prefix caching. Set to math.MaxInt32 -// // for unlimited or if sliding window attention is not being used. -// swaWindowSize int32 -// swaMemorySize int32 - -// chunkSize int32 - -// opts CausalOptions - -// // maxBatch is the largest batch that we might receive -// maxBatch int - -// // config controls mostly backend-specific optimizations -// config *ml.CacheConfig - -// // ** current forward pass ** - -// // size of the current batch -// curBatchSize int - -// // locations for data storage for this batch -// curLoc ml.Tensor - -// // mask of the cache as used by this batch -// curMask ml.Tensor - -// // the active layer for Get and Put -// curLayer int - -// // locations in the cache that are needed for this batch -// curCellRange cellRange - -// // curSequences is the sequences corresponding to this pass's entries in the cache -// curSequences []int - -// // curPositions is the positions corresponding to this pass's entries in the cache -// curPositions []int32 - -// // ** cache metadata ** - -// // for each possible location in the cache, stores the position and set of sequences -// // that reference the data there -// cells []cacheCell - -// // maps from sequence to the range of locations where it is stored in the cache -// cellRanges map[int]cellRange - -// // ** cache data storage ** - -// shiftFn shiftFn -// backend ml.Backend -// ctxs map[int]ml.Context -// keys, values map[int]ml.Tensor - -// kHeadDims, vHeadDims, numKVHeads map[int]int -// } - -// type cacheCell struct { -// pos int32 -// sequences []int -// } - -// type cellRange struct { -// min int -// max int -// } - -// func NewCausalCache(shift shiftFn) *Causal { -// return &Causal{ -// shiftFn: shift, -// ctxs: make(map[int]ml.Context), -// keys: make(map[int]ml.Tensor), -// values: make(map[int]ml.Tensor), -// kHeadDims: make(map[int]int), -// vHeadDims: make(map[int]int), -// numKVHeads: make(map[int]int), -// } -// } - -// func NewSWACache(windowSize int32, shift shiftFn) *Causal { -// return &Causal{ -// swaWindowSize: windowSize, -// shiftFn: shift, -// ctxs: make(map[int]ml.Context), -// keys: make(map[int]ml.Tensor), -// values: make(map[int]ml.Tensor), -// kHeadDims: make(map[int]int), -// vHeadDims: make(map[int]int), -// numKVHeads: make(map[int]int), -// } -// } - -// func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal { -// return &Causal{ -// swaWindowSize: windowSize, -// swaMemorySize: memorySize, -// shiftFn: shift, -// ctxs: make(map[int]ml.Context), -// keys: make(map[int]ml.Tensor), -// values: make(map[int]ml.Tensor), -// kHeadDims: make(map[int]int), -// vHeadDims: make(map[int]int), -// numKVHeads: make(map[int]int), -// } -// } - -// func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal { -// return &Causal{ -// chunkSize: chunkSize, -// shiftFn: shift, -// ctxs: make(map[int]ml.Context), -// keys: make(map[int]ml.Tensor), -// values: make(map[int]ml.Tensor), -// kHeadDims: make(map[int]int), -// vHeadDims: make(map[int]int), -// numKVHeads: make(map[int]int), -// } -// } - -// func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { -// if c.config == nil { -// var config ml.CacheConfig -// if cc, ok := backend.(ml.BackendCacheConfig); ok { -// config = cc.CacheConfig() -// } -// c.config = &config -// } - -// if c.config.CachePadding == 0 { -// c.config.CachePadding = 1 -// } - -// if c.config.MaskBatchPadding == 0 { -// c.config.MaskBatchPadding = 1 -// } - -// // TODO what types do we handle here? -// // if c.config.MaskDType == ml.DTypeOther { -// // c.config.MaskDType = ml.DTypeFloat32 -// // } - -// if c.swaWindowSize == 0 { -// c.swaWindowSize = math.MaxInt32 -// } -// if c.swaMemorySize == 0 { -// c.swaMemorySize = c.swaWindowSize -// } -// // We will allocate space in the cache for the stop token, which won't be part of a follow on -// // sequence, so allocate an extra token of storage to ensure that we can jump back without -// // causing a cache break. As an optimization, only do this when we have parallel sequences -// // because the extra token will live in the batch buffer and won't get overwritten if we -// // only have a single sequence. -// if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 { -// c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1) -// } -// if int(c.swaMemorySize) >= capacity { -// c.swaMemorySize = math.MaxInt32 -// } - -// if c.swaMemorySize < c.swaWindowSize { -// panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize)) -// } - -// var cacheSize int -// if c.swaMemorySize == math.MaxInt32 { -// cacheSize = maxSequences * capacity -// } else { -// cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch -// } -// cacheSize = roundUp(cacheSize, c.config.CachePadding) -// c.cells = make([]cacheCell, cacheSize) - -// c.DType = dtype -// c.cellRanges = make(map[int]cellRange) -// c.backend = backend -// c.maxBatch = maxBatch -// } - -// func (c *Causal) SetConfig(config ml.CacheConfig) { -// if c.config != nil { -// panic("config cannot be changed after being previously set, either by the model or backend") -// } - -// c.config = &config -// } - -// func (c *Causal) Close() { -// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs)) -// for _, ctx := range c.ctxs { -// ctx.Close() -// } -// } - -// func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { -// slog.Info("XXX Causal.StartForward", "cell count", len(c.cells), "prior batch size", c.curBatchSize, "positions", len(batch.Positions), "reserve", reserve, "batch", batch) -// // panic("XXX Causal.StartForward") -// c.curBatchSize = len(batch.Positions) -// c.curSequences = batch.Sequences -// c.curPositions = batch.Positions -// c.opts.Except = nil - -// var locs []int32 -// if !reserve { -// c.updateSlidingWindow() - -// var err error -// locs, err = c.findLocs() -// if err != nil { -// return err -// } -// slog.Info("XXX Causal.StartForward", "findLocs len", len(locs)) - -// for i, pos := range batch.Positions { -// seq := batch.Sequences[i] -// loc := int(locs[i]) - -// c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}} - -// seqRange, ok := c.cellRanges[seq] -// if !ok { -// seqRange = newRange() -// } - -// seqRange.min = min(seqRange.min, loc) -// c.curCellRange.min = min(c.curCellRange.min, loc) - -// seqRange.max = max(seqRange.max, loc) -// c.curCellRange.max = max(c.curCellRange.max, loc) - -// c.cellRanges[seq] = seqRange -// } -// } else { -// // If we are reserving memory, don't update any of the cache metadata but set the size -// // to the worst case. -// locs = make([]int32, c.curBatchSize) -// for i := range locs { -// locs[i] = int32(i) -// } -// c.curCellRange.min = 0 -// c.curCellRange.max = len(c.cells) - 1 -// } - -// // XXX Building up the locs for what's already processed (if any) -// dummyLocs := []int{} -// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding) -// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1 - -// for i := range c.curBatchSize { -// enabled := !slices.Contains(c.opts.Except, i) -// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { -// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) || -// (enabled && c.cells[j].pos > c.curPositions[i]) || -// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize || -// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize { -// // mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1)) -// } else { -// if len(dummyLocs) == 0 || dummyLocs[len(dummyLocs)-1] != i { -// dummyLocs = append(dummyLocs, i) -// } -// } -// } -// } -// slog.Info("XXX Causa.StartForward calculated locations", "locs", dummyLocs) - -// slog.Info("XXX Causal.StartForward", "locs", locs) -// c.curLoc = ctx.Input().FromInts(locs, len(locs)) -// c.curMask = c.buildMask(ctx) - -// return nil -// } - -// func newRange() cellRange { -// return cellRange{ -// min: math.MaxInt, -// max: 0, -// } -// } - -// // Returns a slice of locations where each token in the batch should be stored -// func (c *Causal) findLocs() ([]int32, error) { -// loc := make([]int32, 0, c.curBatchSize) - -// for i := range c.cells { -// if len(c.cells[i].sequences) == 0 { -// loc = append(loc, int32(i)) -// if len(loc) >= c.curBatchSize { -// return loc, nil -// } -// } -// } - -// return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize) -// } - -// func (c *Causal) updateSlidingWindow() { -// c.curCellRange = newRange() - -// if c.swaMemorySize == math.MaxInt32 { -// for _, seq := range c.curSequences { -// if seqRange, ok := c.cellRanges[seq]; ok { -// c.curCellRange.min = min(c.curCellRange.min, seqRange.min) -// c.curCellRange.max = max(c.curCellRange.max, seqRange.max) -// } -// } - -// return -// } - -// type lowestPosition struct { -// pos int32 -// curBatch bool -// } - -// // create a map of unique sequences to the lowest position in that sequence -// lowestPos := make(map[int]lowestPosition) -// for i := range c.curPositions { -// seq := c.curSequences[i] - -// lowest, ok := lowestPos[seq] -// if !ok { -// lowest = lowestPosition{pos: c.curPositions[i], curBatch: true} -// } else if c.curPositions[i] < lowest.pos { -// lowest.pos = c.curPositions[i] -// } - -// lowestPos[seq] = lowest -// } - -// // for any sequences are not part of this batch, clean up any tokens -// // that are no longer needed after the processing of the previous -// // batch -// for seq, seqRange := range c.cellRanges { -// if _, ok := lowestPos[seq]; !ok { -// var last int32 -// for i := seqRange.min; i <= seqRange.max; i++ { -// if slices.Contains(c.cells[i].sequences, seq) { -// last = max(last, c.cells[i].pos) -// } -// } - -// lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false} -// } -// } - -// // delete any entries that are beyond the window of the oldest position in the sequence -// for seq, lowest := range lowestPos { -// oldRange, ok := c.cellRanges[seq] -// if !ok { -// continue -// } - -// newRange := newRange() - -// for i := oldRange.min; i <= oldRange.max; i++ { -// if slices.Contains(c.cells[i].sequences, seq) { -// if c.cells[i].pos < lowest.pos-c.swaMemorySize { -// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq }) -// } else { -// newRange.min = min(newRange.min, i) -// newRange.max = max(newRange.max, i) -// } -// if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize { -// c.curCellRange.min = min(c.curCellRange.min, i) -// c.curCellRange.max = max(c.curCellRange.max, i) -// } -// } -// } - -// c.cellRanges[seq] = newRange -// } -// } - -// func roundDown(length, pad int) int { -// return (length / pad) * pad -// } - -// func roundUp(length, pad int) int { -// return ((length + pad - 1) / pad) * pad -// } - -// // Builds a mask of history x batch indicating whether for each token in the batch the -// // token in the history should apply. This is based on both the sequence and causality (the -// // position of the history is not ahead of the token in the batch). -// func (c *Causal) buildMask(ctx ml.Context) ml.Tensor { -// // Align and pad the two dimensions as required by the backend -// batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding) - -// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding) -// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1 - -// length := c.curCellRange.max - c.curCellRange.min + 1 - -// mask := make([]float32, batchSize*length) - -// for i := range c.curBatchSize { -// enabled := !slices.Contains(c.opts.Except, i) -// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { -// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) || -// (enabled && c.cells[j].pos > c.curPositions[i]) || -// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize || -// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize { -// mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1)) -// } -// } -// } - -// // Mask out any padding tokens we added. For padding that we added to the cache history, this -// // has already been masked out because the sequence doesn't match. -// for i := c.curBatchSize * length; i < len(mask); i++ { -// mask[i] = float32(math.Inf(-1)) -// } - -// maskTensor := ctx.Input().FromFloats(mask, batchSize, length) - -// // if c.config.MaskDType != ml.DTypeFloat32 { -// // maskTensor = maskTensor.Cast(ctx, c.config.MaskDType) -// // } - -// slog.Info("XXX Causal.buildMask", "c.curBatchSize", c.curBatchSize, "c.config.MaskBatchPadding", c.config.MaskBatchPadding, "c.curCellRange.min", c.curCellRange.min, "c.curCellRange.max", c.curCellRange.max, "size", len(mask), "shape", []int{1, batchSize, length}) - -// return maskTensor -// } - -// func (c *Causal) SetLayer(layer int) { -// c.curLayer = layer -// } - -// type CausalOptions struct { -// // Enabled controls whether the causal mask is generated for a particular index in a batch -// Except []int -// } - -// // SetCausal disables causal mask generation for a particular range of indicies in -// // the current batch for subsequent calls to Get. The state resets for the next forward pass. -// func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) { -// if !slices.Equal(c.opts.Except, opts.Except) { -// c.opts = opts -// if ctx != nil { -// c.curMask = c.buildMask(ctx) -// } -// } -// } - -// func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { -// key := c.keys[c.curLayer] -// value := c.values[c.curLayer] - -// kHeadDim := c.kHeadDims[c.curLayer] -// vHeadDim := c.vHeadDims[c.curLayer] -// numKVHeads := c.numKVHeads[c.curLayer] -// // rowSize := numKVHeads * c.curBatchSize -// // cachedSize := c.curMask.Dim(1) -// cachedSize := c.curLoc.Dim(0) -// // kCellSize := kHeadDim * numKVHeads -// // vCellSize := vHeadDim * numKVHeads - -// slog.Info("XXX Causal.Get full cache", "key", key) -// slog.Info("XXX Causal.Get full cache", "value", value) -// slog.Info("XXX Causal.Get full cache", "curloc", c.curLoc) -// slog.Info("XXX Causal.Get", "curMask", c.curMask) -// slog.Info("XXX Causal.Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "cachedSize", cachedSize, "kHeadDim", kHeadDim) -// // panic("XXX") - -// // fmt.Fprintln(os.Stderr, key.ToString()) -// // panic("full cache value") - -// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask -// key = key.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim) -// // key = key.AsStrided(ctx, []int{1, numKVHeads, cachedSize, kHeadDim}, []int{}, rowSize*c.curCellRange.min) - -// // slog.Info("XXX Causal.Get after AsStrided", "key", key) -// // panic("XXX") - -// // if c.config.PermutedV { -// // panic("permuted") -// // // TODO not converted -// // vHeadDim := value.Dim(1) -// // elemSize := value.Stride(2) - -// // value = value.AsStrided(ctx, -// // []int{numKVHeads, vHeadDim, cachedSize}, -// // []int{value.Stride(0), value.Stride(1)}, -// // elemSize*c.curCellRange.min, -// // ) -// // } else { -// // vHeadDim := c.vHeadDims[c.curLayer] -// // rowSize := value.Stride(2) -// // slog.Info("XXX Causal.Get before AsStrided", "vHeadDim", vHeadDim, "rowSize", rowSize) -// // panic("XXX") - -// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask -// value = value.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim) -// // value = value.AsStrided(ctx, []int{1, numKVHeads, cachedSize, vHeadDim}, []int{}, rowSize*c.curCellRange.min) - -// // slog.Info("XXX Causal.Get after AsStrided", "value", value) -// // panic("XXX") - -// // } - -// // // TODO The mask changes from X,X to 1,X, and with the Row-order change -// // // the 1 becomes trailing and messes up later operations -// // // This isn't the right solution, but works around it... -// // if c.curMask.Dim(1) == 1 { -// // return key, value, c.curMask.Transpose(ctx, 1, 0, 2, 3) -// // } -// // fmt.Fprintln(os.Stderr, key.ToString()) -// // fmt.Fprintln(os.Stderr, value.ToString()) -// // panic("XXX") -// slog.Info("XXX Mask", "curLayer", c.curLayer, "shape", c.curMask.Shape()) - -// return key, value, c.curMask -// } - -// func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { -// kHeadDim := key.Dim(3) -// vHeadDim := value.Dim(3) -// numKVHeads := key.Dim(1) -// batchSize := key.Dim(2) -// kCellSize := kHeadDim * numKVHeads -// vCellSize := vHeadDim * numKVHeads - -// // slog.Info("XXX Causal.Put", "key", key, "value", value) -// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize) -// // panic("XXX") - -// if c.curBatchSize != batchSize { -// panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize)) -// } - -// // slog.Info("XXX", "c.ctxs", c.ctxs, "c.curLayer", c.curLayer, "backend", c.backend) -// if _, ok := c.ctxs[c.curLayer]; !ok { -// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer) -// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer) -// } - -// if _, ok := c.keys[c.curLayer]; !ok { -// slog.Info("XXX Causal.Put allocating keys", "c.curLayer", c.curLayer, "shape", []int{len(c.cells), kCellSize}) - -// c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), kCellSize) -// c.kHeadDims[c.curLayer] = kHeadDim -// c.vHeadDims[c.curLayer] = vHeadDim -// c.numKVHeads[c.curLayer] = numKVHeads -// } - -// if _, ok := c.values[c.curLayer]; !ok { -// // if c.config.PermutedV { -// // c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, numKVHeads, vHeadDim, len(c.cells)) -// // } else { -// c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vCellSize) -// // } -// } - -// key = key.Reshape(ctx, batchSize, 1, kCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed - -// // slog.Info("XXX Causal.Put after reshape", "keyCache", keyCache) -// // panic("XXX") -// // curLoc := 0 // TODO c.curLoc is now a tensor -// // kSize := numKVHeads * kHeadDim -// // vSize := numKVHeads * vHeadDim -// // start := []int{int(curLoc), 0} -// // kStop := []int{int(curLoc + batchSize), int(kSize)} -// // vStop := []int{int(curLoc + batchSize), int(vSize)} -// // strides := []int{1, 1} - -// // slog.Info("XXX Causal.Put Key SliceUpdate", "keyCache", keyCache) -// // slog.Info("XXX Causal.Put Key SliceUpdate", "key", key) - -// // slog.Info("XXX Causal.Put Key SliceUpdate", "start", start, "kStop", kStop, "strides", strides) - -// // ctx.Forward(c.keys[c.curLayer].SliceUpdate(ctx, key, start, kStop, strides)) -// ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, key, []int{0})) -// // fmt.Fprintln(os.Stderr, keyCache.ToString()) -// // panic("input value") - -// // fmt.Fprintln(os.Stderr, t.ToString()) -// // panic("XXX") - -// // if c.config.PermutedV { -// // panic("permuted") -// // // TODO not adjusted -// // value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize) -// // value = value.Transpose(ctx, 2, 0, 1, 3) - -// // valueCache := c.values[c.curLayer] -// // valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads) - -// // ctx.Forward(valueCache.SliceUpdate(ctx, value, start, vStop, strides)) -// // } else { -// value = value.Reshape(ctx, batchSize, 1, vCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed -// // slog.Info("XXX Causal.Put Value SliceUpdate", "valueCache", valueCache) -// // slog.Info("XXX Causal.Put Value SliceUpdate", "value", value) -// // slog.Info("XXX Causal.Put Value SliceUpdate", "start", start, "vStop", vStop, "strides", strides) - -// ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, value, []int{0})) -// // } -// // fmt.Fprintln(os.Stderr, c.keys[c.curLayer].ToString()) -// // fmt.Fprintln(os.Stderr, c.values[c.curLayer].ToString()) -// // panic("XXX") - -// } - -// func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { -// seqRange := newRange() - -// for i := range c.cells { -// // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end -// if slices.Contains(c.cells[i].sequences, dstSeq) { -// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq }) -// } - -// if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len { -// c.cells[i].sequences = append(c.cells[i].sequences, dstSeq) -// if i < seqRange.min { -// seqRange.min = i -// } -// if i > seqRange.max { -// seqRange.max = i -// } -// } -// } - -// c.cellRanges[dstSeq] = seqRange -// } - -// func (c *Causal) CanResume(seq int, pos int32) bool { -// if c.swaMemorySize == math.MaxInt32 { -// return true -// } - -// seqRange, ok := c.cellRanges[seq] -// if !ok { -// return false -// } - -// // for sliding window, check that the window of the new sequence is contained in -// // the window of what we are storing -// var first int32 = math.MaxInt32 -// var last int32 = -1 -// for i := seqRange.min; i <= seqRange.max; i++ { -// if slices.Contains(c.cells[i].sequences, seq) { -// first = min(first, c.cells[i].pos) -// last = max(last, c.cells[i].pos) -// } -// } - -// if last == -1 { -// return false -// } - -// posWindowStart := max(0, pos-c.swaWindowSize) -// return posWindowStart >= first && pos <= last+1 -// } - -// func (c *Causal) shift(seq int, beginIndex, offset int32) error { -// if c.shiftFn == nil { -// return ErrNotSupported -// } - -// seqRange := c.cellRanges[seq] - -// for start := seqRange.min; start <= seqRange.max; start += c.maxBatch { -// size := min(seqRange.max-start+1, c.maxBatch) -// offsets := make([]int32, size) - -// var batchFirst, batchLast int - -// batchFirst = -1 -// for i := range offsets { -// cell := c.cells[start+i] - -// if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex { -// offsets[i] = offset -// if batchFirst < 0 { -// batchFirst = i -// } -// batchLast = i -// } -// } - -// if batchFirst < 0 { -// continue -// } - -// offsets = offsets[batchFirst : batchLast+1] - -// slog.Info("XXX Causal.shift creating new temporary context") -// ctx := c.backend.NewContext() -// kShift := ctx.Input().FromInts(offsets, len(offsets)) - -// for i, key := range c.keys { -// if key == nil { -// continue -// } - -// kHeadDim := key.Dim(2) -// numKVHeads := key.Dim(1) -// rowSize := key.Stride(0) - -// key = key.AsStrided(ctx, -// []int{len(offsets), numKVHeads, kHeadDim}, -// []int{key.Stride(0), key.Stride(1)}, -// rowSize*(start+batchFirst), -// ) - -// roped, err := c.shiftFn(ctx, i, key, kShift) -// if err != nil { -// ctx.Close() -// return err -// } - -// ctx.Forward(roped.Copy(ctx, key)) -// } - -// ctx.Compute() -// ctx.Close() -// } - -// return nil -// } - -// func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error { -// // TODO(jessegross): We should check to see if removing the middle of the sequence will -// // cause the sliding window to encompass tokens that we no longer have. If so, then we -// // should return an error, which will trigger the runner to evaluate the full history and -// // rebuild the window. However, if we have multimodal inputs in our history, this reuse -// // results in use after free, so we don't do it for now. - -// var offset int32 -// if endIndex != math.MaxInt32 { -// offset = beginIndex - endIndex -// } - -// seqRange := newRange() - -// for i := range c.cells { -// if slices.Contains(c.cells[i].sequences, seq) { -// if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex { -// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq }) -// } else { -// if c.cells[i].pos >= endIndex { -// if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) { -// return errors.New("shifting cells shared by multiple sequences not supported") -// } - -// c.cells[i].pos += offset -// } -// if i < seqRange.min { -// seqRange.min = i -// } -// if i > seqRange.max { -// seqRange.max = i -// } -// } -// } -// } - -// if seqRange == newRange() { -// delete(c.cellRanges, seq) -// return nil -// } - -// c.cellRanges[seq] = seqRange - -// if endIndex != math.MaxInt32 { -// err := c.shift(seq, endIndex+offset, offset) -// if err != nil { -// return err -// } -// } - -// return nil -// } diff --git a/x/kvcache/causal_test.go b/x/kvcache/causal_test.go deleted file mode 100644 index d7ac430b1..000000000 --- a/x/kvcache/causal_test.go +++ /dev/null @@ -1,973 +0,0 @@ -package kvcache - -// import ( -// "fmt" -// "math" -// "slices" -// "testing" - -// "github.com/ollama/ollama/ml" -// "github.com/ollama/ollama/model/input" -// ) - -// type testCase struct { -// name string -// in []float32 -// inShape []int -// seqs []int -// pos []int32 -// expected []float32 -// expectedShape []int -// expectedMask []float32 -// } - -// func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) { -// t.Helper() -// for _, permuted := range []bool{false, true} { -// t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) { -// fn(t, &testBackend{permutedV: permuted}) -// }) -// } -// } - -// func TestStore(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// cache := NewCausalCache(nil) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 1, 16, 16) - -// tests := []testCase{ -// { -// name: "FirstBatch", -// in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234}, -// inShape: []int{2, 3, 4}, -// seqs: []int{0, 0, 0, 0}, -// pos: []int32{0, 1, 2, 3}, -// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234}, -// expectedShape: []int{2, 3, 4}, -// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0}, -// }, -// { -// name: "SecondBatch", -// in: []float32{115, 215, 125, 225, 135, 235}, -// inShape: []int{2, 3, 1}, -// seqs: []int{0}, -// pos: []int32{4}, -// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235}, -// expectedShape: []int{2, 3, 5}, -// expectedMask: []float32{0, 0, 0, 0, 0}, -// }, -// } - -// testCache(t, backend, cache, tests) -// }) -// } - -// func TestSWA(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// cache := NewSWACache(1, nil) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 1, 16, 16) - -// x := float32(math.Inf(-1)) - -// tests := []testCase{ -// { -// name: "FirstBatch", -// in: []float32{1, 2, 3, 4}, -// inShape: []int{1, 1, 4}, -// seqs: []int{0, 0, 0, 0}, -// pos: []int32{0, 1, 2, 3}, -// expected: []float32{1, 2, 3, 4}, -// expectedShape: []int{1, 1, 4}, -// expectedMask: []float32{ -// 0, x, x, x, -// 0, 0, x, x, -// x, 0, 0, x, -// x, x, 0, 0, -// }, -// }, -// { -// name: "SecondBatch", -// in: []float32{5, 6}, -// inShape: []int{1, 1, 2}, -// seqs: []int{0, 0}, -// pos: []int32{4, 5}, -// expected: []float32{5, 6, 3, 4}, -// expectedShape: []int{1, 1, 4}, -// expectedMask: []float32{ -// 0, x, x, 0, -// 0, 0, x, x, -// }, -// }, -// } - -// testCache(t, backend, cache, tests) -// }) -// } - -// func TestSWASeparateBatches(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// cache := NewSWACache(1, nil) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 2, 16, 2) - -// x := float32(math.Inf(-1)) - -// tests := []testCase{ -// { -// name: "First seq 0", -// in: []float32{1, 2}, -// inShape: []int{1, 1, 2}, -// seqs: []int{0, 0}, -// pos: []int32{0, 1}, -// expected: []float32{1, 2}, -// expectedShape: []int{1, 1, 2}, -// expectedMask: []float32{ -// 0, x, -// 0, 0, -// }, -// }, -// { -// name: "Second seq 0", -// in: []float32{3, 4}, -// inShape: []int{1, 1, 2}, -// seqs: []int{0, 0}, -// pos: []int32{2, 3}, -// expected: []float32{2, 3, 4}, -// expectedShape: []int{1, 1, 3}, -// expectedMask: []float32{ -// 0, 0, x, -// x, 0, 0, -// }, -// }, -// { -// name: "First seq 1", -// in: []float32{5, 6}, -// inShape: []int{1, 1, 2}, -// seqs: []int{1, 1}, -// pos: []int32{0, 1}, -// expected: []float32{5, 6}, -// expectedShape: []int{1, 1, 2}, -// expectedMask: []float32{ -// 0, x, -// 0, 0, -// }, -// }, -// { -// name: "Second seq 1", -// in: []float32{7, 8}, -// inShape: []int{1, 1, 2}, -// seqs: []int{1, 1}, -// pos: []int32{2, 3}, -// expected: []float32{6, 3, 4, 7, 8}, -// expectedShape: []int{1, 1, 5}, -// expectedMask: []float32{ -// 0, x, x, 0, x, -// x, x, x, 0, 0, -// }, -// }, -// { -// name: "Third seq 0", -// in: []float32{9, 10}, -// inShape: []int{1, 1, 2}, -// seqs: []int{0, 0}, -// pos: []int32{4, 5}, -// expected: []float32{9, 10, 3, 4}, -// expectedShape: []int{1, 1, 4}, -// expectedMask: []float32{ -// 0, x, x, 0, -// 0, 0, x, x, -// }, -// }, -// } - -// testCache(t, backend, cache, tests) -// }) -// } - -// func TestSWAMem(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// cache := NewSWAMemCache(1, 3, nil) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 1, 16, 16) - -// x := float32(math.Inf(-1)) - -// tests := []testCase{ -// { -// name: "FirstBatch", -// in: []float32{1, 2, 3, 4}, -// inShape: []int{1, 1, 4}, -// seqs: []int{0, 0, 0, 0}, -// pos: []int32{0, 1, 2, 3}, -// expected: []float32{1, 2, 3, 4}, -// expectedShape: []int{1, 1, 4}, -// expectedMask: []float32{ -// 0, x, x, x, -// 0, 0, x, x, -// x, 0, 0, x, -// x, x, 0, 0, -// }, -// }, -// { -// name: "SecondBatch", -// in: []float32{5, 6}, -// inShape: []int{1, 1, 2}, -// seqs: []int{0, 0}, -// pos: []int32{4, 5}, -// expected: []float32{5, 2, 3, 4, 6}, -// expectedShape: []int{1, 1, 5}, -// expectedMask: []float32{ -// 0, x, x, 0, x, -// 0, x, x, x, 0, -// }, -// }, -// } - -// testCache(t, backend, cache, tests) -// }) -// } - -// func TestChunkedAttention(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// cache := NewChunkedAttentionCache(2, nil) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 1, 16, 16) - -// x := float32(math.Inf(-1)) - -// testCache( -// t, backend, cache, -// []testCase{ -// { -// name: "FirstBatch", -// in: []float32{1, 2, 3, 4}, -// inShape: []int{1, 1, 4}, -// seqs: []int{0, 0, 0, 0}, -// pos: []int32{0, 1, 2, 3}, -// expected: []float32{1, 2, 3, 4}, -// expectedShape: []int{1, 1, 4}, -// expectedMask: []float32{ -// 0, x, x, x, -// 0, 0, x, x, -// x, x, 0, x, -// x, x, 0, 0, -// }, -// }, -// { -// name: "SecondBatch", -// in: []float32{5, 6, 7}, -// inShape: []int{1, 1, 3}, -// seqs: []int{0, 0, 0}, -// pos: []int32{4, 5, 6}, -// expected: []float32{1, 2, 3, 4, 5, 6, 7}, -// expectedShape: []int{1, 1, 7}, -// expectedMask: []float32{ -// x, x, x, x, 0, x, x, -// x, x, x, x, 0, 0, x, -// x, x, x, x, x, x, 0, -// }, -// }, -// { -// name: "ThirdBatch", -// in: []float32{8, 9}, -// inShape: []int{1, 1, 2}, -// seqs: []int{0, 0}, -// pos: []int32{7, 8}, -// expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, -// expectedShape: []int{1, 1, 9}, -// expectedMask: []float32{ -// x, x, x, x, x, x, 0, 0, x, -// x, x, x, x, x, x, x, x, 0, -// }, -// }, -// }, -// ) -// }) -// } - -// func TestSequences(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// cache := NewCausalCache(nil) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 1, 16, 16) - -// tests := []testCase{ -// { -// name: "FirstBatch", -// in: []float32{1, 2, 3, 4}, -// inShape: []int{1, 1, 4}, -// seqs: []int{0, 0, 1, 1}, -// pos: []int32{0, 1, 0, 1}, -// expected: []float32{1, 2, 3, 4}, -// expectedShape: []int{1, 1, 4}, -// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, -// }, -// { -// name: "SecondBatch", -// in: []float32{5, 6}, -// inShape: []int{1, 1, 2}, -// seqs: []int{0, 1}, -// pos: []int32{2, 2}, -// expected: []float32{1, 2, 3, 4, 5, 6}, -// expectedShape: []int{1, 1, 6}, -// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0}, -// }, -// } - -// testCache(t, backend, cache, tests) -// }) -// } - -// func TestRemove(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { -// return key.Add(ctx, shift), nil -// }) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 1, 16, 16) - -// x := float32(math.Inf(-1)) - -// tests := []testCase{ -// { -// name: "FirstBatch", -// in: []float32{1, 2, 3, 4}, -// inShape: []int{1, 1, 4}, -// seqs: []int{0, 0, 1, 1}, -// pos: []int32{0, 1, 0, 1}, -// expected: []float32{1, 2, 3, 4}, -// expectedShape: []int{1, 1, 4}, -// expectedMask: []float32{ -// 0, x, x, x, -// 0, 0, x, x, -// x, x, 0, x, -// x, x, 0, 0, -// }, -// }, -// } - -// testCache(t, backend, cache, tests) - -// err := cache.Remove(0, 1, math.MaxInt32) -// if err != nil { -// panic(err) -// } - -// tests = []testCase{ -// { -// name: "RemoveEnd", -// in: []float32{5, 6}, -// inShape: []int{1, 1, 2}, -// seqs: []int{0, 1}, -// pos: []int32{1, 2}, -// expected: []float32{1, 5, 3, 4, 6}, -// expectedShape: []int{1, 1, 5}, -// expectedMask: []float32{ -// 0, 0, x, x, x, -// x, x, 0, 0, 0, -// }, -// }, -// } - -// testCache(t, backend, cache, tests) - -// err = cache.Remove(0, 0, 1) -// if err != nil { -// panic(err) -// } - -// tests = []testCase{ -// { -// name: "RemoveMiddle", -// in: []float32{7, 8}, -// inShape: []int{1, 1, 2}, -// seqs: []int{0, 0}, -// pos: []int32{1, 2}, -// expected: []float32{7, 4, 3, 4, 6, 8}, -// expectedShape: []int{1, 1, 6}, -// expectedMask: []float32{ -// 0, 0, x, x, x, x, -// 0, 0, x, x, x, 0, -// }, -// }, -// } - -// testCache(t, backend, cache, tests) -// }) -// } - -// func TestCopy(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil }) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 1, 16, 16) - -// tests := []testCase{ -// { -// name: "FirstBatch", -// in: []float32{1, 2, 3, 4}, -// inShape: []int{1, 1, 4}, -// seqs: []int{0, 0, 0, 0}, -// pos: []int32{0, 1, 2, 3}, -// expected: []float32{1, 2, 3, 4}, -// expectedShape: []int{1, 1, 4}, -// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0}, -// }, -// } - -// testCache(t, backend, cache, tests) - -// cache.CopyPrefix(0, 1, 2) - -// tests = []testCase{ -// { -// name: "Copy", -// in: []float32{5, 6}, -// inShape: []int{1, 1, 2}, -// seqs: []int{1, 1}, -// pos: []int32{3, 4}, -// expected: []float32{1, 2, 3, 4, 5, 6}, -// expectedShape: []int{1, 1, 6}, -// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, -// }, -// } - -// testCache(t, backend, cache, tests) -// }) -// } - -// func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) { -// for _, test := range tests { -// t.Run(test.name, func(t *testing.T) { -// context := backend.NewContext() -// defer context.Close() - -// err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false) -// if err != nil { -// panic(err) -// } - -// cache.SetLayer(0) -// tensor := context.FromFloats(test.in, test.inShape...) -// cache.Put(context, tensor, tensor) - -// out, _, mask := cache.Get(context) - -// context.Forward(out, mask).Compute(out, mask) - -// if !slices.Equal(out.Floats(), test.expected) { -// t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected) -// } - -// if !slices.Equal(out.Shape(), test.expectedShape) { -// t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape) -// } - -// if !slices.Equal(mask.Floats(), test.expectedMask) { -// t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask) -// } -// }) -// } -// } - -// func TestCanResume(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// windowSize := int32(4) -// cache := NewSWACache(windowSize, nil) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 1, 16, 16) - -// context := backend.NewContext() -// defer context.Close() - -// err := cache.StartForward(context, input.Batch{ -// Positions: []int32{0, 1, 2, 3, 4}, -// Sequences: []int{0, 0, 0, 0, 0}, -// }, false) -// if err != nil { -// t.Fatalf("StartForward failed: %v", err) -// } - -// cache.SetLayer(0) -// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5) -// cache.Put(context, tensor, tensor) - -// // with window size 4, nothing has slid out of the window yet -// if !cache.CanResume(0, 0) { -// t.Errorf("CanResume(0, 0) = false, want true (within window)") -// } -// if !cache.CanResume(0, 1) { -// t.Errorf("CanResume(0, 1) = false, want true (within window)") -// } -// if !cache.CanResume(0, 2) { -// t.Errorf("CanResume(0, 2) = false, want true (within window)") -// } -// if !cache.CanResume(0, 3) { -// t.Errorf("CanResume(0, 3) = false, want true (latest position)") -// } -// if !cache.CanResume(0, 4) { -// t.Errorf("CanResume(0, 4) = false, want true (latest position)") -// } - -// // shift window by adding position 5 -// err = cache.StartForward(context, input.Batch{ -// Positions: []int32{5}, -// Sequences: []int{0}, -// }, false) -// if err != nil { -// t.Fatalf("StartForward failed: %v", err) -// } - -// cache.SetLayer(0) -// tensor = context.FromFloats([]float32{6}, 1, 1, 1) -// cache.Put(context, tensor, tensor) - -// // only the latest position has overlapping windows -// if cache.CanResume(0, 0) { -// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)") -// } -// if cache.CanResume(0, 1) { -// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)") -// } -// if cache.CanResume(0, 2) { -// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)") -// } -// if cache.CanResume(0, 3) { -// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)") -// } -// if cache.CanResume(0, 4) { -// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)") -// } -// if !cache.CanResume(0, 5) { -// t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)") -// } -// }) -// } - -// func TestCanResumeSWAMem(t *testing.T) { -// runPermutedVariants(t, func(t *testing.T, backend *testBackend) { -// windowSize := int32(4) -// memSize := int32(5) -// cache := NewSWAMemCache(windowSize, memSize, nil) -// defer cache.Close() - -// cache.Init(backend, ml.DTypeF16, 1, 16, 16) - -// context := backend.NewContext() -// defer context.Close() - -// err := cache.StartForward(context, input.Batch{ -// Positions: []int32{0, 1, 2, 3, 4, 5, 6}, -// Sequences: []int{0, 0, 0, 0, 0, 0, 0}, -// }, false) -// if err != nil { -// t.Fatalf("StartForward failed: %v", err) -// } - -// cache.SetLayer(0) -// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7) -// cache.Put(context, tensor, tensor) - -// // shift window by adding position 7 -// err = cache.StartForward(context, input.Batch{ -// Positions: []int32{7}, -// Sequences: []int{0}, -// }, false) -// if err != nil { -// t.Fatalf("StartForward failed: %v", err) -// } - -// cache.SetLayer(0) -// tensor = context.FromFloats([]float32{8}, 1, 1, 1) -// cache.Put(context, tensor, tensor) - -// // only the latest position has overlapping windows -// if cache.CanResume(0, 0) { -// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)") -// } -// if cache.CanResume(0, 1) { -// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)") -// } -// if cache.CanResume(0, 2) { -// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)") -// } -// if cache.CanResume(0, 3) { -// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)") -// } -// if cache.CanResume(0, 4) { -// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)") -// } -// if cache.CanResume(0, 5) { -// t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)") -// } -// if !cache.CanResume(0, 6) { -// t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)") -// } -// if !cache.CanResume(0, 7) { -// t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)") -// } -// }) -// } - -// type testBackend struct { -// ml.Backend -// permutedV bool -// } - -// func (b *testBackend) NewContext() ml.Context { -// return &testContext{} -// } - -// func (b *testBackend) NewContextSize(int) ml.Context { -// return &testContext{} -// } - -// func (b *testBackend) CacheConfig() ml.CacheConfig { -// return ml.CacheConfig{PermutedV: b.permutedV} -// } - -// type testContext struct { -// ml.Context -// } - -// func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor { -// total := 0 - -// if len(shape) > 0 { -// total = 1 -// for _, s := range shape { -// total *= s -// } -// } - -// return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape} -// } - -// func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { -// return c.Empty(dtype, shape...) -// } - -// func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor { -// t := c.Empty(ml.DTypeF32, shape...).(*testTensor) - -// copy(t.data, s) - -// return t -// } - -// func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor { -// f := make([]float32, len(s)) -// for i := range f { -// f[i] = float32(s[i]) -// } - -// out := c.FromFloats(f, shape...) -// out.(*testTensor).dtype = ml.DTypeI32 - -// return out -// } - -// func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { -// s := make([]float32, 0, int((stop-start)/step)) -// for i := start; i < stop; i += step { -// s = append(s, i) -// } - -// out := c.FromFloats(s, len(s)) -// out.(*testTensor).dtype = dtype -// return out -// } - -// func (c *testContext) Input() ml.Context { return c } -// func (c *testContext) Layer(int) ml.Context { return c } - -// func (c *testContext) Forward(...ml.Tensor) ml.Context { return c } - -// func (c *testContext) Compute(...ml.Tensor) {} - -// func (c *testContext) Reserve() {} - -// func (c *testContext) MaxGraphNodes() int { -// return 10 -// } - -// func (c *testContext) Close() {} - -// type testTensor struct { -// ml.Tensor - -// dtype ml.DType -// elementSize int -// data []float32 -// shape []int -// } - -// func (t *testTensor) Dim(n int) int { -// return t.shape[n] -// } - -// func (t *testTensor) Stride(n int) int { -// stride := t.elementSize -// for i := range n { -// stride *= t.shape[i] -// } - -// return stride -// } - -// func (t *testTensor) Shape() []int { -// return t.shape -// } - -// func (t *testTensor) DType() ml.DType { -// return t.dtype -// } - -// func (t *testTensor) Floats() []float32 { -// out := make([]float32, len(t.data)) -// copy(out, t.data) -// return out -// } - -// func (t *testTensor) Neg(ctx ml.Context) ml.Tensor { -// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor) -// for i := range out.data { -// out.data[i] = -t.data[i] -// } -// return out -// } - -// func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { -// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor) - -// for i := range out.data { -// out.data[i] = t.data[i] + t2.(*testTensor).data[i] -// } - -// return out -// } - -// func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { -// return &testTensor{ -// dtype: t.dtype, -// elementSize: t.elementSize, -// data: t.data, -// shape: shape, -// } -// } - -// func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { -// offset /= t.elementSize - -// var s []int - -// switch len(shape) { -// case 1: -// s = []int{shape[0]} -// case 3: -// s = []int{shape[0], shape[2]} -// case 5: -// s = []int{shape[0], shape[2], shape[4]} -// default: -// panic("unsupported number of dimensions") -// } - -// context := &testContext{} - -// view := context.Empty(t.dtype, s...).(*testTensor) -// view.data = t.data[offset : offset+len(view.data)] - -// return view -// } - -// func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor { -// if len(t.shape) > 4 || len(order) > 4 { -// panic("permute only supports up to 4 dimensions") -// } - -// if len(order) != len(t.shape) && len(order) != 4 { -// panic("invalid number of dimensions for permute") -// } - -// // ggml_permute expects 4 axes, so fill in any missing dimensions. -// orderFull := append(make([]int, 0, 4), order...) -// for len(orderFull) < 4 { -// orderFull = append(orderFull, len(orderFull)) -// } - -// seen := [4]bool{} - -// shape4 := [4]int{1, 1, 1, 1} -// for i := 0; i < len(t.shape) && i < 4; i++ { -// shape4[i] = t.shape[i] -// } - -// newShape4 := [4]int{1, 1, 1, 1} -// for axis := range 4 { -// dst := orderFull[axis] -// if dst < 0 || dst >= 4 { -// panic("invalid axis for permute") -// } -// if seen[dst] { -// panic("duplicate axis for permute") -// } -// seen[dst] = true -// newShape4[dst] = shape4[axis] -// } - -// total := len(t.data) -// newData := make([]float32, total) - -// if total > 0 { -// oldDims := shape4 -// newDims := newShape4 - -// oldStride := [4]int{1, 1, 1, 1} -// newStride := [4]int{1, 1, 1, 1} -// for i := 1; i < 4; i++ { -// oldStride[i] = oldStride[i-1] * oldDims[i-1] -// newStride[i] = newStride[i-1] * newDims[i-1] -// } - -// var coords [4]int -// var newCoords [4]int - -// for idx := range total { -// remainder := idx -// for axis := range 4 { -// dim := oldDims[axis] -// if dim == 0 { -// coords[axis] = 0 -// continue -// } -// coords[axis] = remainder % dim -// remainder /= dim -// } - -// for axis := range 4 { -// newCoords[orderFull[axis]] = coords[axis] -// } - -// newIndex := 0 -// for axis := range 4 { -// if newDims[axis] == 0 { -// continue -// } -// newIndex += newCoords[axis] * newStride[axis] -// } - -// newData[newIndex] = t.data[idx] -// } -// } - -// numDims := 4 -// for numDims > 1 && newShape4[numDims-1] <= 1 { -// numDims-- -// } - -// newShape := make([]int, numDims) -// copy(newShape, newShape4[:numDims]) - -// return &testTensor{ -// dtype: t.dtype, -// elementSize: t.elementSize, -// data: newData, -// shape: newShape, -// } -// } - -// func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor { -// dst := t -// srcTensor := src.(*testTensor) -// idxTensor := idxs.(*testTensor) - -// shapeTo4D := func(shape []int) [4]int { -// out := [4]int{1, 1, 1, 1} -// for i := 0; i < len(shape) && i < 4; i++ { -// out[i] = shape[i] -// } -// return out -// } - -// computeStrides := func(shape [4]int) [4]int { -// out := [4]int{1, 1, 1, 1} -// for i := 1; i < 4; i++ { -// out[i] = out[i-1] * shape[i-1] -// } -// return out -// } - -// dstShape4D := shapeTo4D(dst.shape) -// srcShape4D := shapeTo4D(srcTensor.shape) -// idxShape4D := shapeTo4D(idxTensor.shape) - -// if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] { -// panic("SetRows requires matching tensor shapes") -// } - -// if srcShape4D[1] != idxShape4D[0] { -// panic("SetRows rows/index mismatch") -// } - -// if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 { -// panic("SetRows cannot broadcast indices") -// } - -// if idxShape4D[3] != 1 { -// panic("SetRows expects 1D or 2D index tensors") -// } - -// dstStride := computeStrides(dstShape4D) -// srcStride := computeStrides(srcShape4D) -// idxStride := computeStrides(idxShape4D) - -// numColumns := srcShape4D[0] -// numRows := srcShape4D[1] - -// for dim3Index := range dstShape4D[3] { -// for dim2Index := range dstShape4D[2] { -// idxDim2 := 0 -// idxDim3 := 0 -// if idxShape4D[1] > 0 { -// idxDim2 = dim2Index % idxShape4D[1] -// } -// if idxShape4D[2] > 0 { -// idxDim3 = dim3Index % idxShape4D[2] -// } - -// idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1] -// srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2] -// dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2] - -// for row := range numRows { -// idx := int(idxTensor.data[idxBase+row*idxStride[0]]) -// if idx < 0 || idx >= dstShape4D[1] { -// panic("SetRows index out of range") -// } - -// srcOffset := srcBase + row*srcStride[1] -// dstOffset := dstBase + idx*dstStride[1] - -// copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns]) -// } -// } -// } - -// return dst -// } - -// func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { -// copy(t2.(*testTensor).data, t.data) -// return nil -// } diff --git a/x/kvcache/encoder.go b/x/kvcache/encoder.go deleted file mode 100644 index 19a3839ce..000000000 --- a/x/kvcache/encoder.go +++ /dev/null @@ -1,156 +0,0 @@ -package kvcache - -// import ( -// "fmt" - -// "github.com/ollama/ollama/ml" -// "github.com/ollama/ollama/model/input" -// ) - -// // Encoder cache stores K and V tensors that are position independent -// // -// // The tensors can be of any shape and will be returned as they were stored -// // The mask is currently always nil -// // -// // Not currently safe for multiple sequences -// type EncoderCache struct { -// // config controls mostly backend-specific optimizations -// config *ml.CacheConfig - -// // ** current forward pass ** - -// // the active layer for Get and Put -// curLayer int - -// // if something is stored during this pass, this -// // will be the position (but there is no guarantee -// // anything will be stored) -// curPos int32 - -// // curReserve indicates that this forward pass is only for -// // memory reservation and we should not update our metadata -// // based on it. -// curReserve bool - -// // ** cache metadata ** - -// // was something stored in the cache? -// encoderCached bool - -// // position of the cached data -// encoderPos int32 - -// // ** cache data storage ** -// backend ml.Backend -// ctxs map[int]ml.Context -// keys, values map[int]ml.Tensor -// } - -// func NewEncoderCache() *EncoderCache { -// return &EncoderCache{ -// ctxs: make(map[int]ml.Context), -// keys: make(map[int]ml.Tensor), -// values: make(map[int]ml.Tensor), -// } -// } - -// func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { -// if c.config == nil { -// var config ml.CacheConfig -// if cc, ok := backend.(ml.BackendCacheConfig); ok { -// config = cc.CacheConfig() -// } -// c.config = &config -// } - -// if maxSequences > 1 { -// panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences)) -// } - -// if c.config.CachePadding != 0 && c.config.CachePadding != 1 { -// panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding)) -// } - -// c.backend = backend -// } - -// func (c *EncoderCache) SetConfig(config ml.CacheConfig) { -// if c.config != nil { -// panic("config cannot be changed after being previously set, either by the model or backend") -// } - -// c.config = &config -// } - -// func (c *EncoderCache) Close() { -// for _, ctx := range c.ctxs { -// ctx.Close() -// } -// } - -// func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { -// // We work with the most recent image -// if len(batch.Multimodal) > 0 { -// c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index] -// } - -// c.curReserve = reserve - -// return nil -// } - -// func (c *EncoderCache) SetLayer(layer int) { -// c.curLayer = layer -// } - -// func (c *EncoderCache) EncoderCached() bool { -// return c.encoderCached -// } - -// func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { -// return c.keys[c.curLayer], c.values[c.curLayer], nil -// } - -// func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) { -// if !c.curReserve { -// c.encoderPos = c.curPos -// c.encoderCached = true -// } - -// if c.config.PermutedV { -// value = value.Transpose(ctx, 1, 2, 0, 3) -// } - -// if _, ok := c.ctxs[c.curLayer]; !ok { -// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer) -// } - -// if _, ok := c.keys[c.curLayer]; !ok { -// c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...) -// } - -// if _, ok := c.values[c.curLayer]; !ok { -// c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...) -// } - -// ctx.Forward( -// key.Copy(ctx, c.keys[c.curLayer]), -// value.Copy(ctx, c.values[c.curLayer]), -// ) -// } - -// func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) { -// panic("encoder cache does not support multiple sequences") -// } - -// func (c *EncoderCache) CanResume(seq int, pos int32) bool { -// return true -// } - -// func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error { -// if c.encoderPos >= beginIndex && c.encoderPos < endIndex { -// c.encoderCached = false -// } - -// return nil -// } diff --git a/x/kvcache/mlx.go b/x/kvcache/mlx.go deleted file mode 100644 index fa3865104..000000000 --- a/x/kvcache/mlx.go +++ /dev/null @@ -1,144 +0,0 @@ -//go:build mlx - -package kvcache - -import ( - "github.com/ollama/ollama/x/ml" - "github.com/ollama/ollama/x/model/input" -) - -// Causal cache stores K and V tensors according to their position in the -// sequence. Returns the history and a mask for attending to past tokens -type MLXCausal struct { - DType ml.DType - - // locations for data storage for this batch - curLocPut ml.Tensor - - // locations for data storage for this batch - curLocGet ml.Tensor - - // the active layer for Get and Put - curLayer int - - capacity int - - offset int - - backend ml.Backend - ctxs map[int]ml.Context - keys, values map[int]ml.Tensor - - // TODO is this needed per layer, or will it always be consistent? - kHeadDims, vHeadDims, numKVHeads map[int]int -} - -func NewMLXCausalCache() *MLXCausal { - return &MLXCausal{ - ctxs: make(map[int]ml.Context), - keys: make(map[int]ml.Tensor), - values: make(map[int]ml.Tensor), - kHeadDims: make(map[int]int), - vHeadDims: make(map[int]int), - numKVHeads: make(map[int]int), - } -} - -func (c *MLXCausal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { - c.DType = dtype - c.capacity = capacity - c.backend = backend -} - -func (c *MLXCausal) SetConfig(config ml.CacheConfig) {} - -func (c *MLXCausal) SetLayer(layer int) { - c.curLayer = layer -} - -func (c *MLXCausal) Close() { - // slog.Info("XXX MLXCausal.Close called", "number of contexts", len(c.ctxs)) - for _, ctx := range c.ctxs { - ctx.Close() - } -} - -func (c *MLXCausal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { - locsPut := make([]int32, len(batch.Positions)) - for i := c.offset; i < len(batch.Positions); i++ { - locsPut[i-c.offset] = int32(i) - } - c.offset += len(batch.Positions) - locsGet := make([]int32, c.offset) - for i := range c.offset { - locsGet[i] = int32(i) - } - c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet)) - c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut)) - // slog.Info("XXX MLXCausal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet) - - return nil -} -func (c *MLXCausal) Put(ctx ml.Context, key, value ml.Tensor) { - kHeadDim := key.Dim(3) - vHeadDim := value.Dim(3) - numKVHeads := key.Dim(1) - batchSize := key.Dim(2) - kCellSize := kHeadDim * numKVHeads - vCellSize := vHeadDim * numKVHeads - // slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize) - - if _, ok := c.ctxs[c.curLayer]; !ok { - // slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer) - c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer) - } - - if _, ok := c.keys[c.curLayer]; !ok { - // slog.Info("XXX MLXCausal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize}) - c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize) - c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize) - c.kHeadDims[c.curLayer] = kHeadDim - c.vHeadDims[c.curLayer] = vHeadDim - c.numKVHeads[c.curLayer] = numKVHeads - } - key = key.Reshape(ctx, batchSize, 1, kCellSize) - - // slog.Info("XXX MLXCausal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer]) - // slog.Info("XXX MLXCausal.Put ", "c.curLocPut", c.curLocPut) - // slog.Info("XXX MLXCausal.Put ", "key", key) - ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0})) - value = value.Reshape(ctx, batchSize, 1, vCellSize) - ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0})) - -} - -func (c *MLXCausal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { - key := c.keys[c.curLayer] - value := c.values[c.curLayer] - - kHeadDim := c.kHeadDims[c.curLayer] - vHeadDim := c.vHeadDims[c.curLayer] - numKVHeads := c.numKVHeads[c.curLayer] - // rowSize := numKVHeads * c.curBatchSize - // cachedSize := c.curMask.Dim(1) - cachedSize := c.curLocGet.Dim(0) - // kCellSize := kHeadDim * numKVHeads - // vCellSize := vHeadDim * numKVHeads - // slog.Info("XXX MLXCausal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim}) - - key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim) - value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim) - return key, value, nil -} - -func (c *MLXCausal) CopyPrefix(srcSeq, dstSeq int, len int32) { - panic("not implemented") -} - -func (c *MLXCausal) CanResume(seq int, pos int32) bool { - panic("not implemented") -} - -func (c *MLXCausal) Remove(seq int, beginIndex, endIndex int32) error { - panic("not implemented") -} diff --git a/x/kvcache/wrapper.go b/x/kvcache/wrapper.go deleted file mode 100644 index 69e07dc96..000000000 --- a/x/kvcache/wrapper.go +++ /dev/null @@ -1,110 +0,0 @@ -package kvcache - -// import ( -// "math" - -// "github.com/ollama/ollama/ml" -// "github.com/ollama/ollama/model/input" -// ) - -// // Wrapper cache is a container for multiple types of caches, -// // such as for the encoding and decoding portions of a model. -// type WrapperCache struct { -// // caches we are wrapping -// caches []Cache - -// // cache to be used for this layer -// curType int -// } - -// func NewWrapperCache(caches ...Cache) *WrapperCache { -// return &WrapperCache{ -// caches: caches, -// } -// } - -// func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { -// for _, cache := range c.caches { -// cache.Init(backend, dtype, maxSequences, capacity, maxBatch) -// } -// } - -// func (c *WrapperCache) SetConfig(config ml.CacheConfig) { -// for _, cache := range c.caches { -// cache.SetConfig(config) -// } -// } - -// func (c *WrapperCache) Close() { -// for _, cache := range c.caches { -// cache.Close() -// } -// } - -// func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { -// for i, cache := range c.caches { -// err := cache.StartForward(ctx, batch, reserve) -// if err != nil { -// // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail -// for j := i - 1; j >= 0; j-- { -// for k := range batch.Positions { -// _ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32) -// } -// } -// return err -// } -// } - -// c.curType = 0 -// return nil -// } - -// func (c *WrapperCache) SetLayer(layer int) { -// for _, cache := range c.caches { -// cache.SetLayer(layer) -// } -// } - -// func (c *WrapperCache) SetLayerType(layerType int) { -// c.curType = layerType -// } - -// func (c *WrapperCache) UnderlyingCache() Cache { -// return c.caches[c.curType] -// } - -// func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { -// return c.caches[c.curType].Get(ctx) -// } - -// func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) { -// c.caches[c.curType].Put(ctx, key, value) -// } - -// func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) { -// for _, cache := range c.caches { -// cache.CopyPrefix(srcSeq, dstSeq, len) -// } -// } - -// func (c *WrapperCache) CanResume(seq int, pos int32) bool { -// for _, cache := range c.caches { -// if !cache.CanResume(seq, pos) { -// return false -// } -// } - -// return true -// } - -// func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error { -// // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail -// for _, cache := range c.caches { -// err := cache.Remove(seq, beginIndex, endIndex) -// if err != nil { -// return err -// } -// } - -// return nil -// } diff --git a/x/ml/backend.go b/x/ml/backend.go deleted file mode 100644 index 31ff3541e..000000000 --- a/x/ml/backend.go +++ /dev/null @@ -1,433 +0,0 @@ -package ml - -import ( - "fmt" - "log/slog" - "os" - - "github.com/ollama/ollama/fs" -) - -type Backend interface { - // Close frees all memory associated with this backend - // Close() - - // Load(ctx context.Context, progress func(float32)) error - - // BackendMemory returns the memory allocations that were made for this model - // BackendMemory() BackendMemory - - Config() fs.Config - Get(name string) Tensor - NewContext() Context - // NewContextSize(size int) Context - - // Enumerate the devices available for inference via this backend - // BackendDevices() []DeviceInfo -} - -// BackendCacheConfig should be implemented by backends that need special output -// from the cache to meet specific requirements. It is frequently implemented in -// conjunction with ScaledDotProductAttention. -type BackendCacheConfig interface { - CacheConfig() CacheConfig -} - -// CacheConfig controls optimizations (mostly backend-specific) that may transform -// the output the cache to work better with specific kernels. -type CacheConfig struct { - // CachePadding specifies the multiple for the number of tokens of cache history - // that will be returned from cache Get for k, v and mask. The capacity of the - // cache itself will also be increased to a multiple of this size if needed. - CachePadding int - - // PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put - // and return the permuted version via Get. This uses the cache copy operation - // to avoid a Contiguous call on the permuted tensor. - PermutedV bool - - // MaskDType specifies the data type for generating the mask. If unset it will - // default to DTypeF32. - MaskDType DType - - // MaskBatchPadding specifies the multiple for the batch size dimension in the mask. - // Any position that does not correspond to an actual token will be filled with -Inf. - MaskBatchPadding int -} - -// BackendParams controls how the backend loads and executes models -type BackendParams struct { - // AllocMemory causes the backend to allocate memory for the model. If - // false, this is only being used for discovering the required amount of - // memory and cannot load the model for running. - AllocMemory bool - - // NumThreads sets the number of threads to use if running on the CPU - NumThreads int - - // GPULayers is the set of layers to offload to GPUs - GPULayers GPULayersList - - // FlashAttention indicates that we should use a fused flash attention kernel - FlashAttention bool -} - -var backends = make(map[string]func(string, BackendParams) (Backend, error)) - -func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) { - if _, ok := backends[name]; ok { - panic("backend: backend already registered") - } - - backends[name] = f -} - -func NewBackend(modelPath string, params BackendParams) (Backend, error) { - be := os.Getenv("OLLAMA_BACKEND") - if be == "" { - be = "mlx" - slog.Info("Defaulting to " + be + ". Set OLLAMA_BACKEND to override") - } - slog.Info("Loading new engine", "backend", be) - if backend, ok := backends[be]; ok { - return backend(modelPath, params) - } - - return nil, fmt.Errorf("unsupported backend") -} - -type Context interface { - Empty(dtype DType, shape ...int) Tensor - Zeros(dtype DType, shape ...int) Tensor - // FromBytes(dtype DType, s []byte, shape ...int) Tensor - FromFloats(s []float32, shape ...int) Tensor - FromInts(s []int32, shape ...int) Tensor - RandomNormal(shape []int, dtype DType, loc, scale float32, key Tensor) Tensor - - // Arange creates a 1D tensor with values within an interval (start, stop] increased by step. - Arange(start, stop, step float32, dtype DType) Tensor - - Forward(...Tensor) Context - - // SetBatchSize provides a hint on the batch size to optimize processing - // Uses heuristics if not set - // SetBatchSize(int) - - Compute(...Tensor) - // ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun - - // Reserve is analogous to Compute but rather than executing a - // graph, simply preallocates memory. Typically called with a - // worst case graph to ensure all resources are available for - // for future inference. - // Reserve() - - // MaxGraphNodes() int - Close() - - // Input returns a context appropriate for creating tensors that are - // inputs to the model (which includes things like output locations) - Input() Context - - // Layer returns a context appropriate for creating intermediate tensors - Layer(int) Context - - // Load a tensor from "filename" safetensors file, and compare with the input tensor - // Returns error if the shape is inconsistent, or similarity measures are below 99% - CompareWith(filename string, tensors map[string]Tensor, abortOnError bool) error -} - -type RoPEOptions struct { - Base *float32 - Freqs Tensor -} - -func WithRoPEBase(base float32) func(*RoPEOptions) { - return func(opts *RoPEOptions) { - opts.Base = &base - } -} - -func WithRoPEFreqs(freqs Tensor) func(*RoPEOptions) { - return func(opts *RoPEOptions) { - opts.Freqs = freqs - } -} - -type Tensor interface { - ToString() string - RoPE(ctx Context, dims int, traditional bool, scale float32, offset int, options ...func(*RoPEOptions)) Tensor - ScaledDotProductAttention(ctx Context, keys, values Tensor, scale float64, maskMode string, mask Tensor, sinks Tensor) Tensor - TakeAxes(ctx Context, indicies Tensor, axes int) Tensor - // TakeAxes(ctx Context, axes int, indicies ...int) Tensor - - Dim(n int) int - Stride(n int) int - - Shape() []int - DType() DType - // Cast(ctx Context, dtype DType) Tensor - - // Bytes() []byte - Floats() []float32 - Ints() []int32 - - // FromBytes([]byte) - // FromFloats([]float32) - // FromInts([]int32) - - Add(ctx Context, t2 Tensor) Tensor - Sub(ctx Context, t2 Tensor) Tensor - // Mul(ctx Context, t2 Tensor) Tensor - // Div(ctx Context, t2 Tensor) Tensor - - Max(ctx Context, axes []int, keepDims bool) Tensor - Min(ctx Context, axes []int, keepDims bool) Tensor - - Matmul(ctx Context, a2 Tensor) Tensor - // Mulmat(ctx Context, t2 Tensor) Tensor - // MulmatFullPrec(ctx Context, t2 Tensor) Tensor - // MulmatID(ctx Context, t2, ids Tensor) Tensor - // AddID(ctx Context, t2, ids Tensor) Tensor - - Softmax(ctx Context) Tensor - L2Norm(ctx Context, eps float32) Tensor - LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor - RMSNorm(ctx Context, weight Tensor, eps float32) Tensor - Scale(ctx Context, s float64) Tensor - // SumRows(ctx Context) Tensor - - AvgPool2D(ctx Context, k, s int, p float32) Tensor - Conv2D(ctx Context, weight Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) Tensor - Conv3D(ctx Context, weight Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) Tensor - - // IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor - - // Sin(ctx Context) Tensor - // Cos(ctx Context) Tensor - // Tanh(ctx Context) Tensor - GELU(ctx Context, up ...Tensor) Tensor - // QuickGELU(ctx Context, up ...Tensor) Tensor - // SILU(ctx Context, up ...Tensor) Tensor - // RELU(ctx Context, up ...Tensor) Tensor - // Sigmoid(ctx Context) Tensor - - // AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit] - // SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor - - Reshape(ctx Context, shape ...int) Tensor - AsStrided(ctx Context, shape, strides []int, offset int) Tensor - Transpose(ctx Context, shape ...int) Tensor - Contiguous(ctx Context, allowColMajor bool) Tensor - - // Pad(ctx Context, shape ...int) Tensor - - // Stack(ctx Context, dim int, s ...Tensor) Tensor - - // Repeat repeats the tensor n times along dimension dim - // Repeat(ctx Context, dim, n int) Tensor - // Concat(ctx Context, t2 Tensor, dim int) Tensor - // Rows(ctx Context, t2 Tensor) Tensor - - // TODO these probably aren't actually needed - false starts on trying to wire up cache - // SliceUpdate(ctx Context, update Tensor, start, stop, strides []int) Tensor - // SliceUpdateDynamic(ctx Context, update, start Tensor, axes []int) Tensor - // PutAlongAxis(ctx Context, indicies, values Tensor, axis int) Tensor - - Scatter(ctx Context, indicies []Tensor, updates Tensor, axes []int) Tensor - - Copy(ctx Context, t2 Tensor) Tensor - // Duplicate(ctx Context) Tensor - - // Slice(ctx Context, dim, low, high, step int) Tensor - // Chunk(ctx Context, dim int, size int) []Tensor - // ChunkSections(ctx Context, dim int, sections ...int) []Tensor - - // TopK(ctx Context, k int) Tensor - // Argsort(ctx Context) Tensor - // Mean(ctx Context) Tensor - // Variance(ctx Context) Tensor - // Stddev(ctx Context) Tensor - // Sqr(ctx Context) Tensor - // Sqrt(ctx Context) Tensor - - // Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor -} - -// ScaledDotProductAttention implements a fused attention -// operation equivalent to following code on a tensor named -// query: -// -// query = query.Permute(ctx, 0, 2, 1, 3) -// key = key.Permute(ctx, 0, 2, 1, 3) -// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) -// -// kq := key.MulmatFullPrec(ctx, query) -// -// kq = kq.Scale(ctx, scale) -// -// if mask != nil { -// kq = kq.Add(ctx, mask) -// } -// -// kq = kq.Softmax(ctx) -// -// kqv := value.Mulmat(ctx, kq) -// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) -// type ScaledDotProductAttention interface { -// ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor -// } - -// type number interface { -// ~int | ~int8 | ~int16 | ~int32 | ~int64 | -// ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | -// ~float32 | ~float64 | -// ~complex64 | ~complex128 -// } - -// func mul[T number](s ...T) T { -// p := T(1) -// for _, v := range s { -// p *= v -// } - -// return p -// } - -// type DumpOptions func(*dumpOptions) - -// // DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64. -// func DumpWithPrecision(n int) DumpOptions { -// return func(opts *dumpOptions) { -// opts.Precision = n -// } -// } - -// // DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements -// // is less than or equal to this value, the entire tensor will be printed. Otherwise, only the -// // beginning and end of each dimension will be printed. -// func DumpWithThreshold(n int) DumpOptions { -// return func(opts *dumpOptions) { -// opts.Threshold = n -// } -// } - -// // DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension. -// func DumpWithEdgeItems(n int) DumpOptions { -// return func(opts *dumpOptions) { -// opts.EdgeItems = n -// } -// } - -// type dumpOptions struct { -// Precision, Threshold, EdgeItems int -// } - -// func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string { -// opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3} -// for _, optsFunc := range optsFuncs { -// optsFunc(&opts) -// } - -// if mul(t.Shape()...) <= opts.Threshold { -// opts.EdgeItems = math.MaxInt -// } - -// switch t.DType() { -// case DTypeFloat32: -// return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string { -// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32) -// }) -// case DTypeFloat16: // TODO other types... -// f32 := ctx.Input().Empty(DTypeFloat32, t.Shape()...) -// f32 = t.Copy(ctx, f32) -// return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string { -// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32) -// }) -// case DTypeInt32: -// return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string { -// return strconv.FormatInt(int64(i), 10) -// }) -// default: -// return "" -// } -// } - -// func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string { -// if t.Bytes() == nil { -// ctx.Compute(t) -// } - -// s := make(S, mul(t.Shape()...)) -// if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil { -// panic(err) -// } - -// shape := t.Shape() -// slices.Reverse(shape) - -// var sb strings.Builder -// var f func([]int, int) -// f = func(dims []int, stride int) { -// prefix := strings.Repeat(" ", len(shape)-len(dims)+1) -// sb.WriteString("[") -// defer func() { sb.WriteString("]") }() -// for i := 0; i < dims[0]; i++ { -// if i >= items && i < dims[0]-items { -// sb.WriteString("..., ") -// // skip to next printable element -// skip := dims[0] - 2*items -// if len(dims) > 1 { -// stride += mul(append(dims[1:], skip)...) -// fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix) -// } -// i += skip - 1 -// } else if len(dims) > 1 { -// f(dims[1:], stride) -// stride += mul(dims[1:]...) -// if i < dims[0]-1 { -// fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix) -// } -// } else { -// text := fn(s[stride+i]) -// if len(text) > 0 && text[0] != '-' { -// sb.WriteString(" ") -// } - -// sb.WriteString(text) -// if i < dims[0]-1 { -// sb.WriteString(", ") -// } -// } -// } -// } -// f(shape, 0) - -// return sb.String() -// } - -type DType int - -const ( - DTypeBool DType = iota - DTypeUint8 - DTypeUint16 - DTypeUint32 - DTypeUint64 - DTypeInt8 - DTypeInt16 - DTypeInt32 - DTypeInt64 - DTypeFloat16 - DTypeFloat32 - DTypeFloat64 - DTypeBfloat16 - DTypeComplex64 -) - -type SamplingMode int - -const ( - SamplingModeNearest SamplingMode = iota - SamplingModeBilinear -) diff --git a/x/ml/backend/backend.go b/x/ml/backend/backend.go deleted file mode 100644 index b9dd4a13b..000000000 --- a/x/ml/backend/backend.go +++ /dev/null @@ -1,3 +0,0 @@ -package backend - -// _ "github.com/ollama/ollama/x/ml/backend/mlx" diff --git a/x/ml/backend/mlx/CMakeLists.txt b/x/ml/backend/mlx/CMakeLists.txt deleted file mode 100644 index e71a6567a..000000000 --- a/x/ml/backend/mlx/CMakeLists.txt +++ /dev/null @@ -1,57 +0,0 @@ -include(FetchContent) - -set(MLX_C_BUILD_EXAMPLES OFF) - -set(MLX_BUILD_GGUF OFF) -set(MLX_BUILD_SAFETENSORS ON) - -function(set_target_output_directory _target) - if(TARGET ${_target}) - set_target_properties(${_target} PROPERTIES - RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR} - LIBRARY_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR} - ARCHIVE_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR} - ) - endif() -endfunction() - -# Check for Metal support (macOS only) -if(CMAKE_SYSTEM_NAME MATCHES "Darwin") - execute_process( - COMMAND - zsh "-c" - "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'" - OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY) - - if(NOT MLX_METAL_VERSION) - message(STATUS "`xcrun metal` error. Setting MLX_BUILD_METAL=OFF") - set(MLX_BUILD_METAL OFF) - endif() -else() - # On Linux, disable Metal backend - message(STATUS "Non-macOS platform detected. Setting MLX_BUILD_METAL=OFF") - set(MLX_BUILD_METAL OFF) -endif() - -# Map CMAKE_CUDA_ARCHITECTURES to MLX_CUDA_ARCHITECTURES if not explicitly set -if(NOT MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_ARCHITECTURES) - set(MLX_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES}) - message(STATUS "Using CMAKE_CUDA_ARCHITECTURES for MLX: ${MLX_CUDA_ARCHITECTURES}") -endif() - -# Enable CUDA backend if CUDA architectures are specified and CUDA compiler is available -if(MLX_CUDA_ARCHITECTURES AND CMAKE_CUDA_COMPILER) - set(MLX_BUILD_CUDA ON CACHE BOOL "Build CUDA backend for MLX" FORCE) - message(STATUS "Enabling MLX CUDA backend with architectures: ${MLX_CUDA_ARCHITECTURES}") -elseif(MLX_CUDA_ARCHITECTURES) - message(WARNING "MLX_CUDA_ARCHITECTURES specified but CUDA compiler not found, CUDA backend will be disabled") -endif() - -FetchContent_Declare( - mlx-c - GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git" - GIT_TAG v0.4.1) -FetchContent_MakeAvailable(mlx-c) - -set_target_output_directory(mlx) -set_target_output_directory(mlxc) diff --git a/x/ml/backend/mlx/mlx.go b/x/ml/backend/mlx/mlx.go deleted file mode 100644 index 1b647685e..000000000 --- a/x/ml/backend/mlx/mlx.go +++ /dev/null @@ -1,1278 +0,0 @@ -//go:build mlx - -package mlx - -/* -#cgo CPPFLAGS: -I${SRCDIR}/../../../../build/_deps/mlx-c-src -#cgo LDFLAGS: -L${SRCDIR}/../../../../build/lib/ollama/ -lmlxc -lmlx -#cgo LDFLAGS: -framework Accelerate -#cgo LDFLAGS: -Wl,-rpath,${SRCDIR}/../../../../build/lib/ollama/ -#include -#include "mlx/c/mlx.h" -static inline size_t stride(const mlx_array a, int i) {return mlx_array_strides(a)[i];} - -extern void goStackTrace(); -static void error_handler(const char *msg, void* data) { - fprintf(stderr, "MLX error: %s\n", msg); - goStackTrace(); - exit(-1); // TODO adjust so this can become a return code on the current thread instead of exit -} -static void set_error_handler() {mlx_set_error_handler(&error_handler, NULL, NULL);} -static void* mlx_array_data_float16_asvoid(const mlx_array a) {return (void*)mlx_array_data_float16(a);} -typedef const char cchar_t; -*/ -import "C" - -import ( - "encoding/json" - "fmt" - "log/slog" - "math" - "os" - "path/filepath" - "reflect" - "runtime" - "runtime/debug" - "sync" - "unsafe" - - "github.com/ollama/ollama/convert" - "github.com/ollama/ollama/fs" - "github.com/ollama/ollama/x/ml" - "github.com/x448/float16" -) - -func init() { - ml.RegisterBackend("mlx", New) - C.set_error_handler() -} - -//export goStackTrace -func goStackTrace() { - debug.PrintStack() -} - -type SafetensorsIndexMetadata struct { - TotalSize uint64 `json:"total_size"` -} -type SafetensorsIndex struct { - Metadata SafetensorsIndexMetadata `json:"metadata"` - WeightMap map[string]string `json:"weight_map"` -} - -type Backend struct { - meta fs.Config - tensors map[string]*Array -} - -func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { - // TODO assumes modelPath is actually a directory for now... - kv, tokenizer, err := convert.LoadModelMetadata(os.DirFS(modelPath)) - if err != nil { - return nil, fmt.Errorf("unable to load model: %w", err) - } - - b := &Backend{ - meta: kv.KV(tokenizer), - } - - err = b.LoadSafeTensors(modelPath) - if err != nil { - return nil, fmt.Errorf("safetensors load failed: %w", err) - } - return b, nil -} - -func (b *Backend) LoadSafeTensors(dir string) error { - if _, err := os.Stat(dir); err != nil { - return fmt.Errorf("failed to stat dir: %w", err) - } - // other variations to try? - stFilename := filepath.Join(dir, "model.safetensors.index.json") - if _, err := os.Stat(stFilename); err != nil { - return fmt.Errorf("failed to stat %s: %w", stFilename, err) - } - - fp, err := os.Open(stFilename) - if err != nil { - return fmt.Errorf("failed to open safetensor index: %s: %w", stFilename, err) - } - decoder := json.NewDecoder(fp) - var index SafetensorsIndex - if err := decoder.Decode(&index); err != nil { - return fmt.Errorf("decode error: %s: %w", stFilename, err) - } - slog.Info("XXX parsed metadata", "size", index.Metadata.TotalSize, "weights", len(index.WeightMap)) - filenames := map[string]struct{}{} - for _, filename := range index.WeightMap { - filenames[filename] = struct{}{} - } - stream := C.mlx_default_cpu_stream_new() - - b.tensors = map[string]*Array{} - - for filename := range filenames { - filepath := filepath.Join(dir, filename) - if _, err := os.Stat(filepath); err != nil { - return fmt.Errorf("failed to stat %s: %w", filepath, err) - } - slog.Info("Loading tensors from", "filename", filename) - cFilename := C.CString(filepath) - defer C.free(unsafe.Pointer(cFilename)) - data := C.mlx_map_string_to_array_new() // TODO is this needed or just var it? - metadata := C.mlx_map_string_to_string_new() - defer C.mlx_map_string_to_array_free(data) - defer C.mlx_map_string_to_string_free(metadata) - - if C.mlx_load_safetensors(&data, &metadata, cFilename, stream) != 0 { - // TODO with the current error handling, this will never happen - return fmt.Errorf("load failed") - } - - it := C.mlx_map_string_to_array_iterator_new(data) - // defer C.mlx_array_free(shaped) - // TODO confusing how memory management works with this... - for { - var key *C.cchar_t - var value C.mlx_array - if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 { - break - } - k := C.GoString((*C.char)(key)) - b.tensors[k] = &Array{ - name: k, - a: value, - } - // slog.Info("XXX read", "tensor", b.tensors[k], "type", b.tensors[k].TypeString()) - } - } - - return nil -} - -func (b *Backend) Get(name string) ml.Tensor { - var t ml.Tensor - var ok bool - if t, ok = b.tensors[name]; !ok { - // slog.Warn("unable to locate", "tensor", name) - return nil - } - // slog.Info("Fetching", "tensor", name, "type", b.tensors[name].TypeString()) - return t -} - -func (b *Backend) NewContext() ml.Context { - // slog.Info("MLX.NewContext") - return &Context{ - stream: C.mlx_default_gpu_stream_new(), - } -} - -func (b *Backend) Config() fs.Config { - return b.meta -} - -type Context struct { - stream C.mlx_stream - - mu sync.Mutex - arrays []C.mlx_array // TODO should we do some bookkeeping to ensure none of these Arrays are still lingering? -} - -func (c *Context) Close() { - // C.mlx_synchronize(c.stream) // ??? - C.mlx_stream_free(c.stream) - - c.mu.Lock() - defer c.mu.Unlock() - for _, a := range c.arrays { - slog.Info("XXX freeing", "array", a) - C.mlx_array_free(a) - } -} - -func (c *Context) Compute(tensors ...ml.Tensor) { - // TODO - for the zero tensor case this feels like it might not be correct... - needSync := true - sync := func() { - if needSync { - C.mlx_synchronize(c.stream) - needSync = false - } - } - - vec := C.mlx_vector_array_new() - defer C.mlx_vector_array_free(vec) - for _, t := range tensors { - C.mlx_vector_array_append_value(vec, t.(*Array).a) - t.(*Array).sync = sync - } - C.mlx_async_eval(vec) -} - -func (c *Context) Forward(tensors ...ml.Tensor) ml.Context { - vec := C.mlx_vector_array_new() - defer C.mlx_vector_array_free(vec) - needSync := true - sync := func() { - if needSync { - C.mlx_synchronize(c.stream) - needSync = false - } - } - - for _, t := range tensors { - t.(*Array).sync = sync - C.mlx_vector_array_append_value(vec, t.(*Array).a) - } - C.mlx_async_eval(vec) - return c -} - -func (c *Context) Input() ml.Context { - return c -} - -// func (c *Context) Output() ml.Context { -// return c -// } - -func (c *Context) Layer(_ int) ml.Context { - return c -} - -func (c *Context) RandomNormal(shape []int, dtype ml.DType, loc, scale float32, key ml.Tensor) ml.Tensor { - var r C.mlx_array - var k C.mlx_array - if key != nil { - k = key.(*Array).a - } - sh := make([]C.int, len(shape)) - for i := range shape { - sh[i] = C.int(shape[i]) - } - C.mlx_random_normal( - &r, - &sh[0], - C.size_t(len(shape)), - C.mlx_dtype(dtype), - C.float(loc), - C.float(scale), - k, - c.stream, - ) - return newArray(c, r) -} - -func (c *Context) CompareWith(filepath string, tensors map[string]ml.Tensor, abortOnError bool) (err error) { - minCosine := float32(0.96) // TODO too low... - fileTensors := map[string]*Array{} - defer func() { - if err != nil { - for k, v := range tensors { - fmt.Fprintln(os.Stderr, "input tensor "+k+"\n"+v.ToString()) - if fv, ok := fileTensors[k]; ok { - fmt.Fprintln(os.Stderr, " file tensor "+k+"\n"+fv.ToString()) - } else { - fmt.Fprintln(os.Stderr, " file tensor "+k+" missing!\n") - } - } - } - if abortOnError { - if err != nil { - panic(fmt.Sprintf("%s", err)) - } - } - }() - if _, err = os.Stat(filepath); err != nil { - filepath += ".safetensors" - if _, err = os.Stat(filepath); err != nil { - err = fmt.Errorf("failed to stat %s: %w", filepath, err) - return - } - err = nil - } - // slog.Info("Loading tensors from", "filename", filepath) - cFilename := C.CString(filepath) - defer C.free(unsafe.Pointer(cFilename)) - data := C.mlx_map_string_to_array_new() // TODO is this needed or just var it? - metadata := C.mlx_map_string_to_string_new() - defer C.mlx_map_string_to_array_free(data) - defer C.mlx_map_string_to_string_free(metadata) - - stream := C.mlx_default_cpu_stream_new() - - if C.mlx_load_safetensors(&data, &metadata, cFilename, stream) != 0 { - // TODO with the current error handling, this will never happen - err = fmt.Errorf("load failed") - return - } - - it := C.mlx_map_string_to_array_iterator_new(data) - allTensors := []ml.Tensor{} - for _, t := range tensors { - allTensors = append(allTensors, t) - } - - for { - var key *C.cchar_t - var value C.mlx_array - defer C.mlx_array_free(value) - if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 { - break - } - k := C.GoString((*C.char)(key)) - var r C.mlx_array - defer C.mlx_array_free(r) - C.mlx_astype( - &r, - value, - C.MLX_FLOAT32, - stream, - ) - - fileTensors[k] = &Array{ - name: k, - a: r, - } - // slog.Info("XXX read", "tensor", t, "type", t.TypeString()) - allTensors = append(allTensors, fileTensors[k]) - } - c.Forward(allTensors...) - for k, t := range tensors { - a, ok := fileTensors[k] - if !ok { - err = fmt.Errorf("tensor named %s not found in file", k) - return - } - if !reflect.DeepEqual(a.Shape(), t.Shape()) { - err = fmt.Errorf("mismatched shapes: file: %v vs. input %v", a.Shape(), t.Shape()) - return - } - // slog.Info("XXX shapes match", "shape", t.Shape()) - // TODO handle int types... - tDType := t.DType() - if tDType != ml.DTypeFloat16 && tDType != ml.DTypeFloat32 { - var r C.mlx_array - defer C.mlx_array_free(r) - C.mlx_astype( - &r, - t.(*Array).a, - C.MLX_FLOAT32, - stream, - ) - t = &Array{ - a: r, - } - c.Forward(t) - } - - af := a.Floats() - tf := t.Floats() - cos := cosineSimilarity(af, tf) - diff := a.Sub(c, t) - min := diff.Min(c, nil, true) - max := diff.Max(c, nil, true) - c.Forward(min, max) - minf := min.Floats() - maxf := max.Floats() - if cos < minCosine { - err = fmt.Errorf("%s shapes match, but not similar enough: %v min_difference=%v max_difference=%v", k, cos, minf, maxf) - return - } - - slog.Info("XXX tensors are similar", k, cos, "shape", t.Shape(), "min_difference", minf, "max_difference", maxf) - } - err = nil - - return -} - -func dotProduct[V float32 | float64](v1, v2 []V) V { - var result V = 0 - if len(v1) != len(v2) { - return result - } - - for i := 0; i < len(v1); i++ { - result += v1[i] * v2[i] - } - return result -} - -func magnitude[V float32 | float64](v []V) V { - var result V = 0 - for _, val := range v { - result += val * val - } - return V(math.Sqrt(float64(result))) -} - -func cosineSimilarity[V float32 | float64](v1, v2 []V) V { - mag1 := magnitude(v1) - mag2 := magnitude(v2) - - if mag1 == 0 || mag2 == 0 { - return 0 - } - - return dotProduct(v1, v2) / (magnitude(v1) * magnitude(v2)) -} - -func euclideanDistance[V float32 | float64](v1, v2 []V) V { - if len(v1) != len(v2) { - return V(math.Inf(1)) - } - - var sum V = 0 - for i := 0; i < len(v1); i++ { - diff := v1[i] - v2[i] - sum += diff * diff - } - - return V(math.Sqrt(float64(sum))) -} - -func manhattanDistance[V float32 | float64](v1, v2 []V) V { - if len(v1) != len(v2) { - return V(math.Inf(1)) - } - - var sum V = 0 - for i := 0; i < len(v1); i++ { - sum += V(math.Abs(float64(v1[i] - v2[i]))) - } - - return sum -} - -type Array struct { - name string - a C.mlx_array - c *Context - - sync func() -} - -func newArray(ctx *Context, a C.mlx_array) *Array { - // TODO measure impact and if this slows things down, make it conditional on some debugging flag at load time - var name string - _, f, l, ok := runtime.Caller(2) - if ok { - name = fmt.Sprintf("%s:%d", f, l) - } - - t := &Array{ - name: name, - a: a, - c: ctx, - } - // DEBUG memory allocation problems... - // slog.Info("XXX Allocated", "array", t, "a", a) - ctx.mu.Lock() - defer ctx.mu.Unlock() - ctx.arrays = append(ctx.arrays, a) - return t -} - -// FromFloats implements ml.Context. -func (c *Context) FromFloats(s []float32, shape ...int) ml.Tensor { - u16s := make([]float16.Float16, len(s)) - for i := range u16s { - u16s[i] = float16.Fromfloat32(s[i]) - } - cshape := make([]C.int, len(shape)) - for i, dim := range shape { - cshape[i] = C.int(dim) - } - return newArray(c, - C.mlx_array_new_data( - unsafe.Pointer(&u16s[0]), - &cshape[0], - C.int(len(cshape)), - C.MLX_FLOAT16, - ), - ) -} - -func (a *Array) Floats() []float32 { - if a.sync != nil { - a.sync() - } - l := (int)(C.mlx_array_size(a.a)) - - switch C.mlx_array_dtype(a.a) { - case C.MLX_BFLOAT16: - panic("bfloat16 not yet implemented") - case C.MLX_FLOAT16: - data := C.mlx_array_data_float16_asvoid(a.a) - if data == nil { - panic("nil data, wasn't eval'd") - } - u16s := unsafe.Slice((*uint16)(data), l) - f32s := make([]float32, len(u16s)) - for i := range u16s { - f32s[i] = float16.Frombits(u16s[i]).Float32() - } - return f32s - case C.MLX_FLOAT32: - data := C.mlx_array_data_float32(a.a) - if data == nil { - panic("nil data, wasn't eval'd") - } - f32s := unsafe.Slice((*float32)(data), l) - return f32s - default: - panic(fmt.Sprintf("unsupported dtype for Floats: %d", C.mlx_array_dtype(a.a))) - } -} - -// FromInts implements ml.Context. -func (c *Context) FromInts(s []int32, shape ...int) ml.Tensor { - cshape := make([]C.int, len(shape)) - for i, dim := range shape { - cshape[i] = C.int(dim) - } - return newArray(c, - C.mlx_array_new_data( - unsafe.Pointer(&s[0]), - &cshape[0], - C.int(len(cshape)), - C.MLX_INT32, - ), - ) -} - -func (a *Array) Ints() []int32 { - if a.sync != nil { - a.sync() - } - l := (int)(C.mlx_array_size(a.a)) - - switch C.mlx_array_dtype(a.a) { - case C.MLX_INT32: - data := C.mlx_array_data_int32(a.a) - if data == nil { - panic("nil data, wasn't eval'd") - } - i32s := unsafe.Slice((*int32)(data), l) - return i32s - - // TODO other types via conversion? - default: - panic(fmt.Sprintf("unsupported dtype for Ints: %d", C.mlx_array_dtype(a.a))) - } -} - -func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { - sh := make([]C.int, len(shape)) - for i, s := range shape { - sh[i] = (C.int)(s) - } - - var r C.mlx_array - C.mlx_zeros( - &r, - &sh[0], - (C.size_t)(len(sh)), - C.mlx_dtype(dtype), - c.stream, - ) - return newArray(c, r) -} - -func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor { - // TODO more efficient impl? - return c.Zeros(dtype, shape...) -} - -func (a *Array) DType() ml.DType { - return (ml.DType)(C.mlx_array_dtype(a.a)) -} - -func (a *Array) Dim(n int) int { - return int(C.mlx_array_dim(a.a, C.int(n))) -} - -func (a *Array) Stride(n int) int { - return (int)(C.stride(a.a, (C.int)(n))) -} - -func (c *Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { - var r C.mlx_array - C.mlx_arange( - &r, - C.double(start), - C.double(stop), - C.double(step), - (C.mlx_dtype)(dtype), - c.stream, - ) - - return newArray(c, r) -} - -// Scale implements ml.Tensor. -func (a *Array) Scale(ctx ml.Context, s float64) ml.Tensor { - scale := C.mlx_array_new_float(C.float(s)) - var r C.mlx_array - C.mlx_multiply( - &r, - a.a, - scale, - ctx.(*Context).stream, - ) - return newArray(ctx.(*Context), r) -} - -func (a *Array) Softmax(ctx ml.Context) ml.Tensor { - var r C.mlx_array - C.mlx_softmax( - &r, - a.a, - false, // TODO - precise? - ctx.(*Context).stream, - ) - return newArray(ctx.(*Context), r) -} - -func (a *Array) SliceUpdate(ctx ml.Context, update ml.Tensor, start, stop, strides []int) ml.Tensor { - cStart := make([]C.int, len(start)) - for i := range start { - cStart[i] = C.int(start[i]) - } - cStop := make([]C.int, len(stop)) - for i := range stop { - cStop[i] = C.int(stop[i]) - } - cStrides := make([]C.int, len(strides)) - for i := range strides { - cStrides[i] = C.int(strides[i]) - } - var r C.mlx_array - C.mlx_slice_update( - &r, - a.a, - update.(*Array).a, - (*C.int)(unsafe.Pointer(&cStart[0])), - C.size_t(len(cStart)), - (*C.int)(unsafe.Pointer(&cStop[0])), - C.size_t(len(cStop)), - (*C.int)(unsafe.Pointer(&cStrides[0])), - C.size_t(len(cStrides)), - ctx.(*Context).stream, - ) - // Release the old array and replace with the new one to ensure the same underlying buffer is used - a.c.mu.Lock() - defer a.c.mu.Unlock() - for i := range a.c.arrays { - if a.c.arrays[i] == a.a { - C.mlx_array_free(a.a) - a.a = r - a.c.arrays = append(a.c.arrays[:i], a.c.arrays[i+1:]...) - return a - } - } - panic("unable to locate array in context") -} - -func (a *Array) SliceUpdateDynamic(ctx ml.Context, update, start ml.Tensor, axes []int) ml.Tensor { - cAxes := make([]C.int, len(axes)) - for i := range axes { - cAxes[i] = C.int(axes[i]) - } - - var r C.mlx_array - C.mlx_slice_update_dynamic( - &r, - a.a, - update.(*Array).a, - start.(*Array).a, - (*C.int)(unsafe.Pointer(&cAxes[0])), - C.size_t(len(cAxes)), - ctx.(*Context).stream, - ) - // Release the old array and replace with the new one to ensure the same underlying buffer is used - a.c.mu.Lock() - defer a.c.mu.Unlock() - for i := range a.c.arrays { - if a.c.arrays[i] == a.a { - C.mlx_array_free(a.a) - a.a = r - a.c.arrays = append(a.c.arrays[:i], a.c.arrays[i+1:]...) - return a - } - } - panic("unable to locate array in context") - -} - -func (a *Array) PutAlongAxis(ctx ml.Context, indicies, values ml.Tensor, axis int) ml.Tensor { - var r C.mlx_array - C.mlx_put_along_axis( - &r, - a.a, - indicies.(*Array).a, - values.(*Array).a, - C.int(axis), - ctx.(*Context).stream, - ) - // Release the old array and replace with the new one to ensure the same underlying buffer is used - a.c.mu.Lock() - defer a.c.mu.Unlock() - for i := range a.c.arrays { - if a.c.arrays[i] == a.a { - C.mlx_array_free(a.a) - a.a = r - a.c.arrays = append(a.c.arrays[:i], a.c.arrays[i+1:]...) - return a - } - } - panic("unable to locate array in context") -} - -func (a *Array) Scatter(ctx ml.Context, indicies []ml.Tensor, updates ml.Tensor, axes []int) ml.Tensor { - - cAxes := make([]C.int, len(axes)) - for i := range axes { - cAxes[i] = C.int(axes[i]) - } - var cAxes0 *C.int - if len(cAxes) > 0 { - cAxes0 = (*C.int)(unsafe.Pointer(&cAxes[0])) - } - indiciesVec := C.mlx_vector_array_new() - defer C.mlx_vector_array_free(indiciesVec) - for _, ind := range indicies { - C.mlx_vector_array_append_value(indiciesVec, ind.(*Array).a) - } - - var r C.mlx_array - C.mlx_scatter( - &r, - a.a, - indiciesVec, - updates.(*Array).a, - cAxes0, - C.size_t(len(cAxes)), - ctx.(*Context).stream, - ) - // Release the old array and replace with the new one to ensure the same underlying buffer is used - a.c.mu.Lock() - defer a.c.mu.Unlock() - for i := range a.c.arrays { - if a.c.arrays[i] == a.a { - C.mlx_array_free(a.a) - a.a = r - a.c.arrays[i] = r - return a - } - } - panic("unable to locate array in context") - -} - -func (a *Array) Copy(ctx ml.Context, a2 ml.Tensor) ml.Tensor { - C.mlx_copy( - &a2.(*Array).a, - a.a, - ctx.(*Context).stream, - ) - // TODO - view? - return newArray(ctx.(*Context), a2.(*Array).a) -} - -func (a *Array) Add(ctx ml.Context, a2 ml.Tensor) ml.Tensor { - var r C.mlx_array - C.mlx_add( - &r, - a.a, - a2.(*Array).a, - ctx.(*Context).stream, - ) - return newArray(ctx.(*Context), r) -} - -func (a *Array) Sub(ctx ml.Context, a2 ml.Tensor) ml.Tensor { - var r C.mlx_array - C.mlx_subtract( - &r, - a.a, - a2.(*Array).a, - ctx.(*Context).stream, - ) - return newArray(ctx.(*Context), r) -} - -func (a *Array) Max(ctx ml.Context, axes []int, keepDims bool) ml.Tensor { - var r C.mlx_array - cAxes := make([]C.int, len(axes)) - for i := range axes { - cAxes[i] = C.int(axes[i]) - } - var cAxes0 *C.int - if len(cAxes) > 0 { - cAxes0 = (*C.int)(unsafe.Pointer(&cAxes[0])) - C.mlx_max_axes( - &r, - a.a, - cAxes0, - C.size_t(len(cAxes)), - C._Bool(keepDims), - ctx.(*Context).stream, - ) - } else { - C.mlx_max( - &r, - a.a, - C._Bool(keepDims), - ctx.(*Context).stream, - ) - - } - - return newArray(ctx.(*Context), r) -} - -func (a *Array) Min(ctx ml.Context, axes []int, keepDims bool) ml.Tensor { - var r C.mlx_array - cAxes := make([]C.int, len(axes)) - for i := range axes { - cAxes[i] = C.int(axes[i]) - } - var cAxes0 *C.int - if len(cAxes) > 0 { - cAxes0 = (*C.int)(unsafe.Pointer(&cAxes[0])) - C.mlx_min_axes( - &r, - a.a, - cAxes0, - C.size_t(len(cAxes)), - C._Bool(keepDims), - ctx.(*Context).stream, - ) - } else { - C.mlx_min( - &r, - a.a, - C._Bool(keepDims), - ctx.(*Context).stream, - ) - } - - return newArray(ctx.(*Context), r) -} - -func (a *Array) Matmul(ctx ml.Context, a2 ml.Tensor) ml.Tensor { - var r C.mlx_array - C.mlx_matmul( - &r, - a.a, - a2.(*Array).a, - ctx.(*Context).stream, - ) - return newArray(ctx.(*Context), r) -} - -func (a *Array) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor { - // slog.Info("MLX.RMSNorm", "a", a, "w", w) - var r C.mlx_array - C.mlx_fast_rms_norm( - &r, - a.a, - w.(*Array).a, - C.float(eps), - ctx.(*Context).stream, - ) - return newArray(ctx.(*Context), r) -} - -func (a *Array) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { - var r C.mlx_array - C.mlx_fast_layer_norm( - &r, - a.a, - w.(*Array).a, - b.(*Array).a, - C.float(eps), - ctx.(*Context).stream, - ) - return newArray(ctx.(*Context), r) -} - -func (a *Array) L2Norm(ctx ml.Context, eps float32) ml.Tensor { - // TODO implement - panic("NOT YET IMPLEMENTED") -} - -func (t Array) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor { - panic("NOT YET IMPLEMENTED") -} - -// RoPE implements Rotary Positional Encoding -// -// dims (int) – The feature dimensions to be rotated. If the input feature is larger than dims then the rest is left unchanged. -// traditional (bool) – If set to True choose the traditional implementation which rotates consecutive dimensions. -// scale (float) – The scale used to scale the positions. -// offset (int) – The position offset to start at. TODO MLX-C does not yet expose Offset as an Array -// WithBase (float, optional) – The base used to compute angular frequency for each dimension in the positional encodings. Exactly one of base and freqs must be None. -// WithFreqs (array, optional) – Optional frequencies to use with RoPE. If set, the base parameter must be None. Default: None. -func (a *Array) RoPE(ctx ml.Context, dims int, traditional bool, scale float32, offset int, options ...func(*ml.RoPEOptions)) ml.Tensor { - opts := ml.RoPEOptions{} - - // Apply any provided options - for _, option := range options { - option(&opts) - } - var r C.mlx_array - var base C.mlx_optional_float - var freqs C.mlx_array - - if opts.Base != nil { - base.value = C.float(*opts.Base) - base.has_value = true - } - if opts.Freqs != nil { - freqs = opts.Freqs.(*Array).a - } - C.mlx_fast_rope( - &r, - a.a, - C.int(dims), - C._Bool(traditional), - base, - C.float(scale), - C.int(offset), - freqs, - ctx.(*Context).stream, - ) - return newArray(ctx.(*Context), r) -} - -// A fast implementation of multi-head attention: O = softmax(Q @ K.T, dim=-1) @ V. -// -// Supports: -// - Multi-Head Attention -// - Grouped Query Attention -// - Multi-Query Attention -// -// Note: -// - The softmax operation is performed in float32 regardless of the input precision. -// - For Grouped Query Attention and Multi-Query Attention, the k and v inputs should not be pre-tiled to match q. -// -// In the following the dimensions are given by: -// - B: The batch size. -// - N_q: The number of query heads. -// - N_kv: The number of key and value heads. -// - T_q: The number of queries per example. -// - T_kv: The number of keys and values per example. -// - D: The per-head dimension. -// -// Parameters: -// - [subject array] queries (array) – Queries with shape [B, N_q, T_q, D]. -// - keys (array) – with shape [B, N_kv, T_kv, D]. -// - values (array) – with shape [B, N_kv, T_kv, D]. -// - scale (float) – Scale for queries (typically 1.0 / sqrt(q.shape(-1)). -// - mask (str or array, optional) – The mask to apply to the query-key scores. -// The mask can be an array or a string indicating the mask type. The only supported string type is "causal". -// If the mask is an array it can be a boolean or additive mask. The mask can have at most 4 dimensions and -// must be broadcast-compatible with the shape [B, N, T_q, T_kv]. If an additive mask is given its type must -// promote to the promoted type of q, k, and v. -// - sinks (array, optional) – An optional array of attention sinks. Default: None. - -func (queries *Array) ScaledDotProductAttention(ctx ml.Context, keys, values ml.Tensor, scale float64, maskMode string, mask ml.Tensor, sinks ml.Tensor) ml.Tensor { - var r C.mlx_array - var s C.mlx_array - if sinks != nil { - s = sinks.(*Array).a - } - maskModeC := C.CString(maskMode) - defer C.free(unsafe.Pointer(maskModeC)) - var maskArr C.mlx_array - if mask != nil { - maskArr = mask.(*Array).a - } - - C.mlx_fast_scaled_dot_product_attention( - &r, - queries.a, - keys.(*Array).a, - values.(*Array).a, - C.float(scale), - maskModeC, - maskArr, - s, - ctx.(*Context).stream, - ) - return newArray(ctx.(*Context), r) -} - -func (a *Array) TakeAxes(ctx ml.Context, indicies ml.Tensor, axes int) ml.Tensor { - var r C.mlx_array - - C.mlx_take_axis(&r, a.a, indicies.(*Array).a, C.int(axes), ctx.(*Context).stream) - return newArray(ctx.(*Context), r) - -} - -// TODO not sure if we'll want this variation taking raw ints instead of a tensor... -// func (a *Array) TakeAxes(ctx ml.Context, axes int, indicies ...int) ml.Tensor { -// var i C.mlx_array -// var r C.mlx_array - -// if indicies != nil { -// shape := []C.int{C.int(len(indicies))} -// cindicies := make([]int32, len(indicies)) -// for i, v := range indicies { -// cindicies[i] = int32(v) -// } -// i = C.mlx_array_new_data( -// unsafe.Pointer(&cindicies[0]), -// &shape[0], -// C.int(len(shape)), -// C.MLX_INT32, -// ) -// } -// C.mlx_take_axis(&r, a.a, i, C.int(axes), ctx.(*Context).stream) -// return newArray(ctx.(*Context), r) - -// } - -func (a *Array) GELU(ctx ml.Context, up ...ml.Tensor) ml.Tensor { - // TODO precise vs fast, and compile - // x * mx.sigmoid(1.702 * x) - u16s := []float16.Float16{float16.Fromfloat32(1.702)} - cshape := []C.int{1} - f := C.mlx_array_new_data(unsafe.Pointer(&u16s[0]), &cshape[0], 1, C.MLX_FLOAT16) - defer C.mlx_array_free(f) - var r1, r2, r3 C.mlx_array - C.mlx_multiply(&r1, a.a, f, ctx.(*Context).stream) - defer C.mlx_array_free(r1) - C.mlx_sigmoid(&r2, r1, ctx.(*Context).stream) - defer C.mlx_array_free(r2) - C.mlx_multiply(&r3, a.a, r2, ctx.(*Context).stream) - - if len(up) > 0 { - var r4 C.mlx_array - defer C.mlx_array_free(r3) - C.mlx_multiply(&r4, r3, up[0].(*Array).a, ctx.(*Context).stream) - return newArray(ctx.(*Context), r4) - } - - return newArray(ctx.(*Context), r3) -} - -// Create a view into the array with the given shape and strides. -// -// The resulting array will always be as if the provided array was row -// contiguous regardless of the provided arrays storage order and current -// strides. -// -// Note that this function should be used with caution as it changes the shape -// and strides of the array directly. This can lead to the resulting array -// pointing to invalid memory locations which can result into crashes. -// -// Parameters: -// - shape (list(int), optional) – The shape of the resulting array. If None it defaults to a.shape(). -// - strides (list(int), optional) – The strides of the resulting array. If None it defaults to the -// reverse exclusive cumulative product of a.shape(). -// - offset (int) – Skip that many elements from the beginning of the input array. -func (a *Array) AsStrided(ctx ml.Context, shape, strides []int, offset int) ml.Tensor { - var r C.mlx_array - sh := make([]C.int, len(shape)) - st := make([]C.int64_t, len(strides)) - var sh0 *C.int - var st0 *C.int64_t - for i, s := range shape { - sh[i] = C.int(s) - } - for i, s := range strides { - st[i] = C.int64_t(s) - } - if len(sh) > 0 { - sh0 = (*C.int)(unsafe.Pointer(&sh[0])) - } - if len(st) > 0 { - st0 = (*C.int64_t)(unsafe.Pointer(&st[0])) - } - - C.mlx_as_strided( - &r, - a.a, - sh0, - C.size_t(len(sh)), - st0, - C.size_t(len(st)), - C.size_t(offset), - ctx.(*Context).stream, - ) - return newArray(ctx.(*Context), r) - -} - -func (a *Array) Reshape(ctx ml.Context, shape ...int) ml.Tensor { - cshape := make([]C.int, len(shape)) - for i, dim := range shape { - cshape[i] = C.int(dim) - } - var r C.mlx_array - C.mlx_reshape(&r, a.a, &cshape[0], C.size_t(len(cshape)), ctx.(*Context).stream) - return newArray(ctx.(*Context), r) -} - -func (a *Array) Transpose(ctx ml.Context, shape ...int) ml.Tensor { - ndim := min(C.mlx_array_ndim(a.a), C.size_t(len(shape))) - var r C.mlx_array - sh := make([]C.int, ndim) - for i := range ndim { - sh[i] = (C.int)(shape[i]) - if int(sh[i]) >= int(ndim) { - slog.Error("Permute error", "tensor", a, "shape", shape) - panic("invalid pemute call") - } - } - if len(sh) > 0 { - C.mlx_transpose_axes( - &r, - a.a, - &sh[0], - ndim, - ctx.(*Context).stream, - ) - } else { - C.mlx_transpose( - &r, - a.a, - ctx.(*Context).stream, - ) - } - return newArray(ctx.(*Context), r) -} - -func (a *Array) Contiguous(ctx ml.Context, allowColMajor bool) ml.Tensor { - var r C.mlx_array - C.mlx_contiguous( - &r, - a.a, - (C._Bool)(allowColMajor), - ctx.(*Context).stream, - ) - return newArray(ctx.(*Context), r) -} - -// Conv2D implements ml.Tensor. -// GGML API -// input: [N, IC, IH, IW] -// weight: [OC,IC, KH, KW] -// result: [N, OC, OH, OW] -// -// MLX: -// input: (N, KH, KW, C_in) -// weight: (C_out, IH, IW, C_in) -// result: XXX - -func (input *Array) Conv2D(ctx ml.Context, weight ml.Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) ml.Tensor { - var r C.mlx_array - C.mlx_conv2d( - &r, - input.a, - weight.(*Array).a, - C.int(stride0), - C.int(stride1), - C.int(padding0), - C.int(padding1), - C.int(dilation0), - C.int(dilation1), - C.int(groups), - ctx.(*Context).stream, - ) - return newArray(ctx.(*Context), r) -} - -func (input *Array) Conv3D(ctx ml.Context, weight ml.Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) ml.Tensor { - var r C.mlx_array - C.mlx_conv3d( - &r, - input.a, - weight.(*Array).a, - C.int(stride0), - C.int(stride1), - C.int(stride2), - C.int(padding0), - C.int(padding1), - C.int(padding2), - C.int(dilation0), - C.int(dilation1), - C.int(dilation2), - C.int(groups), - ctx.(*Context).stream, - ) - return newArray(ctx.(*Context), r) -} - -func (a *Array) ToString() string { - str := C.mlx_string_new() - C.mlx_array_tostring(&str, a.a) - s := C.mlx_string_data(str) - defer C.mlx_string_free(str) - return C.GoString(s) -} - -func (a *Array) LogValue() slog.Value { - - dims := int(C.mlx_array_ndim(a.a)) - strides := make([]int, dims) - for i := range strides { - strides[i] = int(C.stride(a.a, (C.int)(i))) - } - - return slog.GroupValue( - slog.String("name", a.name), - slog.String("type", a.TypeString()), - slog.Any("shape", a.Shape()), - slog.Any("strides", strides), - // slog.String("values", C.GoString(s)), - ) -} - -func (a *Array) Shape() []int { - shape := make([]int, C.mlx_array_ndim(a.a)) - for i := range shape { - shape[i] = int(C.mlx_array_dim(a.a, C.int(i))) - } - - return shape -} - -func (a *Array) TypeString() string { - switch C.mlx_array_dtype(a.a) { - case C.MLX_BOOL: - return "bool" - case C.MLX_UINT8: - return "uint8" - case C.MLX_UINT16: - return "uint16" - case C.MLX_UINT32: - return "uint32" - case C.MLX_UINT64: - return "uint64" - case C.MLX_INT8: - return "int8" - case C.MLX_INT16: - return "int16" - case C.MLX_INT32: - return "int32" - case C.MLX_INT64: - return "int64" - case C.MLX_FLOAT16: - return "float16" - case C.MLX_FLOAT32: - return "float32" - case C.MLX_BFLOAT16: - return "bfloat16" - case C.MLX_COMPLEX64: - return "complex64" - default: - return "unknown" - } -} diff --git a/x/ml/backend/mlx/mlx_test.go b/x/ml/backend/mlx/mlx_test.go deleted file mode 100644 index 7699c1524..000000000 --- a/x/ml/backend/mlx/mlx_test.go +++ /dev/null @@ -1,314 +0,0 @@ -//go:build mlx - -package mlx - -import ( - "log/slog" - "os" - "reflect" - "strings" - "testing" - - "github.com/ollama/ollama/api" - "github.com/ollama/ollama/runner/common" - "github.com/ollama/ollama/sample" - "github.com/ollama/ollama/x/ml" - "github.com/ollama/ollama/x/model" - "github.com/ollama/ollama/x/model/input" - _ "github.com/ollama/ollama/x/model/models/gemma3" -) - -func init() { - logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) - slog.SetDefault(logger) -} - -func TestLoadModel(t *testing.T) { - dir := "/Users/daniel/Models/gemma-3-4b-it/" - b := &Backend{} - err := b.LoadSafeTensors(dir) - if err != nil { - t.Fatalf("load failed: %s", err) - } -} - -func TestFromInts(t *testing.T) { - b := &Backend{} - c := b.NewContext() - defer c.Close() - data := []int32{1, 2, 3, 4, 5, 6} - a := c.FromInts(data, 2, 3) - slog.Info("", "array", a) - t.Log(a.ToString()) - if !reflect.DeepEqual(a.Shape(), []int{2, 3}) { - t.Fatalf("incorrect shape: %v", a.Shape()) - } -} - -func TestFromFloats(t *testing.T) { - b := &Backend{} - c := b.NewContext() - defer c.Close() - data := []float32{1, 2, 3, 4, 5, 6} - a := c.FromFloats(data, 2, 3) - slog.Info("", "array", a) - t.Log(a.ToString()) - if !reflect.DeepEqual(a.Shape(), []int{2, 3}) { - t.Fatalf("incorrect shape: %v", a.Shape()) - } - res := a.Floats() - if !reflect.DeepEqual(res, data) { - t.Fatalf("incorrect results: %v", res) - } -} - -func TestAdd(t *testing.T) { - b := &Backend{} - c := b.NewContext() - defer c.Close() - t1 := c.Arange(0, 24, 1, ml.DTypeFloat16) - t2 := c.Arange(0, 24, 1, ml.DTypeFloat16) - exp := c.Arange(0, 48, 2, ml.DTypeFloat16) - t3 := t1.Add(c, t2) - c.Compute(t3, exp) - t3f := t3.Floats() - if !reflect.DeepEqual(t3f, exp.Floats()) { - t.Fatalf("incorrect result: %v", t3f) - } -} - -func TestReshapeTranspose(t *testing.T) { - b := &Backend{} - c := b.NewContext() - defer c.Close() - t1 := c.Arange(0, 24, 1, ml.DTypeFloat16).Reshape(c, 2, 3, 4).Transpose(c, 0, 2, 1).Contiguous(c, false) - c.Compute(t1) - t1f := t1.Floats() - exp := []float32{ - 0, 4, 8, - 1, 5, 9, - 2, 6, 10, - 3, 7, 11, - 12, 16, 20, - 13, 17, 21, - 14, 18, 22, - 15, 19, 23, - } - if !reflect.DeepEqual(t1f, exp) { - t.Fatalf("incorrect results: %v", t1f) - } -} - -func prod(vals ...int) int { - r := 1 - for _, v := range vals { - r *= v - } - return r -} -func TestMatmul(t *testing.T) { - // TODO create scenarios... - b := &Backend{} - c := b.NewContext() - defer c.Close() - s1 := []int{1, 3, 2, 4} - t1 := c.Arange(0, float32(prod(s1...)), 1, ml.DTypeFloat16).Reshape(c, s1...) - s2 := []int{4, 2} - t2 := c.Arange(0, float32(prod(s2...)), 1, ml.DTypeFloat16).Reshape(c, s2...) - t3 := t1.Matmul(c, t2) - exp := []float32{ - 28, 34, - 76, 98, - - 124, 162, - 172, 226, - - 220, 290, - 268, 354, - } - c.Compute(t3) - t3f := t3.Floats() - if !reflect.DeepEqual(t3f, exp) { - t.Fatalf("incorrect result: %v", t3f) - } -} - -func TestRows(t *testing.T) { - b := &Backend{} - c := b.NewContext() - defer c.Close() - t1 := c.Arange(0, 12, 1, ml.DTypeFloat32).Reshape(c, 1, 4, 3) - outputs := c.Zeros(ml.DTypeInt32, 1) - t2 := t1.TakeAxes(c, outputs, 1) - c.Forward(t1, t2).Compute(t1, t2) - t.Log(t1.ToString()) - t.Log(t2.ToString()) - f := t2.Floats() - t.Logf("Result: %v", f) -} - -func TestCaching(t *testing.T) { - // Validate the caching algorithm - b := &Backend{} - c := b.NewContext() - defer c.Close() - batchSize := 3 - headDim := 4 - numKVHeads := 2 - // Make cache twice the size of one test batch - cells := batchSize * 2 - cellSize := numKVHeads * headDim - shape := []int{1, numKVHeads, batchSize, headDim} - stop := float32(1) - for _, x := range shape { - stop *= float32(x) - } - // Create the cache - cache := c.Zeros(ml.DTypeFloat16, cells, cellSize) - t.Logf("Empty Cache shape%v\n"+cache.ToString(), []int{cells, cellSize}) - - // Input tensor - t1 := c.Arange(0, stop, 1, ml.DTypeFloat16).Reshape(c, shape...) - t.Logf("Initial Data shape%v\n"+t1.ToString(), shape) - - // Reshape to copy into the cache - /* - From MLX python/src/indexing.cpp mlx_scatter_args_array - // The update shape must broadcast with indices.shape + [1] + src.shape[1:] - auto up_shape = indices.shape(); - up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end()); - up = broadcast_to(up, up_shape); - up_shape.insert(up_shape.begin() + indices.ndim(), 1); - up = reshape(up, up_shape); - */ - numRows := 3 - up := t1.Reshape(c, numRows, 1, cellSize) // The shape has to look like this for scatter to work properly - t.Logf("Data reshaped for cache input shape%v\n"+up.ToString(), []int{batchSize, numKVHeads * headDim}) - - // Simulate cells 1,3,5 are available - indicies := []ml.Tensor{c.FromInts([]int32{1, 3, 5}, numRows)} - t.Logf("Indicies shape%v\n"+indicies[0].ToString(), []int{numRows}) - axis := []int{0} // The 1,3,5 of the indicies are in reference to axis 0 in the cache shape - cache.Scatter(c, indicies, up, axis) - - c.Forward(cache) - // Cache should contain the data now - t.Log("Cache after put\n" + cache.ToString()) - - // Retrieve cache content and verify it matches - out := cache.TakeAxes(c, indicies[0], 0).Reshape(c, shape...) - t.Logf("Output shape%v\n"+out.ToString(), out.Shape()) - - t1f := t1.Floats() - outf := out.Floats() - if !reflect.DeepEqual(t1f, outf) { - t.Fatalf("mismatched in->out\n%v\n ->\n%v", t1f, outf) - } -} - -func TestGemma3(t *testing.T) { - // Why is the sky blue - inputs := []int32{2, 105, 2364, 107, 36425, 563, 506, 7217, 3730, 106, 107, 105, 4368} - limit := 50 - - // TODO generalize this - dir := "/Users/daniel/Models/gemma-3-4b-it/" - - m, err := model.New(dir, ml.BackendParams{}) - if err != nil { - t.Fatalf("unable to load model: %s", err) - } - b := m.Backend() - ctx := b.NewContext() - defer ctx.Close() - - batch := input.Batch{ - Inputs: ctx.FromInts(inputs[:], 1, len(inputs)), - Positions: make([]int32, len(inputs)), - Sequences: make([]int, len(inputs)), - Outputs: ctx.FromInts([]int32{int32(len(inputs) - 1)}, 1), - Offset: 0, - } - for i := range len(inputs) { - batch.Positions[i] = int32(i) - } - offset := len(inputs) - - cache := m.Config().Cache - if cache != nil { - numSlots := 1 - batchSize := 512 - numCtx := 4096 - - // Note: this is inconsistent with mlx-py, but trying to be consistent with the GGML cache impl to get things working - // cache.SetConfig(ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 64}) - cache.SetConfig(ml.CacheConfig{CachePadding: 0, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 0}) - - cache.Init(b, ml.DTypeBfloat16, numSlots, int(numCtx), batchSize) - err := cache.StartForward(ctx, batch, false) - if err != nil { - t.Fatalf("failed cache.StartForward: %s", err) - } - } - opts := api.DefaultOptions() - var grammar *sample.GrammarSampler - sampler := sample.NewSampler( - opts.Temperature, - opts.TopK, - opts.TopP, - opts.MinP, - opts.Seed, - grammar, - ) - - t.Log("Starting Forward pass loop") - pendingResponses := []string{} - for { - out, err := m.Forward(ctx, batch) - if err != nil { - t.Fatalf("failed forward pass: %s", err) - } - ctx.Forward(out) - outputs := out.Floats() - t.Logf("finished forward pass! length:%d", len(outputs)) - // sample a token - logits := outputs - token, err := sampler.Sample(logits) - if err != nil { - t.Fatalf("unable to sample token: %s", err) - } - t.Logf("Sampled token: %v", token) - if m.(model.TextProcessor).Is(token, model.SpecialEOS) { - t.Log("hit EOS") - break - } - piece, err := m.(model.TextProcessor).Decode([]int32{token}) - if err != nil { - t.Fatalf("unable to decode token: %s", err) - } - - pendingResponses = append(pendingResponses, piece) - sequence := strings.Join(pendingResponses, "") - if ok, stop := common.FindStop(sequence, opts.Stop); ok { - t.Logf("hit stop token: %v", stop) - break - } - t.Logf("RESULTS: %s", sequence) - batch = input.Batch{ - Inputs: ctx.FromInts([]int32{token}, 1, 1), - Positions: make([]int32, 1), - Sequences: make([]int, 1), - Outputs: ctx.FromInts([]int32{0}, 1), - Offset: offset, - } - offset++ - batch.Positions[0] = 0 - err = cache.StartForward(ctx, batch, false) - if err != nil { - t.Fatalf("failed cache.StartForward: %s", err) - } - if offset > limit { - break - } - } -} diff --git a/x/ml/backend/mlx/quant.go b/x/ml/backend/mlx/quant.go deleted file mode 100644 index 724f43253..000000000 --- a/x/ml/backend/mlx/quant.go +++ /dev/null @@ -1,335 +0,0 @@ -//go:build mlx - -package mlx - -/* -#include -#include - -#include "mlx/c/array.h" -#include "mlx/c/ops.h" - -// Derived from https://github.com/ml-explore/mlx/blob/main/mlx/io/gguf_quants.cpp - -void unpack_32_4(uint8_t* data, int8_t* dst) { - memset(dst, 0, 16); - for (int j = 0; j < 16; ++j) { - uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes. - if (j % 2 != 0) { - x <<= 4; - } - dst[j / 2] += x; - } - // Last 16 weights are in the higher bits - for (int j = 0; j < 16; ++j) { - uint8_t x = (data[j + 2] >> 4); - if (j % 2 != 0) { - x <<= 4; - } - dst[8 + j / 2] += x; - } -} - -// Extracts (weight, scales, biases) from Q4_0 tensors. -// Data layout is: |16 bit scale|32 x 4bit weights|. -void extract_q4_0_data( - uint8_t* data, - mlx_array* weights_arr, - mlx_array* scales_arr, - mlx_array* biases_arr) { - const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights - uint8_t* weights = mlx_array_data_uint8(*weights_arr); - float16_t* scales = mlx_array_data_float16(*scales_arr); - float16_t* biases = mlx_array_data_float16(*biases_arr); - for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) { - scales[i] = *((float16_t*)data); - biases[i] = -8 * scales[i]; - unpack_32_4(data, weights); - weights += 16; - data += bytes_per_block; - } -} - -// Extracts (weight, scales, biases) from Q4_1 tensors. -// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|. -void extract_q4_1_data( - uint8_t* data, - mlx_array* weights_arr, - mlx_array* scales_arr, - mlx_array* biases_arr) { - const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights - uint8_t* weights = mlx_array_data_uint8(*weights_arr); - float16_t* scales = mlx_array_data_float16(*scales_arr); - float16_t* biases = mlx_array_data_float16(*biases_arr); - for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) { - scales[i] = *((float16_t*)data); - biases[i] = *((float16_t*)(data) + 1); - unpack_32_4(data, weights); - weights += 16; - data += bytes_per_block; - } -} - -// Extracts (weight, scales, biases) from Q8_0 tensors. -// Data layout is: |16 bit scale|32 x 8bit weights|. -void extract_q8_0_data( - uint8_t* data, - mlx_array* weights_arr, - mlx_array* scales_arr, - mlx_array* biases_arr) { - const uint64_t weights_per_block = 32; - const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights - uint8_t* weights = mlx_array_data_uint8(*weights_arr); - float16_t* scales = mlx_array_data_float16(*scales_arr); - float16_t* biases = mlx_array_data_float16(*biases_arr); - for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) { - uint8_t* block_data = data + i * bytes_per_block; - scales[i] = *((float16_t*)block_data); - biases[i] = -128 * scales[i]; - for (int64_t j = 0; j < weights_per_block; ++j) { - uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes. - // Original data is in int8_t, so we add a bias of -128 and invert the - // first bit. - x ^= 1 << 7; - weights[i * weights_per_block + j] = x; - } - } -} - -// Drived from ggml-quants.c - -#define QK_K 256 - -// 6-bit quantization -// weight is represented as x = a * q -// 16 blocks of 16 elements each -// Effectively 6.5625 bits per weight -typedef struct { - uint8_t ql[QK_K/2]; // quants, lower 4 bits - uint8_t qh[QK_K/4]; // quants, upper 2 bits - int8_t scales[QK_K/16]; // scales, quantized with 8 bits - uint16_t d; // super-block scale -} block_q6_K; - -void dequant_row_q6_K(const void * restrict vx, void * restrict vy, int k) { - const int64_t nb = k / QK_K; - block_q6_K *x = (block_q6_K *)vx; - float16_t* y = (float16_t *)vy; - - for (int i = 0; i < nb; i++) { - float16_t d = 0.0; - memcpy(&d, &x[i].d, sizeof(d)); - - const uint8_t * restrict ql = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict sc = x[i].scales; - - for (int n = 0; n < QK_K; n += 128) { - for (int l = 0; l < 32; ++l) { - int is = l/16; - const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - y[l + 0] = d * sc[is + 0] * q1; - y[l + 32] = d * sc[is + 2] * q2; - y[l + 64] = d * sc[is + 4] * q3; - y[l + 96] = d * sc[is + 6] * q4; - } - y += 128; - ql += 64; - qh += 32; - sc += 8; - } - } -} - -#define K_SCALE_SIZE 12 -#define GGML_COMMON_AGGR_U -#define GGML_COMMON_AGGR_S - -// 4-bit quantization -// 8 blocks of 32 elements each -// weight is represented as x = a * q + b -// Effectively 4.5 bits per weight -typedef struct { - union { - struct { - uint16_t d; // super-block scale for quantized scales - uint16_t dmin; // super-block scale for quantized mins - } GGML_COMMON_AGGR_S; - uint16_t dm; - } GGML_COMMON_AGGR_U; - uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_K; - -static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) { - if (j < 4) { - *d = q[j] & 63; *m = q[j + 4] & 63; - } else { - *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); - *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); - } -} - -void dequant_row_q4_K(const void * restrict vx, void * restrict vy, int k) { - block_q4_K *x = (block_q4_K *)vx; - float16_t* y = (float16_t *)vy; - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - const uint8_t * q = x[i].qs; - float16_t d = 0.0; - memcpy(&d, &x[i].d, sizeof(d)); - float16_t min = 0.0; - memcpy(&min, &x[i].dmin, sizeof(d)); - - int is = 0; - uint8_t sc, m; - for (int j = 0; j < QK_K; j += 64) { - get_scale_min_k4(is + 0, x[i].scales, &sc, &m); - const float16_t d1 = d * sc; const float16_t m1 = min * m; - get_scale_min_k4(is + 1, x[i].scales, &sc, &m); - const float16_t d2 = d * sc; const float16_t m2 = min * m; - for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1; - for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; - q += 32; is += 2; - } - } -} - - - -*/ -import "C" - -import ( - "fmt" - "unsafe" - - "github.com/x448/float16" -) - -func gguf_load_quantized(data unsafe.Pointer, name string, final_shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) { - shape := append([]C.int{}, final_shape...) - var weights_per_byte C.int - if dtype == 2 || dtype == 3 { - weights_per_byte = 2 - } else if dtype == 8 { - weights_per_byte = 1 - } else { - return r, fmt.Errorf("unsupported tensor type %d", dtype) - } - - weights_per_block := C.int(32) - if shape[len(shape)-1]%weights_per_block != 0 { - return r, fmt.Errorf("[load_gguf] tensor has incompatible last dim shape: %d", shape[len(shape)-1]) - } - - weights_shape := append([]C.int{}, shape...) - weights_shape[len(weights_shape)-1] /= (weights_per_byte * 4) - w_nbytes := C.int(unsafe.Sizeof(uint32(0))) - for i := range weights_shape { - w_nbytes *= weights_shape[i] - } - w_data := make([]byte, w_nbytes) - cbytes := C.CBytes(w_data) - defer C.free(cbytes) - weights := C.mlx_array_new_data( - cbytes, - &weights_shape[0], - C.int(len(weights_shape)), - C.MLX_UINT32, - ) - - // For scales and bias - shape[len(shape)-1] = shape[len(shape)-1] / weights_per_block - sb_nbytes := C.int(unsafe.Sizeof(float16.Float16(0))) - for i := range shape { - sb_nbytes *= shape[i] - } - - s_data := make([]byte, sb_nbytes) - cbytes = C.CBytes(s_data) - defer C.free(cbytes) - scales := C.mlx_array_new_data( - cbytes, - &shape[0], - C.int(len(shape)), - C.MLX_FLOAT16, - ) - b_data := make([]byte, sb_nbytes) - cbytes = C.CBytes(b_data) - defer C.free(cbytes) - biases := C.mlx_array_new_data( - cbytes, - &shape[0], - C.int(len(shape)), - C.MLX_FLOAT16, - ) - var bits C.int - switch dtype { - case 2: - C.extract_q4_0_data((*C.uint8_t)(data), &weights, &scales, &biases) - bits = 4 - case 3: - C.extract_q4_1_data((*C.uint8_t)(data), &weights, &scales, &biases) - bits = 4 - case 8: - C.extract_q8_0_data((*C.uint8_t)(data), &weights, &scales, &biases) - bits = 8 - } - groupSize := C.mlx_optional_int{value: 32, has_value: true} - bitsOpt := C.mlx_optional_int{value: bits, has_value: true} - var dtypeOpt C.mlx_optional_dtype // has_value defaults to false - C.mlx_dequantize( - &r, - weights, - scales, - biases, - groupSize, - bitsOpt, - nil, // TODO mode - dtypeOpt, - stream, - ) - C.mlx_array_free(weights) - C.mlx_array_free(scales) - C.mlx_array_free(biases) - - return r, nil -} - -func load_k_quantized(data unsafe.Pointer, name string, shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) { - size := 1 - for _, d := range shape { - size *= int(d) - } - fdata := make([]float16.Float16, size) - switch dtype { - case 14: - C.dequant_row_q6_K( - data, - unsafe.Pointer(&fdata[0]), - C.int(size), - ) - - case 12: - C.dequant_row_q4_K( - data, - unsafe.Pointer(&fdata[0]), - C.int(size), - ) - default: - return r, fmt.Errorf("unsupported K quant") - } - - r = C.mlx_array_new_data( - unsafe.Pointer(&fdata[0]), - &shape[0], - C.int(len(shape)), - C.MLX_FLOAT16, - ) - return r, nil -} diff --git a/x/ml/device.go b/x/ml/device.go deleted file mode 100644 index f892b512d..000000000 --- a/x/ml/device.go +++ /dev/null @@ -1,643 +0,0 @@ -package ml - -import ( - "context" - "encoding/binary" - "encoding/json" - "fmt" - "hash/maphash" - "io" - "log/slog" - "math" - "net/http" - "runtime" - "slices" - "sort" - "strconv" - "strings" - "time" - - "github.com/ollama/ollama/format" - "github.com/ollama/ollama/logutil" -) - -// GPULayers is a set of layers to be allocated on a single GPU -type GPULayers struct { - DeviceID - - // Layers is a set of layer indicies to load - Layers []int -} - -// FirstLayer returns the smallest layer index scheduled on this GPU, or MaxInt when empty. -func (g GPULayers) FirstLayer() int { - if len(g.Layers) == 0 { - return math.MaxInt - } - - first := g.Layers[0] - for i := 1; i < len(g.Layers); i++ { - if g.Layers[i] < first { - first = g.Layers[i] - } - } - - return first -} - -func (g GPULayers) String() string { - if len(g.Layers) == 0 { - return "" - } - - slices.Sort(g.Layers) - - contiguous := true - base := g.Layers[0] - for i := range g.Layers { - if g.Layers[i] != base+i { - contiguous = false - break - } - } - - if contiguous { - return fmt.Sprintf("ID:%v Layers:%v(%v..%v)", g.ID, len(g.Layers), g.Layers[0], g.Layers[len(g.Layers)-1]) - } else { - return fmt.Sprintf("ID:%v Layers:%v%v", g.ID, len(g.Layers), g.Layers) - } -} - -// GPULayersList is a set of layer allocations across multiple GPUs -type GPULayersList []GPULayers - -func (l GPULayersList) Len() int { return len(l) } -func (l GPULayersList) Swap(i, j int) { l[i], l[j] = l[j], l[i] } - -// Sort by the ordering of the layers offloaded -func (l GPULayersList) Less(i, j int) bool { - li := l[i].FirstLayer() - lj := l[j].FirstLayer() - - return li < lj -} - -func (l GPULayersList) String() string { - if l.Sum() > 0 { - return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l)) - } else { - return fmt.Sprintf("%v", []GPULayers(l)) - } -} - -// Sum is the total number of layers assigned across all GPUs -func (l GPULayersList) Sum() int { - var sum int - - for _, g := range l { - sum += len(g.Layers) - } - - return sum -} - -var h maphash.Hash - -// Hash is an identifier of this layer assignment -func (l GPULayersList) Hash() uint64 { - h.Reset() - for _, g := range l { - if len(g.Layers) > 0 { - h.WriteString(g.ID + g.Library) - for _, l := range g.Layers { - binary.Write(&h, binary.NativeEndian, int64(l)) - } - } - } - - return h.Sum64() -} - -// ErrNoMem is returned when panicing due to insufficient memory. It includes -// the attempted memory allocation. -type ErrNoMem struct { - BackendMemory -} - -func (e ErrNoMem) Error() string { - return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory) -} - -// Minimal unique device identification -type DeviceID struct { - // ID is an identifier for the device for matching with system - // management libraries. The ID is only unique for other devices - // using the same Library. - // This ID represents a "post filtered" view of the enumerated devices - // if the ID is numeric - ID string `json:"id"` - - // Library identifies which library is used for the device (e.g. CUDA, ROCm, etc.) - Library string `json:"backend,omitempty"` -} - -// DeviceMemory provides a breakdown of the memory needed -// per device, such as a CPU or GPU. -type DeviceMemory struct { - DeviceID - - // Name is the name of the device as labeled by the backend. It - // may not be persistent across instances of the runner. - Name string - - // Weights is the per-layer memory needed for the model weights. - Weights []uint64 - - // Cache is the per-layer memory needed for the KV cache. - Cache []uint64 - - // Graph is the size of the compute graph. It is not per-layer. - Graph uint64 -} - -func sumMemory(mem []uint64) uint64 { - var sum uint64 - - for _, m := range mem { - sum += m - } - - return sum -} - -// Size returns the total size of the memory required by this device -func (m DeviceMemory) Size() uint64 { - return sumMemory(m.Weights) + sumMemory(m.Cache) + m.Graph -} - -func memoryPresent(mem []uint64) bool { - return slices.ContainsFunc(mem, func(m uint64) bool { return m != 0 }) -} - -func (m DeviceMemory) LogValue() slog.Value { - var attrs []slog.Attr - if memoryPresent(m.Weights) { - attrs = append(attrs, slog.Any("Weights", m.Weights)) - } - - if memoryPresent(m.Cache) { - attrs = append(attrs, slog.Any("Cache", m.Cache)) - } - - if m.Graph != 0 { - attrs = append(attrs, slog.Any("Graph", m.Graph)) - } - - if len(attrs) > 0 && m.ID != "" { - attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...) - } - - return slog.GroupValue(attrs...) -} - -// BackendMemory provides the amount of memory required to load the model -// per device based on the BackendParams. In some cases, not all required -// allocations will be known at this point. However, the size of the most recent -// allocation is guaranteed to be provided so that if it failed, the caller can -// accommodate that to make forward progress. -type BackendMemory struct { - // InputWeights are always located on the CPU and cannot be moved - InputWeights uint64 - - // CPU model components are located in system memory. This does not - // include unified memory allocated through the GPU. - CPU DeviceMemory - - // GPU model components are located on one or more GPUs. - GPUs []DeviceMemory -} - -func (m BackendMemory) LogValue() slog.Value { - var attrs []slog.Attr - if m.InputWeights != 0 { - attrs = append(attrs, slog.Any("InputWeights", m.InputWeights)) - } - - attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU)) - for _, g := range m.GPUs { - attrs = append(attrs, slog.Any(g.Name, g)) - } - - return slog.GroupValue(attrs...) -} - -// Log prints a high level summary of the memory -func (m BackendMemory) Log(level slog.Level) { - var total uint64 - - for _, gpu := range m.GPUs { - if sum := sumMemory(gpu.Weights); sum > 0 { - slog.Log(context.TODO(), level, "model weights", "device", gpu.Name, "size", format.HumanBytes2(sum)) - total += sum - } - } - if sum := m.InputWeights + sumMemory(m.CPU.Weights); sum > 0 { - slog.Log(context.TODO(), level, "model weights", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) - total += sum - } - - for _, gpu := range m.GPUs { - if sum := sumMemory(gpu.Cache); sum > 0 { - slog.Log(context.TODO(), level, "kv cache", "device", gpu.Name, "size", format.HumanBytes2(sum)) - total += sum - } - } - if sum := sumMemory(m.CPU.Cache); sum > 0 { - slog.Log(context.TODO(), level, "kv cache", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) - total += sum - } - - for _, gpu := range m.GPUs { - if sum := gpu.Graph; sum > 0 { - slog.Log(context.TODO(), level, "compute graph", "device", gpu.Name, "size", format.HumanBytes2(sum)) - total += sum - } - } - if sum := m.CPU.Graph; sum > 0 { - slog.Log(context.TODO(), level, "compute graph", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) - total += sum - } - - if total > 0 { - slog.Log(context.TODO(), level, "total memory", "size", format.HumanBytes2(total)) - } -} - -type DeviceInfo struct { - DeviceID - - // Name is the name of the device as labeled by the backend. It - // may not be persistent across instances of the runner. - Name string `json:"name"` - - // Description is the longer user-friendly identification of the device - Description string `json:"description"` - - // FilterID is populated with the unfiltered device ID if a numeric ID is used - // so the device can be included. - FilterID string `json:"filter_id,omitempty"` - - // Integrated is set true for integrated GPUs, false for Discrete GPUs - Integrated bool `json:"integration,omitempty"` - - // PCIID is the bus, device and domain ID of the device for deduplication - // when discovered by multiple backends - PCIID string `json:"pci_id,omitempty"` - - // TotalMemory is the total amount of memory the device can use for loading models - TotalMemory uint64 `json:"total_memory"` - - // FreeMemory is the amount of memory currently available on the device for loading models - FreeMemory uint64 `json:"free_memory,omitempty"` - - // ComputeMajor is the major version of capabilities of the device - // if unsupported by the backend, -1 will be returned - ComputeMajor int - - // ComputeMinor is the minor version of capabilities of the device - // if unsupported by the backend, -1 will be returned - ComputeMinor int - - // Driver Information - DriverMajor int `json:"driver_major,omitempty"` - DriverMinor int `json:"driver_minor,omitempty"` - - // Where backends were loaded from - LibraryPath []string -} - -type SystemInfo struct { - // ThreadCount is the optimal number of threads to use for inference - ThreadCount int `json:"threads,omitempty"` - - // TotalMemory is the total amount of system memory - TotalMemory uint64 `json:"total_memory,omitempty"` - - // FreeMemory is the amount of memory currently available on the system for loading models - FreeMemory uint64 `json:"free_memory,omitempty"` - - // FreeSwap is the amount of system swap space reported as available - FreeSwap uint64 `json:"free_swap,omitempty"` -} - -func (d DeviceInfo) Compute() string { - // AMD gfx is encoded into the major minor in hex form - if strings.EqualFold(d.Library, "ROCm") { - return fmt.Sprintf("gfx%x%02x", d.ComputeMajor, d.ComputeMinor) - } - return strconv.Itoa(d.ComputeMajor) + "." + strconv.Itoa(d.ComputeMinor) -} - -func (d DeviceInfo) Driver() string { - return strconv.Itoa(d.DriverMajor) + "." + strconv.Itoa(d.DriverMinor) -} - -// MinimumMemory reports the amount of memory that should be set aside -// on the device for overhead (e.g. VRAM consumed by context structures independent -// of model allocations) -func (d DeviceInfo) MinimumMemory() uint64 { - if d.Library == "Metal" { - return 512 * format.MebiByte - } - return 457 * format.MebiByte -} - -// Sort by Free Space. -// iGPUs are reported first, thus Reverse() yields the largest discrete GPU first -type ByFreeMemory []DeviceInfo - -func (a ByFreeMemory) Len() int { return len(a) } -func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -func (a ByFreeMemory) Less(i, j int) bool { - if a[i].Integrated && !a[j].Integrated { - return true - } else if !a[i].Integrated && a[j].Integrated { - return false - } - return a[i].FreeMemory < a[j].FreeMemory -} - -// ByPerformance groups devices by similar speed -func ByPerformance(l []DeviceInfo) [][]DeviceInfo { - resp := [][]DeviceInfo{} - scores := []bool{} - for _, info := range l { - found := false - requested := info.Integrated - for i, score := range scores { - if score == requested { - resp[i] = append(resp[i], info) - found = true - break - } - } - if !found { - scores = append(scores, requested) - resp = append(resp, []DeviceInfo{info}) - } - } - return resp -} - -func ByLibrary(l []DeviceInfo) [][]DeviceInfo { - resp := [][]DeviceInfo{} - libs := []string{} - for _, info := range l { - found := false - requested := info.Library - for i, lib := range libs { - if lib == requested { - resp[i] = append(resp[i], info) - found = true - break - } - } - if !found { - libs = append(libs, requested) - resp = append(resp, []DeviceInfo{info}) - } - } - return resp -} - -func LibraryPaths(l []DeviceInfo) []string { - gpuLibs := []string{LibOllamaPath} - for _, gpu := range l { - for _, dir := range gpu.LibraryPath { - needed := true - for _, existing := range gpuLibs { - if dir == existing { - needed = false - break - } - } - if needed { - gpuLibs = append(gpuLibs, dir) - } - } - } - return gpuLibs -} - -type DeviceComparison int - -const ( - UniqueDevice DeviceComparison = iota - SameBackendDevice // The device is the same, and the library/backend is the same - DuplicateDevice // The same physical device but different library/backend (overlapping device) -) - -func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison { - if a.PCIID != b.PCIID { - return UniqueDevice - } - // If PCIID is empty, we have to use ID + library for uniqueness - if a.PCIID == "" && a.DeviceID != b.DeviceID { - return UniqueDevice - } - if a.Library == b.Library { - return SameBackendDevice - } - return DuplicateDevice -} - -// For a SameBackendDevice, return true if b is better than a -// e.g. newer GPU library version -func (a DeviceInfo) IsBetter(b DeviceInfo) bool { - aLib := a.LibraryPath[len(a.LibraryPath)-1] - bLib := b.LibraryPath[len(b.LibraryPath)-1] - if aLib == bLib { - return false - } - aLibSplit := strings.SplitN(aLib, "_", 2) - bLibSplit := strings.SplitN(bLib, "_", 2) - if len(aLibSplit) < 2 || len(bLibSplit) < 2 { - return false - } - if aLibSplit[0] != bLibSplit[0] { - slog.Debug("unexpected libraries", "a", aLib, "b", bLib) - return false - } - if aLibSplit[1] == bLibSplit[1] { - return false - } - cmp := []string{aLibSplit[1], bLibSplit[1]} - sort.Sort(sort.Reverse(sort.StringSlice(cmp))) - return cmp[0] == bLibSplit[1] -} - -// For each GPU, check if it does NOT support flash attention -func FlashAttentionSupported(l []DeviceInfo) bool { - for _, gpu := range l { - supportsFA := gpu.Library == "cpu" || - gpu.Name == "Metal" || gpu.Library == "Metal" || - (gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) || - gpu.Library == "ROCm" || - gpu.Library == "Vulkan" - - if !supportsFA { - return false - } - } - return true -} - -// Given the list of GPUs this instantiation is targeted for, -// figure out the visible devices environment variables -// Set mustFilter true to enable filtering of CUDA devices -func GetVisibleDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string { - if len(l) == 0 { - return nil - } - env := map[string]string{} - for _, d := range l { - d.updateVisibleDevicesEnv(env, mustFilter) - } - return env -} - -// NeedsInitValidation returns true if the device in question has the potential -// to crash at inference time and requires deeper validation before we include -// it in the supported devices list. -func (d DeviceInfo) NeedsInitValidation() bool { - // ROCm: rocblas will crash on unsupported devices. - // CUDA: verify CC is supported by the version of the library - return d.Library == "ROCm" || d.Library == "CUDA" -} - -// Set the init validation environment variable -func (d DeviceInfo) AddInitValidation(env map[string]string) { - env["GGML_CUDA_INIT"] = "1" // force deep initialization to trigger crash on unsupported GPUs -} - -// PreferredLibrary returns true if this library is preferred over the other input -// library -// Used to filter out Vulkan in favor of CUDA or ROCm -func (d DeviceInfo) PreferredLibrary(other DeviceInfo) bool { - // TODO in the future if we find Vulkan is better than ROCm on some devices - // that implementation can live here. - - if d.Library == "CUDA" || d.Library == "ROCm" { - return true - } - return false -} - -func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string, mustFilter bool) { - var envVar string - switch d.Library { - case "ROCm": - // ROCm must be filtered as it can crash the runner on unsupported devices - envVar = "ROCR_VISIBLE_DEVICES" - if runtime.GOOS != "linux" { - envVar = "HIP_VISIBLE_DEVICES" - } - case "CUDA": - if !mustFilter { - // By default we try to avoid filtering CUDA devices because ROCm also - // looks at the CUDA env var, and gets confused in mixed vendor environments. - return - } - envVar = "CUDA_VISIBLE_DEVICES" - default: - // Vulkan is not filtered via env var, but via scheduling decisions - return - } - v, existing := env[envVar] - if existing { - v = v + "," - } - if d.FilterID != "" { - v = v + d.FilterID - } else { - v = v + d.ID - } - env[envVar] = v -} - -type BaseRunner interface { - // GetPort returns the localhost port number the runner is running on - GetPort() int - - // HasExited indicates if the runner is no longer running. This can be used during - // bootstrap to detect if a given filtered device is incompatible and triggered an assert - HasExited() bool -} - -type RunnerDiscovery interface { - BaseRunner - - // GetDeviceInfos will perform a query of the underlying device libraries - // for device identification and free VRAM information - // During bootstrap scenarios, this routine may take seconds to complete - GetDeviceInfos(ctx context.Context) []DeviceInfo -} - -type FilteredRunnerDiscovery interface { - RunnerDiscovery - - // GetActiveDeviceIDs returns the filtered set of devices actively in - // use by this runner for running models. If the runner is a bootstrap runner, no devices - // will be active yet so no device IDs are returned. - // This routine will not query the underlying device and will return immediately - GetActiveDeviceIDs() []DeviceID -} - -func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo, error) { - var moreDevices []DeviceInfo - port := runner.GetPort() - tick := time.Tick(10 * time.Millisecond) - for { - select { - case <-ctx.Done(): - return nil, fmt.Errorf("failed to finish discovery before timeout") - case <-tick: - r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - r.Header.Set("Content-Type", "application/json") - - resp, err := http.DefaultClient.Do(r) - if err != nil { - // slog.Warn("failed to send request", "error", err) - if runner.HasExited() { - return nil, fmt.Errorf("runner crashed") - } - continue - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusNotFound { - // old runner, fall back to bootstrapping model - return nil, fmt.Errorf("llamarunner free vram reporting not supported") - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - slog.Warn("failed to read response", "error", err) - continue - } - if resp.StatusCode != 200 { - logutil.Trace("runner failed to discover free VRAM", "status", resp.StatusCode, "response", body) - return nil, fmt.Errorf("runner error: %s", string(body)) - } - - if err := json.Unmarshal(body, &moreDevices); err != nil { - slog.Warn("unmarshal encode response", "error", err) - continue - } - return moreDevices, nil - } - } -} diff --git a/x/ml/nn/attention.go b/x/ml/nn/attention.go deleted file mode 100644 index c4a16a302..000000000 --- a/x/ml/nn/attention.go +++ /dev/null @@ -1,103 +0,0 @@ -package nn - -import ( - "fmt" - - "github.com/ollama/ollama/x/kvcache" - "github.com/ollama/ollama/x/ml" -) - -// Attention implements scaled dot-product attention for transformer models: -// Attention(Q, K, V) = softmax(QK^T/√d_k)V -// -// Parameters: -// - ctx: Context for tensor operations -// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q] -// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only -// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only -// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension -// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value -// -// Returns: -// -// Attention output with shape [d_v, heads, seq_len_q] -func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { - return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache) -} - -func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { - return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache) -} - -func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { - ctx.Forward(query) - - if key != nil && value != nil { - if query.Dim(0) != key.Dim(0) { - panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) - } - - if key.Dim(1) != value.Dim(1) { - panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1))) - } - - if key.Dim(2) != value.Dim(2) { - panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) - } - - ctx.Forward(key, value) - if cache != nil { - cache.Put(ctx, key, value) - } - } else if cache == nil { - panic("key & value tensors must be provided if cache is nil") - } - - // ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query, "k": key, "v": value}, true) - // panic("after cache get") // - // 2025/12/10 16:02:33 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844] - // 2025/12/10 16:02:33 INFO XXX tensors are similar k=0.9999891519546509 shape="[1 4 13 256]" min_difference=[-0.21365738] max_difference=[0.19916534] - // 2025/12/10 16:02:33 INFO XXX tensors are similar v=0.9999960660934448 shape="[1 4 13 256]" min_difference=[-0.32923126] max_difference=[0.32646942] - - // var mask ml.Tensor - if cache != nil { - key, value, _ = cache.Get(ctx) - } - // ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query.Contiguous(ctx, false), "k": key.Contiguous(ctx, false), "v": value.Contiguous(ctx, false)}, true) - // panic("after cache get") // - // 2025/12/10 15:34:03 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844] - // 2025/12/10 15:34:03 INFO XXX tensors are similar k=0.9999881982803345 shape="[1 4 13 256]" min_difference=[-0.25] max_difference=[0.25] - // 2025/12/10 15:34:03 INFO XXX tensors are similar v=0.9999913573265076 shape="[1 4 13 256]" min_difference=[-0.5] max_difference=[0.5] - - // Only use the fast SDPA implementation if we have a cache, since that's what - // will do any expected backend-specific transformations for us - - if cache != nil { - // TODO what to do with vmla? - // return query.Transpose(ctx, 0, 2, 1, 3).ScaledDotProductAttention(ctx, key.Transpose(ctx, 0, 2, 1, 3), value.Transpose(ctx, 0, 2, 1, 3), scale, "array", mask, sinks) - return query.ScaledDotProductAttention(ctx, key, value, scale, "causal", nil, sinks) - - // TODO these two produce identical output, but not similar enough - 92.9% - should be 99.999% - } else { - panic("else case not supported") - // TODO transpose shapes are wrong - // key = key.Transpose(ctx, 0, 2, 1, 3) - // value = value.Transpose(ctx, 1, 2, 0, 3).Contiguous(ctx, false) - - // kq := query.Matmul(ctx, key) - - // kq = kq.Scale(ctx, scale) - // if mask != nil { - // kq = kq.Add(ctx, mask) - // } - // kq = kq.Softmax(ctx) - - // kqv := kq.Matmul(ctx, value) - - // if vmla != nil { - // kqv = kqv.Matmul(ctx, vmla) - // } - - // return kqv.Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false) - } -} diff --git a/x/ml/nn/convolution.go b/x/ml/nn/convolution.go deleted file mode 100644 index 7c4b5a520..000000000 --- a/x/ml/nn/convolution.go +++ /dev/null @@ -1,30 +0,0 @@ -package nn - -import "github.com/ollama/ollama/x/ml" - -type Conv2D struct { - Weight ml.Tensor `gguf:"weight"` - Bias ml.Tensor `gguf:"bias"` -} - -func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { - t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1, 1) - if m.Bias != nil { - // Bias shape is (out_channels,) while t shape is (width, height, out_channels, batch) - t = t.Add(ctx, m.Bias.Reshape(ctx, 1, 1, -1)) - } - return t -} - -type Conv3D struct { - Weight ml.Tensor `gguf:"weight"` - Bias ml.Tensor `gguf:"bias"` -} - -func (m *Conv3D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, s2, p0, p1, p2, d0, d1, d2, g int) ml.Tensor { - t = m.Weight.Conv3D(ctx, t, s0, s1, s2, p0, p1, p2, d0, d1, d2, g) - if m.Bias != nil { - t = t.Add(ctx, m.Bias) - } - return t -} diff --git a/x/ml/nn/embedding.go b/x/ml/nn/embedding.go deleted file mode 100644 index b00aa2ff1..000000000 --- a/x/ml/nn/embedding.go +++ /dev/null @@ -1,11 +0,0 @@ -package nn - -import "github.com/ollama/ollama/x/ml" - -type Embedding struct { - Weight ml.Tensor `gguf:"weight"` -} - -func (m *Embedding) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor { - return m.Weight.TakeAxes(ctx, hiddenState, 0) -} diff --git a/x/ml/nn/linear.go b/x/ml/nn/linear.go deleted file mode 100644 index 6d108e095..000000000 --- a/x/ml/nn/linear.go +++ /dev/null @@ -1,32 +0,0 @@ -package nn - -import "github.com/ollama/ollama/x/ml" - -type Linear struct { - Weight ml.Tensor `gguf:"weight"` - Bias ml.Tensor `gguf:"bias"` -} - -func (m *Linear) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor { - t = t.Matmul(ctx, m.Weight.Transpose(ctx)) - if m.Bias != nil { - t = t.Add(ctx, m.Bias) - } - - return t -} - -type LinearBatch struct { - Weight ml.Tensor `gguf:"weight"` - Bias ml.Tensor `gguf:"bias"` -} - -func (m *LinearBatch) Forward(ctx ml.Context, t, indices ml.Tensor) ml.Tensor { - panic("not yet ported") - // t = m.Weight.MulmatID(ctx, t, indices) - // if m.Bias != nil { - // t = t.AddID(ctx, m.Bias, indices) - // } - - // return t -} diff --git a/x/ml/nn/normalization.go b/x/ml/nn/normalization.go deleted file mode 100644 index 621245ab4..000000000 --- a/x/ml/nn/normalization.go +++ /dev/null @@ -1,29 +0,0 @@ -package nn - -import ( - "github.com/ollama/ollama/x/ml" -) - -type LayerNorm struct { - Weight ml.Tensor `gguf:"weight"` - Bias ml.Tensor `gguf:"bias"` -} - -func (m *LayerNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor { - return t.LayerNorm(ctx, m.Weight, m.Bias, eps) -} - -type RMSNorm struct { - Weight ml.Tensor `gguf:"weight"` -} - -func (m *RMSNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor { - // slog.Info("RMSNorm", "eps", eps) - // fmt.Fprintln(os.Stderr, t.ToString()) - // fmt.Fprintln(os.Stderr, m.Weight.ToString()) - - // TODO this is probably model specific, not generalized... - w := m.Weight.Add(ctx, ctx.FromFloats([]float32{1.0}, 1)) - - return t.RMSNorm(ctx, w, eps) -} diff --git a/x/ml/nn/pooling/pooling.go b/x/ml/nn/pooling/pooling.go deleted file mode 100644 index 2dae6dc43..000000000 --- a/x/ml/nn/pooling/pooling.go +++ /dev/null @@ -1,41 +0,0 @@ -package pooling - -import ( - "github.com/ollama/ollama/x/ml" -) - -type Type uint32 - -const ( - TypeNone Type = iota - TypeMean - TypeCLS - TypeLast -) - -func (t Type) String() string { - switch t { - case TypeMean: - return "Mean" - case TypeCLS: - return "CLS" - case TypeLast: - return "Last" - default: - return "Unknown" - } -} - -func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { - switch t { - // case TypeMean: - // hiddenStates = hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false).Mean(ctx) - // return hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false) - // case TypeCLS: - // return hiddenStates.Slice(ctx, 1, 0, 1, 1) - // case TypeLast: - // return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1) - default: - panic("unknown pooling type") - } -} diff --git a/x/ml/nn/rope/rope.go b/x/ml/nn/rope/rope.go deleted file mode 100644 index e868aa614..000000000 --- a/x/ml/nn/rope/rope.go +++ /dev/null @@ -1,72 +0,0 @@ -package rope - -import "github.com/ollama/ollama/x/ml" - -// Options contains optional parameters for RoPE function -type Options struct { - Type int - Factors ml.Tensor - - // YaRN options - YaRN struct { - OriginalContextLength int - ExtrapolationFactor, - AttentionFactor, - BetaFast, - BetaSlow float32 - } - - // MRoPE options - MRoPE struct { - Sections []int - } -} - -// WithTypeNeoX sets RoPE type to NeoX -func WithTypeNeoX() func(*Options) { - return func(opts *Options) { - opts.Type = 2 - } -} - -// WithFactors sets custom rope factors -func WithFactors(factors ml.Tensor) func(*Options) { - return func(opts *Options) { - if factors != nil { - opts.Factors = factors - } - } -} - -// WithOriginalContextLength sets a custom context length -func WithOriginalContextLength(n int) func(*Options) { - return func(opts *Options) { - opts.YaRN.OriginalContextLength = n - } -} - -func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) { - return func(opts *Options) { - opts.YaRN.ExtrapolationFactor = extrapolationFactor - } -} - -func WithAttentionFactor(attentionFactor float32) func(*Options) { - return func(opts *Options) { - opts.YaRN.AttentionFactor = attentionFactor - } -} - -func WithMRoPE(sections []int) func(*Options) { - return func(opts *Options) { - opts.Type |= 1 << 3 - opts.MRoPE.Sections = sections - } -} - -func WithInterleaveMRoPE(sections []int) func(*Options) { - return func(opts *Options) { - opts.Type |= 1<<3 | 1<<5 - opts.MRoPE.Sections = sections - } -} diff --git a/x/ml/path.go b/x/ml/path.go deleted file mode 100644 index ac93af403..000000000 --- a/x/ml/path.go +++ /dev/null @@ -1,56 +0,0 @@ -package ml - -import ( - "os" - "path/filepath" - "runtime" -) - -// LibPath is a path to lookup dynamic libraries -// in development it's usually 'build/lib/ollama' -// in distribution builds it's 'lib/ollama' on Windows -// '../lib/ollama' on Linux and the executable's directory on macOS -// note: distribution builds, additional GPU-specific libraries are -// found in subdirectories of the returned path, such as -// 'cuda_v12', 'rocm', etc. -var LibOllamaPath string = func() string { - exe, err := os.Executable() - if err != nil { - return "" - } - - if eval, err := filepath.EvalSymlinks(exe); err == nil { - exe = eval - } - - var libPath string - switch runtime.GOOS { - case "windows": - libPath = filepath.Join(filepath.Dir(exe), "lib", "ollama") - case "linux": - libPath = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama") - case "darwin": - libPath = filepath.Dir(exe) - } - - cwd, err := os.Getwd() - if err != nil { - return "" - } - - paths := []string{ - libPath, - - // build paths for development - filepath.Join(filepath.Dir(exe), "build", "lib", "ollama"), - filepath.Join(cwd, "build", "lib", "ollama"), - } - - for _, p := range paths { - if _, err := os.Stat(p); err == nil { - return p - } - } - - return filepath.Dir(exe) -}() diff --git a/x/model/bytepairencoding.go b/x/model/bytepairencoding.go deleted file mode 100644 index acb58743b..000000000 --- a/x/model/bytepairencoding.go +++ /dev/null @@ -1,282 +0,0 @@ -package model - -import ( - "cmp" - "fmt" - "iter" - "log/slog" - "slices" - "strings" - - "github.com/dlclark/regexp2" - heap "github.com/emirpasic/gods/v2/trees/binaryheap" - "github.com/ollama/ollama/logutil" -) - -type BytePairEncoding struct { - vocab *Vocabulary - regexps []*regexp2.Regexp -} - -var _ TextProcessor = (*BytePairEncoding)(nil) - -func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding { - if len(pretokenizers) == 0 { - // set default byte-level pretokenizer if none provided, e.g. - // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44 - pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`} - } - - return BytePairEncoding{ - vocab: vocab, - regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) { - for _, p := range pretokenizers { - if !yield(regexp2.MustCompile(p, regexp2.RE2)) { - return - } - } - }), - } -} - -func (bpe BytePairEncoding) Vocabulary() *Vocabulary { - return bpe.vocab -} - -func (bpe BytePairEncoding) Is(id int32, special Special) bool { - return bpe.vocab.Is(id, special) -} - -func (bpe *BytePairEncoding) split(s string) iter.Seq[string] { - parts := []string{s} - for _, re := range bpe.regexps { - parts = slices.Collect(func(yield func(string) bool) { - for _, part := range parts { - r := []rune(part) - var offset int - for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) { - if offset-m.Index != 0 { - if !yield(string(r[:m.Index])) { - return - } - } - - if !yield(m.String()) { - return - } - - offset = m.Index + m.Length - } - - if offset < len(r) { - if !yield(string(r[offset:])) { - return - } - } - } - }) - } - - return slices.Values(parts) -} - -// fragment is a string fragment and their corresponding token IDs -type fragment struct { - value string - ids []int32 -} - -// pair is a pair of runes and its rank -type pair struct { - a, b int - rank int - value string -} - -type merge struct { - p, n int - runes []rune -} - -func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { - fragments := []fragment{{value: s}} - for _, special := range bpe.vocab.SpecialVocabulary() { - // TODO: process special tokens concurrently - id := bpe.vocab.Encode(special) - for i := 0; i < len(fragments); i++ { - frag := fragments[i] - if len(frag.ids) > 0 { - continue - } - - var middle []fragment - switch i := strings.Index(frag.value, special); { - case i < 0: - middle = append(middle, frag) - case i > 0: - middle = append(middle, fragment{value: frag.value[:i]}) - fallthrough - default: - middle = append(middle, fragment{value: special, ids: []int32{id}}) - if rest := frag.value[i+len(special):]; rest != "" { - middle = append(middle, fragment{value: rest}) - } - } - - fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...) - } - } - - var ids []int32 - for _, frag := range fragments { - if len(frag.ids) > 0 { - ids = append(ids, frag.ids...) - continue - } - - for split := range bpe.split(frag.value) { - // TODO: process splits concurrently - var sb strings.Builder - for _, b := range []byte(split) { - r := rune(b) - switch { - case r == 0x00ad: - r = 0x0143 - case r <= 0x0020: - r = r + 0x0100 - case r >= 0x007f && r <= 0x00a0: - r = r + 0x00a2 - } - - sb.WriteRune(r) - } - - // short circuit if the fragment is in the vocabulary - if id := bpe.vocab.Encode(sb.String()); id >= 0 { - ids = append(ids, id) - continue - } - - runes := []rune(sb.String()) - merges := make([]merge, len(runes)) - for r := range runes { - merges[r] = merge{ - p: r - 1, - n: r + 1, - runes: []rune{runes[r]}, - } - } - - pairwise := func(a, b int) *pair { - if a < 0 || b >= len(runes) { - return nil - } - - left, right := string(merges[a].runes), string(merges[b].runes) - rank := bpe.vocab.Merge(left, right) - if rank < 0 { - return nil - } - - return &pair{ - a: a, - b: b, - rank: rank, - value: left + right, - } - } - - pairs := heap.NewWith(func(i, j *pair) int { - return cmp.Compare(i.rank, j.rank) - }) - - for i := range len(runes) - 1 { - if pair := pairwise(i, i+1); pair != nil { - pairs.Push(pair) - } - } - - for !pairs.Empty() { - pair, _ := pairs.Pop() - - left, right := merges[pair.a], merges[pair.b] - if len(left.runes) == 0 || len(right.runes) == 0 || - string(left.runes)+string(right.runes) != pair.value { - continue - } - - if id := bpe.vocab.Encode(pair.value); id < 0 { - continue - } - - merges[pair.a].runes = append(left.runes, right.runes...) - merges[pair.b].runes = nil - - merges[pair.a].n = right.n - if right.n < len(merges) { - merges[right.n].p = pair.a - } - - if pair := pairwise(merges[pair.a].p, pair.a); pair != nil { - pairs.Push(pair) - } - - if pair := pairwise(pair.a, merges[pair.a].n); pair != nil { - pairs.Push(pair) - } - } - - for _, merge := range merges { - if len(merge.runes) > 0 { - // TODO: handle the edge case where the rune isn't in the vocabulary - if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 { - ids = append(ids, id) - } - } - } - } - } - - if addSpecial { - ids = bpe.vocab.addSpecials(ids) - } - - logutil.Trace("encoded", "string", s, "ids", ids) - return ids, nil -} - -type lazyIdsString struct { - ids []int32 -} - -func (l lazyIdsString) LogValue() slog.Value { - return slog.AnyValue(fmt.Sprint(l.ids)) -} - -func (bpe BytePairEncoding) Decode(ids []int32) (string, error) { - var sb strings.Builder - for _, id := range ids { - for _, r := range bpe.vocab.Decode(id) { - switch { - case r == 0x0100: - // this produces 0x00 aka NULL - continue - case r == 0x0143: - r = 0x00ad - case r > 0x0100 && r <= 0x0120: - r = r - 0x0100 - case r > 0x0120 && r <= 0x0142: - r = r - 0x00a2 - } - - // NOTE: not using WriteRune here because it writes the UTF-8 - // encoding of the rune which is _not_ what we want - if err := sb.WriteByte(byte(r)); err != nil { - return "", err - } - } - } - - logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids}) - return sb.String(), nil -} diff --git a/x/model/bytepairencoding_test.go b/x/model/bytepairencoding_test.go deleted file mode 100644 index 2a7041284..000000000 --- a/x/model/bytepairencoding_test.go +++ /dev/null @@ -1,322 +0,0 @@ -package model - -import ( - "bufio" - "encoding/json" - "math" - "os" - "path/filepath" - "slices" - "strconv" - "strings" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func llama(t testing.TB) BytePairEncoding { - t.Helper() - - f, err := os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "encoder.json")) - if err != nil { - t.Fatal(err) - } - defer f.Close() - - vocab := make(map[string]int32) - if err := json.NewDecoder(f).Decode(&vocab); err != nil { - t.Fatal(err) - } - - types := make([]int32, len(vocab)) - tokens := make([]string, len(vocab)) - for token, id := range vocab { - tokens[id] = token - types[id] = 1 - } - - for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} { - if _, ok := vocab[token]; !ok { - tokens = append(tokens, token) //nolint:makezero - types = append(types, 3) //nolint:makezero - vocab[token] = int32(len(vocab)) - } - } - - f, err = os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "vocab.bpe")) - if err != nil { - t.Fatal(err) - } - defer f.Close() - - merges := make([]string, 0, 50000) - - scanner := bufio.NewScanner(f) - for scanner.Scan() { - if !strings.HasPrefix(scanner.Text(), "#") { - merges = append(merges, scanner.Text()) - } - } - - return NewBytePairEncoding( - &Vocabulary{ - Values: tokens, - Types: types, - Merges: merges, - }, - "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", - ) -} - -func TestLlama(t *testing.T) { - tokenizer := llama(t) - - t.Run("simple", func(t *testing.T) { - t.Parallel() - - ids, err := tokenizer.Encode("hello world", true) - if err != nil { - t.Error(err) - } - - if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" { - t.Errorf("no match (-theirs +ours):\n%s", diff) - } - - s, err := tokenizer.Decode([]int32{15339, 1917}) - if err != nil { - t.Fatal(err) - } - - if s != "hello world" { - t.Errorf("got %q, want hello world", s) - } - - ids, err = tokenizer.Encode("hello <|end_of_text|>", true) - if err != nil { - t.Error(err) - } - - if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" { - t.Errorf("no match (-theirs +ours):\n%s", diff) - } - }) - - t.Run("simple repeated", func(t *testing.T) { - t.Parallel() - - cases := map[string][]int32{ - strings.Repeat("0", 1): {15}, - strings.Repeat("0", 2): {410}, - strings.Repeat("0", 3): {931}, - strings.Repeat("0", 4): {931, 15}, - strings.Repeat("0", 5): {931, 410}, - strings.Repeat("0", 6): {931, 931}, - strings.Repeat("0", 7): {931, 931, 15}, - strings.Repeat("0", 8): {931, 931, 410}, - strings.Repeat("0", 9): {931, 931, 931}, - strings.Repeat("0", 10): {931, 931, 931, 15}, - strings.Repeat("0", 11): {931, 931, 931, 410}, - strings.Repeat("0", 12): {931, 931, 931, 931}, - strings.Repeat("0", 13): {931, 931, 931, 931, 15}, - strings.Repeat("0", 14): {931, 931, 931, 931, 410}, - strings.Repeat("0", 15): {931, 931, 931, 931, 931}, - strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15}, - strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410}, - } - - for s, want := range cases { - ids, err := tokenizer.Encode(s, true) - if err != nil { - t.Error(err) - } - - if diff := cmp.Diff(want, ids); diff != "" { - t.Errorf("%q no match (-theirs +ours):\n%s", s, diff) - } - } - }) - - t.Run("basic roundtrip", func(t *testing.T) { - t.Parallel() - - cases := []string{ - "hello", - "hello ", - "hello ", - " hello", - " hello ", - " hello ", - "hello world", - "请考试我的软件!12345", - } - - for _, want := range cases { - ids, err := tokenizer.Encode(want, true) - if err != nil { - t.Error(err) - } - - if got, err := tokenizer.Decode(ids); err != nil { - t.Fatal(err) - } else if got != want { - t.Errorf("got %q, want %q", got, want) - } - } - }) - - t.Run("special", func(t *testing.T) { - t.Parallel() - - cases := map[string][]int32{ - "<|begin_of_text|>A B!": {128000, 32, 426, 0}, - "<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0}, - "<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0}, - "<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001}, - } - - for s, want := range cases { - ids, err := tokenizer.Encode(s, true) - if err != nil { - t.Fatal(err) - } - - if diff := cmp.Diff(want, ids); diff != "" { - t.Errorf("no match (-theirs +ours):\n%s", diff) - } - } - }) - - t.Run("split", func(t *testing.T) { - t.Parallel() - - cases := map[string][]string{ - "Hello World!": {"Hello", " World", "!"}, - "I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"}, - "In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"}, - "Hello!! ...world": {"Hello", "!!", " ...", "world"}, - "Hello World": {"Hello", " ", " World"}, - "Hello\nWorld": {"Hello", "\n", "World"}, - "Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"}, - } - - for s, want := range cases { - got := slices.Collect(tokenizer.split(s)) - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("no match (-theirs +ours):\n%s", diff) - } - } - }) - - t.Run("roundtriping 0x00-0xFF", func(t *testing.T) { - t.Parallel() - - for b := 0x00; b <= 0xFF; b++ { - input := string(rune(b)) - ids, err := tokenizer.Encode(input, false) - if err != nil { - t.Errorf("failed to encode rune 0x%02X: %v", b, err) - continue - } - - decoded, err := tokenizer.Decode(ids) - if err != nil { - t.Errorf("failed to decode rune 0x%02X: %v", b, err) - continue - } - - if b == 0x00 { - if len(decoded) != 0 { - t.Errorf("Decode(Encode(0x00)) should be empty, got %v", ids) - } - continue - } - - if decoded != input { - t.Errorf("rune 0x%02X failed roundtrip: got %q, want %q", b, decoded, input) - } - } - }) -} - -func BenchmarkBytePairEncoding(b *testing.B) { - tokenizer := llama(b) - bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt")) - if err != nil { - b.Fatal(err) - } - - for i := range 8 { - n := min(int(math.Pow10(i)), len(bts)) - bts := bts[:n] - b.Run("encode"+strconv.Itoa(n), func(b *testing.B) { - b.ResetTimer() - for b.Loop() { - _, err := tokenizer.Encode(string(bts), true) - if err != nil { - b.Fatal(err) - } - } - }) - - b.Run("decode"+strconv.Itoa(n), func(b *testing.B) { - ids, err := tokenizer.Encode(string(bts), true) - if err != nil { - b.Fatal(err) - } - - b.ResetTimer() - for b.Loop() { - _, err := tokenizer.Decode(ids) - if err != nil { - b.Fatal(err) - } - } - }) - - b.Run("split"+strconv.Itoa(n), func(b *testing.B) { - b.ResetTimer() - for b.Loop() { - slices.Collect(tokenizer.split(string(bts))) - } - }) - } -} - -func TestSplit(t *testing.T) { - cases := []struct { - name string - patterns, - want []string - }{ - { - name: "default", - want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"}, - }, - { - name: "unicode", - patterns: []string{ - "\\p{N}{1,3}", - `[一-龥぀-ゟ゠-ヿ]+`, - "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", - }, - want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"}, - }, - { - name: "individual digits", - patterns: []string{ - "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", - }, - want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"}, - }, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - tokenizer := NewBytePairEncoding(nil, tt.patterns...) - if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" { - t.Errorf("no match (-theirs +ours):\n%s", diff) - } - }) - } -} diff --git a/x/model/input/input.go b/x/model/input/input.go deleted file mode 100644 index 05857e20a..000000000 --- a/x/model/input/input.go +++ /dev/null @@ -1,76 +0,0 @@ -package input - -import "github.com/ollama/ollama/x/ml" - -// Multimodal is a multimodal embedding or a component of one. -// For example, it could be a row of an image that can be processed -// independently. -type Multimodal struct { - // Tensor is the embedding data. Implementations may chose what to - // store here or it may be nil if not needed. However, any ml.Tensor - // objects must be stored here and not in Data. - Tensor ml.Tensor - - // Data is implementation-specific opaque data, such as metadata on how - // to layout Tensor. It may be nil if not needed. It may also store larger - // objects such as complete images if they are to be processed later. - Data any -} - -// Input represents one token in the input stream -type Input struct { - // Token is a single element of text. - Token int32 - - // Multimodal is represents a non-text element such as an - // image (or part of one if the image can be processed in pieces). - // It may be used either together with Token or on its own. - Multimodal []Multimodal - - // MultimodalHash is a unique representation of the data - // stored in Multimodal, used for caching and comparing - // equality. - MultimodalHash uint64 - - // SameBatch forces the following number of tokens to be processed - // in a single batch, breaking and extending batches as needed. - // Useful for things like images that must be processed in one - // shot. - SameBatch int -} - -// MultimodalIndex is a multimodal element (such as an image) -// together with an index into the slice of Inputs with the -// corresponding token. Note that the index is not the same -// as the position - to find that use the index with the -// Positions slice. -type MultimodalIndex struct { - Index int - Multimodal []Multimodal -} - -// Batch contains the inputs for a model forward pass -type Batch struct { - // Inputs is the input tokens, including placeholders for multimodal inputs. - Inputs ml.Tensor - - // Outputs are the set of indicies into Inputs for which output data should - // be returned. - Outputs ml.Tensor - - // TODO maybe not the optimal way to handle this - // Offset of final tensor in the final batch - Offset int - - // Positions is the position for each Input, relative to its sequence. Equal - // in length to Inputs. - Positions []int32 - - // Sequences is the sequence for each Input. Equal in length to Inputs. - Sequences []int - - // Multimodal is a set of multimodal embeddings previously created by - // EncodeMultimodal, along with an index into Inputs. Unused for text-only - // models or for batches without multimodal elements. - Multimodal []MultimodalIndex -} diff --git a/x/model/model.go b/x/model/model.go deleted file mode 100644 index 60c3d1487..000000000 --- a/x/model/model.go +++ /dev/null @@ -1,333 +0,0 @@ -package model - -import ( - "errors" - "fmt" - _ "image/jpeg" - _ "image/png" - "log/slog" - "os" - "reflect" - "strconv" - "strings" - - _ "golang.org/x/image/bmp" - _ "golang.org/x/image/tiff" - _ "golang.org/x/image/webp" - - "github.com/ollama/ollama/fs" - fsggml "github.com/ollama/ollama/fs/ggml" - "github.com/ollama/ollama/logutil" - "github.com/ollama/ollama/x/kvcache" - "github.com/ollama/ollama/x/ml" - _ "github.com/ollama/ollama/x/ml/backend" - "github.com/ollama/ollama/x/ml/nn/pooling" - "github.com/ollama/ollama/x/model/input" -) - -var ( - ErrNoVisionModel = errors.New("this model is missing data required for image input") - ErrUnsupportedModel = errors.New("model not supported") - ErrUnsupportedTokenizer = errors.New("tokenizer not supported") -) - -// Model implements a specific model architecture, defining the forward pass and any model-specific configuration -type Model interface { - Forward(ml.Context, input.Batch) (ml.Tensor, error) - - Backend() ml.Backend - Config() config -} - -// MultimodalProcessor must be implemented by multimodal models. -type MultimodalProcessor interface { - // EncodeMultimodal processes a single input (such as an image) and - // generates an output (typically an embedding) that can be used by the model. - // - // The return value is one or more tensors, each with optional model-specific - // opaque metadata. Typically, the tensors might be views into an embedding - // with each view representing a chunk of data that can be processed independently - // in different batches. - // - // The result may be cached by the runner. - EncodeMultimodal(ml.Context, []byte) ([]input.Multimodal, error) - - // PostTokenize is called after tokenization to allow the model to edit the - // input stream to correctly arrange multimodal elements. - // - // The input is a slice of tokens with the results of EncodeMultimodal interleaved - // in the order that the user provided them. Each element of the slice will be - // either a single token or single multimodal object. - // - // The model must ensure that inputs are stored according to how they will be - // processed and stored in the cache. For example, Llava-style models should insert - // placeholder tokens equal to the feature size of the corresponding image with - // the image itself attached to and split across these tokens. When Forward is called - // a partial subset of these tokens may be submitted according to the batch size. - // - // This function is also responsible for updating MultimodalHash for any Multimodal - // that is modified to ensure that there is a unique hash value that accurately - // represents the contents. - PostTokenize([]*input.Input) ([]*input.Input, error) -} - -// Base implements the common fields and methods for all models -type Base struct { - b ml.Backend - config -} - -type config struct { - Cache kvcache.Cache -} - -// Backend returns the underlying backend that will run the model -func (m *Base) Backend() ml.Backend { - return m.b -} - -func (m *Base) Config() config { - return m.config -} - -var models = make(map[string]func(fs.Config) (Model, error)) - -// Register registers a model constructor for the given architecture -func Register(name string, f func(fs.Config) (Model, error)) { - if _, ok := models[name]; ok { - panic("model: model already registered") - } - - models[name] = f -} - -// New initializes a new model instance with the provided configuration based on the metadata in the model file -func New(modelPath string, params ml.BackendParams) (Model, error) { - b, err := ml.NewBackend(modelPath, params) - if err != nil { - return nil, err - } - - m, err := modelForArch(b.Config()) - if err != nil { - return nil, err - } - - base := Base{b: b, config: m.Config()} - v := reflect.ValueOf(m) - v.Elem().Set(populateFields(base, v.Elem())) - return m, nil -} - -func NewTextProcessor(s string) (TextProcessor, error) { - r, err := os.Open(s) - if err != nil { - return nil, err - } - defer r.Close() - - meta, err := fsggml.Decode(r, -1) - if err != nil { - return nil, err - } - - m, err := modelForArch(meta.KV()) - if err != nil { - return nil, err - } - - tp, ok := m.(TextProcessor) - if !ok { - return nil, ErrUnsupportedTokenizer - } - return tp, nil -} - -func modelForArch(c fs.Config) (Model, error) { - arch := c.Architecture() - if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone { - arch = arch + "_embed" - } - - f, ok := models[arch] - if !ok { - return nil, ErrUnsupportedModel - } - - return f(c) -} - -func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { - t := v.Type() - - if t.Kind() == reflect.Struct { - allNil := true - for i := range t.NumField() { - tt := t.Field(i).Type - vv := v.Field(i) - if !vv.CanSet() { - continue - } - - // make a copy - tagsCopy := tags - if tag := t.Field(i).Tag.Get("gguf"); tag != "" { - tagsCopy = append(tagsCopy, parseTag(tag)) - } - - if tt == reflect.TypeOf((*Base)(nil)).Elem() { - vv.Set(reflect.ValueOf(base)) - } else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() { - var fn func([]Tag, string, string) [][]string - fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) { - if len(tags) > 0 { - var names []string - if tags[0].name != "" { - for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) { - names = append(names, prefix+n+suffix) - } - } - childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix) - if len(names) == 0 { - // current tag has no name, use child names only - fullNames = append(fullNames, childNames...) - } else if len(childNames) == 0 { - // current tag has names but no children, create branches for each name - for _, name := range names { - fullNames = append(fullNames, []string{name}) - } - } else { - // merge each name with each child - for _, name := range names { - for _, childName := range childNames { - fullNames = append(fullNames, append([]string{name}, childName...)) - } - } - } - } - - return fullNames - } - - names := fn(tagsCopy, "", "") - for _, name := range names { - if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil { - logutil.Trace("found tensor", "", tensor) - vv.Set(reflect.ValueOf(tensor)) - break - } - } - } else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface { - setPointer(base, vv, tagsCopy) - } else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array { - for i := range vv.Len() { - vvv := vv.Index(i) - if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface { - setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})) - } else { - vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...)) - } - } - } - - if !canNil(tt) || !vv.IsNil() { - allNil = false - } - } - - if allNil { - return reflect.Zero(t) - } - } - - return v -} - -func setPointer(base Base, v reflect.Value, tags []Tag) { - vv := v - if v.Kind() == reflect.Interface { - if v.IsNil() { - return - } - - vv = vv.Elem() - } - - vv = reflect.Indirect(vv) - if v.IsNil() { - vv = reflect.New(v.Type().Elem()).Elem() - } - - if f := populateFields(base, vv, tags...); f.CanAddr() { - v.Set(f.Addr()) - } -} - -type Tag struct { - name, - // prefix and suffix are applied to child tags - prefix, - suffix string - alternatives []string -} - -func parseTag(s string) (tag Tag) { - parts := strings.Split(s, ",") - if len(parts) > 0 { - tag.name = parts[0] - - for _, part := range parts[1:] { - if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" { - // elevate alternative to primary if no primary given - tag.name = value - slog.Warn("gguf tag has alt: but no primary name", "tag", s) - } else if ok { - tag.alternatives = append(tag.alternatives, value) - } - if value, ok := strings.CutPrefix(part, "pre:"); ok { - tag.prefix = value - } - if value, ok := strings.CutPrefix(part, "suf:"); ok { - tag.suffix = value - } - } - } - - return -} - -func canNil(t reflect.Type) bool { - return t.Kind() == reflect.Chan || - t.Kind() == reflect.Func || - t.Kind() == reflect.Interface || - t.Kind() == reflect.Map || - t.Kind() == reflect.Pointer || - t.Kind() == reflect.Slice -} - -func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) { - if len(batch.Positions) != len(batch.Sequences) { - return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences)) - } - - if len(batch.Positions) < 1 { - return nil, errors.New("batch size cannot be less than 1") - } - - cache := m.Config().Cache - if cache != nil { - err := cache.StartForward(ctx, batch, false) - if err != nil { - return nil, err - } - } - - t, err := m.Forward(ctx, batch) - if err != nil { - return nil, err - } - - ctx.Forward(t) - - return t, nil -} diff --git a/x/model/models/gemma3/embed.go b/x/model/models/gemma3/embed.go deleted file mode 100644 index 229cbcb50..000000000 --- a/x/model/models/gemma3/embed.go +++ /dev/null @@ -1,58 +0,0 @@ -//go:build mlx - -package gemma3 - -import ( - "github.com/ollama/ollama/fs" - "github.com/ollama/ollama/x/ml" - "github.com/ollama/ollama/x/ml/nn" - "github.com/ollama/ollama/x/ml/nn/pooling" - "github.com/ollama/ollama/x/model" - "github.com/ollama/ollama/x/model/input" -) - -type embedModel struct { - model.Base - model.SentencePiece - - *TextModel - poolingType pooling.Type - - Dense [2]*nn.Linear `gguf:"dense"` -} - -func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) - hiddenStates = m.poolingType.Forward(ctx, hiddenStates) - for _, dense := range m.Dense { - hiddenStates = dense.Forward(ctx, hiddenStates) - } - hiddenStates = hiddenStates.L2Norm(ctx, 1e-12) - return hiddenStates, nil -} - -func newEmbedModel(c fs.Config) (model.Model, error) { - m := &embedModel{ - SentencePiece: model.NewSentencePiece( - &model.Vocabulary{ - Values: c.Strings("tokenizer.ggml.tokens"), - Scores: c.Floats("tokenizer.ggml.scores"), - Types: c.Ints("tokenizer.ggml.token_type"), - AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, - AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - EOS: append( - []int32{ - int32(c.Uint("tokenizer.ggml.eos_token_id")), - int32(c.Uint("tokenizer.ggml.eot_token_id", 106)), - }, - c.Ints("tokenizer.ggml.eos_token_ids")..., - ), - }, - ), - TextModel: newTextModel(c), - poolingType: pooling.Type(c.Uint("pooling_type", 0)), - } - - return m, nil -} diff --git a/x/model/models/gemma3/model.go b/x/model/models/gemma3/model.go deleted file mode 100644 index 23f78f207..000000000 --- a/x/model/models/gemma3/model.go +++ /dev/null @@ -1,157 +0,0 @@ -//go:build mlx - -package gemma3 - -import ( - "bytes" - "image" - "math" - "slices" - - "github.com/ollama/ollama/fs" - "github.com/ollama/ollama/x/kvcache" - "github.com/ollama/ollama/x/ml" - "github.com/ollama/ollama/x/ml/nn" - "github.com/ollama/ollama/x/model" - "github.com/ollama/ollama/x/model/input" -) - -type Model struct { - model.Base - model.SentencePiece - - *VisionModel `gguf:"vision_tower.vision_model"` - *TextModel `gguf:"language_model.model"` - - *MultiModalProjector `gguf:"multi_modal_projector"` - - ImageProcessor -} - -var _ model.MultimodalProcessor = (*Model)(nil) - -type MultiModalProjector struct { - SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"` - InputProjection *nn.Linear `gguf:"mm_input_projection_weight"` // TODO .weight vs _weight - - tokensPerImage int -} - -func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor { - l := visionOutputs.Dim(0) - - visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false) - patchesPerImage := imageSize / patchSize - visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l) - - kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage))) - visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0) - visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l) - visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false) - visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps) - - // TODO: inputProjection must be transposed since they're incompatible with visionOutputs - visionOutputs = visionOutputs.Matmul(ctx, p.InputProjection.Weight.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)) - return visionOutputs -} - -func New(c fs.Config) (model.Model, error) { - // slog.Info("XXX Config", "c", c) - m := Model{ - SentencePiece: model.NewSentencePiece( - &model.Vocabulary{ - Values: c.Strings("tokenizer.ggml.tokens"), - Scores: c.Floats("tokenizer.ggml.scores"), - Types: c.Ints("tokenizer.ggml.token_type"), - AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, - AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - EOS: append( - []int32{ - int32(c.Uint("tokenizer.ggml.eos_token_id")), - int32(c.Uint("tokenizer.ggml.eot_token_id", 106)), - }, - c.Ints("tokenizer.ggml.eos_token_ids")..., - ), - }, - ), - ImageProcessor: newImageProcessor(c), - VisionModel: newVisionModel(c), - TextModel: newTextModel(c), - MultiModalProjector: &MultiModalProjector{ - tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)), - }, - } - - // slidingWindowLen := int32(c.Uint("attention.sliding_window")) - // m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift)) - - // TODO need to implement sliding window... - m.Cache = kvcache.NewMLXCausalCache() - - return &m, nil -} - -func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) { - if len(m.VisionModel.Layers) == 0 { - return nil, model.ErrNoVisionModel - } - - image, _, err := image.Decode(bytes.NewReader(multimodalData)) - if err != nil { - return nil, err - } - - f32s, err := m.ImageProcessor.ProcessImage(image) - if err != nil { - return nil, err - } - - pixelValues := ctx.Input().FromFloats(f32s, - m.ImageProcessor.imageSize, - m.ImageProcessor.imageSize, - m.ImageProcessor.numChannels, - ) - - visionOutputs := m.VisionModel.Forward(ctx, pixelValues) - visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps) - return []input.Multimodal{{Tensor: visionOutputs}}, nil -} - -func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { - var result []*input.Input - - for _, inp := range inputs { - if len(inp.Multimodal) == 0 { - result = append(result, inp) - } else { - inputMultimodal := inp.Multimodal[0].Tensor - - result = append(result, - &input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n" - &input.Input{Token: 255999}, // """ - &input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder - ) - - // add image token placeholders - result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...) - - result = append(result, - &input.Input{Token: 256000}, // - &input.Input{Token: 108}, // "\n\n" - ) - } - } - - return result, nil -} - -func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) - return m.Output.Forward(ctx, hiddenStates), nil -} - -func init() { - model.Register("gemma3", New) - model.Register("gemma3_embed", newEmbedModel) -} diff --git a/x/model/models/gemma3/model_text.go b/x/model/models/gemma3/model_text.go deleted file mode 100644 index d7686542a..000000000 --- a/x/model/models/gemma3/model_text.go +++ /dev/null @@ -1,211 +0,0 @@ -//go:build mlx - -package gemma3 - -import ( - "math" - - "github.com/ollama/ollama/fs" - "github.com/ollama/ollama/x/kvcache" - "github.com/ollama/ollama/x/ml" - "github.com/ollama/ollama/x/ml/nn" - "github.com/ollama/ollama/x/model/input" -) - -type TextConfig struct { - hiddenSize, numHeads, numKVHeads int - attnKeyLen int - eps, ropeScale float32 - ropeLocalBase, ropeGlobalBase float32 - largeModelScaling bool -} - -type TextModel struct { - TokenEmbedding *nn.Embedding `gguf:"embed_tokens"` - Layers []TextLayer `gguf:"layers"` - OutputNorm *nn.RMSNorm `gguf:"norm"` - Output *nn.Linear `gguf:"embed_tokens"` - - *TextConfig -} - -const ( - gemmaGlobalCacheCount = 6 - gemma27BLayerCount = 62 -) - -// const ( -// cacheTypeSWA = iota -// cacheTypeCausal -// ) - -func newTextModel(c fs.Config) *TextModel { - numBlocks := int(c.Uint("block_count")) - - m := TextModel{ - Layers: make([]TextLayer, numBlocks), - TextConfig: &TextConfig{ - hiddenSize: int(c.Uint("embedding_length")), // 2560 -- config.json: text_config.hidden_size - numHeads: int(c.Uint("attention.head_count")), // 8 -- hard coded in python implementation for the model, 4 in some places, then overridden as 8 - numKVHeads: int(c.Uint("attention.head_count_kv")), // 4 -- same as above - attnKeyLen: int(c.Uint("attention.key_length", 256)), //256 -- rope settings, hardcoded in model definition python - eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), // 1e-06 - hardcoded in model definition python - ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), // 10000 - hardcoded in python - ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), // 1e+06 - hardcoded in python - ropeScale: 1, // 1 - default is 1, implied in python code - // vocabSize: vocabSize, // 262144 - // attnValLen: int(c.Uint("attention.value_length", 256)), //256 - // NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights - // (8 instead of 1) - // ropeScale: c.Float("rope.scaling.factor", 1.0), - }, - } - if numBlocks == gemma27BLayerCount { - m.largeModelScaling = true - } - - return &m -} - -type TextSelfAttention struct { - Query *nn.Linear `gguf:"q_proj"` - QueryNorm *nn.RMSNorm `gguf:"q_norm"` - Key *nn.Linear `gguf:"k_proj"` - KeyNorm *nn.RMSNorm `gguf:"k_norm"` - Value *nn.Linear `gguf:"v_proj"` - Output *nn.Linear `gguf:"o_proj"` -} - -func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor { - B := hiddenState.Dim(0) - L := hiddenState.Dim(1) - ropeBase := opts.ropeLocalBase - if (layer+1)%gemmaGlobalCacheCount == 0 { - ropeBase = opts.ropeGlobalBase - } - - q := sa.Query.Forward(ctx, hiddenState) - k := sa.Key.Forward(ctx, hiddenState) - v := sa.Value.Forward(ctx, hiddenState) - q = q.Reshape(ctx, B, L, opts.numHeads, -1).Transpose(ctx, 0, 2, 1, 3) - k = k.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3) - v = v.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false) - q = sa.QueryNorm.Forward(ctx, q, opts.eps) - k = sa.KeyNorm.Forward(ctx, k, opts.eps) - traditional := false - q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase)) - k = k.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase)) - - // TODO - this is wrong somehow so commenting out - // if opts.largeModelScaling { - // q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) - // } else { - // q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen))) - // } - - scaleFactor := math.Pow(256, -0.5) - - kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) - kqv = kqv.Transpose(ctx, 0, 2, 1, 3).Reshape(ctx, B, L, -1) - return sa.Output.Forward(ctx, kqv) -} - -func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - // ropeBase := m.TextConfig.ropeLocalBase - // if (layer+1)%gemmaGlobalCacheCount == 0 { - // ropeBase = m.TextConfig.ropeGlobalBase - // } - // q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase)) - panic("not yet implemented") - // return key.RoPE(ctx, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil -} - -type TextMLP struct { - Up *nn.Linear `gguf:"up_proj"` - Down *nn.Linear `gguf:"down_proj"` - Gate *nn.Linear `gguf:"gate_proj"` -} - -func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState)) - return mlp.Down.Forward(ctx, hiddenState) -} - -type TextLayer struct { - AttentionNorm *nn.RMSNorm `gguf:"input_layernorm"` - SelfAttention *TextSelfAttention `gguf:"self_attn"` - PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_layernorm"` - MLPNorm *nn.RMSNorm `gguf:"pre_feedforward_layernorm"` - MLP *TextMLP `gguf:"mlp"` - PostMLPNorm *nn.RMSNorm `gguf:"post_feedforward_layernorm"` -} - -func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, outputs ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor { - residual := hiddenState - hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) - hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, offset, cache, opts) - hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps) - - // In the final layer (outputs != nil), optimize by pruning to just the token positions - // we need logits for. - if outputs != nil { - hiddenState = hiddenState.TakeAxes(ctx, outputs, 1) - residual = residual.TakeAxes(ctx, outputs, 1) - } - - hiddenState = hiddenState.Add(ctx, residual) - residual = hiddenState - hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps) - hiddenState = l.MLP.Forward(ctx, hiddenState, opts) // TODO this is where it goes bad most likely... - hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps) - return hiddenState.Add(ctx, residual) -} - -func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor { - hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) - hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize))) - - // set image embeddings - // var except []int - // for _, image := range batch.Multimodal { - // visionOutputs := image.Multimodal[0].Tensor - // ctx.Forward(visionOutputs.Copy(ctx, hiddenState.AsStrided(ctx, - // []int{visionOutputs.Dim(0) * visionOutputs.Dim(1)}, - // []int{image.Index * hiddenState.Stride(1)}, 0))) - - // for i := range visionOutputs.Dim(1) { - // except = append(except, image.Index+i) - // } - // } - - for i, layer := range m.Layers { - // gemma alternates between the sliding window (local) and causal (global) - // kv cache every 6 layers - if cache != nil { - // cacheType := cacheTypeSWA - // if (i+1)%gemmaGlobalCacheCount == 0 { - // cacheType = cacheTypeCausal - // } - cache.SetLayer(i) - - // TODO this needs to come back - // wc := cache.(*kvcache.WrapperCache) - // wc.SetLayerType(cacheType) - - // if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok { - // causal.SetCausal(ctx, kvcache.CausalOptions{Except: except}) - // } - } - - var offset int - var lastLayerOutputs ml.Tensor - if i == len(m.Layers)-1 { - offset = batch.Offset - lastLayerOutputs = batch.Outputs - } - - hiddenState = layer.Forward(ctx, i, hiddenState, lastLayerOutputs, offset, cache, m.TextConfig) - } - hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) - return hiddenState -} diff --git a/x/model/models/gemma3/model_vision.go b/x/model/models/gemma3/model_vision.go deleted file mode 100644 index bffb3cb58..000000000 --- a/x/model/models/gemma3/model_vision.go +++ /dev/null @@ -1,121 +0,0 @@ -//go:build mlx - -package gemma3 - -import ( - "math" - - "github.com/ollama/ollama/fs" - "github.com/ollama/ollama/x/ml" - "github.com/ollama/ollama/x/ml/nn" -) - -var batchSize int = 1 - -type VisionSelfAttention struct { - Query *nn.Linear `gguf:"self_attn.q_proj"` - Key *nn.Linear `gguf:"self_attn.k_proj"` - Value *nn.Linear `gguf:"self_attn.v_proj"` - Output *nn.Linear `gguf:"self_attn.out_proj"` -} - -func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor { - headDim := opts.hiddenSize / opts.numHeads - - query := sa.Query.Forward(ctx, hiddenState) - key := sa.Key.Forward(ctx, hiddenState) - value := sa.Value.Forward(ctx, hiddenState) - - query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize) - key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize) - value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize) - - attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil) - attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) - - hiddenState = sa.Output.Forward(ctx, attention) - return hiddenState -} - -type VisionMLP struct { - FC1 *nn.Linear `gguf:"fc1"` - FC2 *nn.Linear `gguf:"fc2"` -} - -func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor { - hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx) - hiddenState = mlp.FC2.Forward(ctx, hiddenState) - return hiddenState -} - -type VisionEncoderLayer struct { - LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"` - SelfAttention *VisionSelfAttention - - LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"` - MLP *VisionMLP `gguf:"mlp"` -} - -func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor { - residual := hiddenState - - // self attention - hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps) - hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts) - hiddenState = hiddenState.Add(ctx, residual) - residual = hiddenState - - // feed forward - hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps) - hiddenState = e.MLP.Forward(ctx, hiddenState, opts) - return hiddenState.Add(ctx, residual) -} - -type VisionModelOptions struct { - hiddenSize, numHeads int - imageSize, patchSize int - eps float32 -} - -type VisionModel struct { - PatchEmbedding *nn.Conv2D `gguf:"embeddings.patch_embedding"` - PositionEmbedding *nn.Embedding `gguf:"embeddings.position_embedding"` - PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"` - - Layers []VisionEncoderLayer `gguf:"encoder.layers"` - - *VisionModelOptions -} - -func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { - numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize) - - hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1) - hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize) - hiddenState = hiddenState.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false) - - positionIDs := ctx.Arange(0, float32(numPatches), 1, ml.DTypeInt32) - hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs)) - - for _, layer := range m.Layers { - hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions) - } - - hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps) - return hiddenState -} - -func newVisionModel(c fs.Config) *VisionModel { - return &VisionModel{ - Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")), - VisionModelOptions: &VisionModelOptions{ - hiddenSize: int(c.Uint("vision.embedding_length")), - numHeads: int(c.Uint("vision.attention.head_count")), - - imageSize: int(c.Uint("vision.image_size")), - patchSize: int(c.Uint("vision.patch_size")), - - eps: c.Float("vision.attention.layer_norm_epsilon"), - }, - } -} diff --git a/x/model/models/gemma3/process_image.go b/x/model/models/gemma3/process_image.go deleted file mode 100644 index 09d0727d0..000000000 --- a/x/model/models/gemma3/process_image.go +++ /dev/null @@ -1,60 +0,0 @@ -//go:build mlx - -package gemma3 - -import ( - "image" - - "github.com/ollama/ollama/fs" - "github.com/ollama/ollama/model/imageproc" -) - -type ImageProcessor struct { - imageSize, patchSize, numChannels int -} - -func newImageProcessor(c fs.Config) ImageProcessor { - return ImageProcessor{ - imageSize: int(c.Uint("vision.image_size")), - patchSize: int(c.Uint("vision.patch_size")), - numChannels: int(c.Uint("vision.num_channels")), - } -} - -func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 { - var pixelVals, rVals, gVals, bVals []float32 - - bounds := img.Bounds() - for y := bounds.Min.Y; y < bounds.Max.Y; y++ { - for x := bounds.Min.X; x < bounds.Max.X; x++ { - c := img.At(x, y) - r, g, b, _ := c.RGBA() - rVal := float32(r>>8) / 255.0 - gVal := float32(g>>8) / 255.0 - bVal := float32(b>>8) / 255.0 - - rVal = (rVal - mean[0]) / std[0] - gVal = (gVal - mean[1]) / std[1] - bVal = (bVal - mean[2]) / std[2] - - rVals = append(rVals, rVal) - gVals = append(gVals, gVal) - bVals = append(bVals, bVal) - } - } - - pixelVals = append(pixelVals, rVals...) - pixelVals = append(pixelVals, gVals...) - pixelVals = append(pixelVals, bVals...) - - return pixelVals -} - -func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) { - outputSize := image.Point{p.imageSize, p.imageSize} - newImage := imageproc.Composite(img) - newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear) - - data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD) - return data, nil -} diff --git a/x/model/models/models.go b/x/model/models/models.go deleted file mode 100644 index a2542707f..000000000 --- a/x/model/models/models.go +++ /dev/null @@ -1,3 +0,0 @@ -package models - -// _ "github.com/ollama/ollama/x/model/models/gemma3" diff --git a/x/model/sentencepiece.go b/x/model/sentencepiece.go deleted file mode 100644 index 2c178ec0c..000000000 --- a/x/model/sentencepiece.go +++ /dev/null @@ -1,249 +0,0 @@ -package model - -import ( - "container/heap" - "fmt" - "log/slog" - "strconv" - "strings" - - "github.com/ollama/ollama/logutil" -) - -const spmWhitespaceSep = "▁" - -type SentencePiece struct { - maxTokenLen int - vocab *Vocabulary -} - -var _ TextProcessor = (*SentencePiece)(nil) - -func (spm SentencePiece) Vocabulary() *Vocabulary { - return spm.vocab -} - -func NewSentencePiece(vocab *Vocabulary) SentencePiece { - logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5]) - - counter := map[int]int{} - var maxTokenLen int - for cnt := range vocab.Types { - switch vocab.Types[cnt] { - case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED: - maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt])) - fallthrough - default: - counter[int(vocab.Types[cnt])] += 1 - } - } - - logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL], - "user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE], - "max token len", maxTokenLen) - - return SentencePiece{ - maxTokenLen: maxTokenLen, - vocab: vocab, - } -} - -func (spm SentencePiece) Is(id int32, special Special) bool { - return spm.vocab.Is(id, special) -} - -func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) { - fragments := []fragment{{value: s}} - for _, special := range spm.vocab.SpecialVocabulary() { - id := spm.vocab.Encode(special) - for i := 0; i < len(fragments); i++ { - frag := fragments[i] - if len(frag.ids) > 0 { - continue - } - - var middle []fragment - switch i := strings.Index(frag.value, special); { - case i < 0: - middle = append(middle, frag) - case i > 0: - middle = append(middle, fragment{value: frag.value[:i]}) - fallthrough - default: - middle = append(middle, fragment{value: special, ids: []int32{id}}) - if rest := frag.value[i+len(special):]; rest != "" { - middle = append(middle, fragment{value: rest}) - } - } - - fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...) - } - } - - var ids []int32 - for _, frag := range fragments { - if len(frag.ids) > 0 { - ids = append(ids, frag.ids...) - continue - } - - text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep) - - if id := spm.vocab.Encode(text); id >= 0 { - ids = append(ids, id) - continue - } - - q := &queue{} - heap.Init(q) - - runes := []rune(text) - merges := make([]merge, len(runes)) - for r := range runes { - merges[r] = merge{ - p: r - 1, - n: r + 1, - runes: []rune{runes[r]}, - } - } - - pairwise := func(a, b int) *candidate { - if a < 0 || b >= len(runes) { - return nil - } - - left, right := string(merges[a].runes), string(merges[b].runes) - if id := spm.vocab.Encode(left + right); id >= 0 { - return &candidate{ - a: a, - b: b, - score: spm.vocab.Scores[id], - size: len(left) + len(right), - } - } - - return nil - } - - for i := range len(runes) - 1 { - if pair := pairwise(i, i+1); pair != nil { - heap.Push(q, pair) - } - } - - for q.Len() > 0 { - pair := heap.Pop(q).(*candidate) - left, right := merges[pair.a], merges[pair.b] - - if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size { - continue - } - - merges[pair.a].runes = append(left.runes, right.runes...) - merges[pair.b].runes = nil - merges[pair.a].n = right.n - if right.n < len(merges) { - merges[right.n].p = pair.a - } - - if pair := pairwise(merges[pair.a].p, pair.a); pair != nil { - heap.Push(q, pair) - } - - if pair := pairwise(pair.a, merges[pair.a].n); pair != nil { - heap.Push(q, pair) - } - } - - for _, merge := range merges { - if token := string(merge.runes); token != "" { - id := spm.vocab.Encode(token) - - if id >= 0 { - ids = append(ids, id) - continue - } - - // Fallback to byte tokenization - var result []int32 - for _, b := range []byte(token) { - byteToken := fmt.Sprintf("<0x%02X>", b) - unknownID := spm.vocab.Encode(byteToken) - if unknownID >= 0 { - result = append(result, unknownID) - } else { - slog.Debug("unknown byte token", "byte", b, "token", byteToken) - } - } - - ids = append(ids, result...) - } - } - } - - if addSpecial { - ids = spm.vocab.addSpecials(ids) - } - - logutil.Trace("encoded", "string", s, "ids", ids) - return ids, nil -} - -type candidate struct { - a, b int - score float32 - size int -} - -type queue []*candidate - -func (q queue) Len() int { return len(q) } - -func (q queue) Less(i, j int) bool { - return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a) -} - -func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] } - -func (q *queue) Push(x interface{}) { - item := x.(*candidate) - *q = append(*q, item) -} - -func (q *queue) Pop() interface{} { - old := *q - n := len(old) - item := old[n-1] - *q = old[0 : n-1] - return item -} - -func (spm SentencePiece) Decode(ids []int32) (string, error) { - var sb strings.Builder - for _, id := range ids { - data := spm.vocab.Decode(id) - data = strings.ReplaceAll(data, spmWhitespaceSep, " ") - - // For tokenizers that use byte tokens like "<0xEA>" - // convert them to the partial unicode character - // so they are buffered correctly by the runner instead - // of being sent back to the api as "<0xEA>" - if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") { - byteVal, err := strconv.ParseUint(data[1:5], 0, 8) - if err != nil { - return "", fmt.Errorf("failed to parse hex byte: %v", err) - } - - if err := sb.WriteByte(byte(byteVal)); err != nil { - return "", err - } - } else { - if _, err := sb.WriteString(data); err != nil { - return "", err - } - } - } - - logutil.Trace("decoded", "ids", ids, "string", sb.String()) - return sb.String(), nil -} diff --git a/x/model/sentencepiece_test.go b/x/model/sentencepiece_test.go deleted file mode 100644 index 7ab158af7..000000000 --- a/x/model/sentencepiece_test.go +++ /dev/null @@ -1,172 +0,0 @@ -package model - -import ( - "log/slog" - "os" - "path/filepath" - "slices" - "testing" - - "google.golang.org/protobuf/proto" - - "github.com/ollama/ollama/convert/sentencepiece" -) - -func loadSentencePieceVocab(t *testing.T) SentencePiece { - t.Helper() - - bts, err := os.ReadFile(filepath.Join("..", "..", "model", "testdata", "gemma2", "tokenizer.model")) - if err != nil { - t.Fatal(err) - } - - var spm sentencepiece.ModelProto - if err := proto.Unmarshal(bts, &spm); err != nil { - t.Fatal(err) - } - - var v Vocabulary - - for _, piece := range spm.GetPieces() { - v.Values = append(v.Values, piece.GetPiece()) - v.Scores = append(v.Scores, piece.GetScore()) - switch t := piece.GetType(); t { - case sentencepiece.ModelProto_SentencePiece_UNKNOWN, - sentencepiece.ModelProto_SentencePiece_CONTROL, - sentencepiece.ModelProto_SentencePiece_UNUSED, - sentencepiece.ModelProto_SentencePiece_BYTE: - v.Types = append(v.Types, int32(t)) - default: - tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL) - // todo parse the special tokens file - // - this will roundtrip correctly but the and - // tokens aren't processed - v.Types = append(v.Types, tt) - } - } - - return NewSentencePiece(&v) -} - -func TestSentencePieceEncode(t *testing.T) { - logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) - slog.SetDefault(logger) - - tokenizer := loadSentencePieceVocab(t) - - t.Run("basic roundtrip", func(t *testing.T) { - t.Parallel() - - cases := []string{ - "hello", - "hello ", - "hello ", - " hello", - " hello ", - " hello ", - "hello world", - "请考试我的软件!12345", - "你好", - "Hello 你好 world!", - "Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?", - "Multilingual: 你好 こんにちは Привет Hola مرحبا", - "Numbers and symbols: 123456789 +- */", - "Special tokens: text ", - "Code snippets: func main() { fmt.Println(\"Hello World\") }", - "Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + - "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + - "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.", - } - - for _, want := range cases { - ids, err := tokenizer.Encode(want, true) - if err != nil { - t.Fatal(err) - } - - if got, err := tokenizer.Decode(ids); err != nil { - t.Fatal(err) - } else if got != want { - t.Errorf("got %q, want %q [%#v]", got, want, ids) - } - } - }) - - t.Run("special tokens", func(t *testing.T) { - type candidate struct { - token string - ids []int32 - } - - cases := []candidate{ - {"", []int32{2}}, - {"", []int32{1}}, - } - - for _, want := range cases { - ids, err := tokenizer.Encode(want.token, true) - if err != nil { - t.Fatal(err) - } - if !slices.Equal(ids, want.ids) { - t.Errorf("got %#v, want %#v", ids, want.ids) - } - } - }) -} - -func TestSentencePieceDecodeByteTokens(t *testing.T) { - vocab := &Vocabulary{ - Values: []string{ - "normal", - "<0xEA>", - "<0x41>", - "<0xC3>", - "<0xA3>", - }, - Types: []int32{ - TOKEN_TYPE_NORMAL, - TOKEN_TYPE_BYTE, - TOKEN_TYPE_BYTE, - TOKEN_TYPE_BYTE, - TOKEN_TYPE_BYTE, - }, - Scores: []float32{0, 0, 0, 0, 0}, - } - - spm := NewSentencePiece(vocab) - - tests := []struct { - name string - ids []int32 - expected string - }{ - { - name: "single byte token", - ids: []int32{1}, - expected: "\xea", - }, - { - name: "ASCII byte token", - ids: []int32{2}, - expected: "A", - }, - { - name: "multiple byte tokens forming UTF-8 character", - ids: []int32{3, 4}, - expected: "ã", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := spm.Decode(tt.ids) - if err != nil { - t.Errorf("failed to decode token IDs %v: %v", tt.ids, err) - } - if result != tt.expected { - t.Errorf("got %q, want %q", result, tt.expected) - } - }) - } -} diff --git a/x/model/textprocessor.go b/x/model/textprocessor.go deleted file mode 100644 index 4a36f2352..000000000 --- a/x/model/textprocessor.go +++ /dev/null @@ -1,17 +0,0 @@ -package model - -const ( - TOKEN_TYPE_NORMAL = iota + 1 - TOKEN_TYPE_UNKNOWN - TOKEN_TYPE_CONTROL - TOKEN_TYPE_USER_DEFINED - TOKEN_TYPE_UNUSED - TOKEN_TYPE_BYTE -) - -type TextProcessor interface { - Encode(s string, addSpecial bool) ([]int32, error) - Decode([]int32) (string, error) - Is(int32, Special) bool - Vocabulary() *Vocabulary -} diff --git a/x/model/vocabulary.go b/x/model/vocabulary.go deleted file mode 100644 index d977c4957..000000000 --- a/x/model/vocabulary.go +++ /dev/null @@ -1,112 +0,0 @@ -package model - -import ( - "log/slog" - "slices" - "sync" -) - -type Special int32 - -const ( - SpecialBOS Special = iota - SpecialEOS -) - -type Vocabulary struct { - Values []string - Types []int32 - Scores []float32 - Merges []string - - BOS, EOS []int32 - AddBOS, AddEOS bool - - specialOnce sync.Once - special []string - - valuesOnce sync.Once - values map[string]int32 - - mergeOnce sync.Once - merge map[string]int32 -} - -func (v *Vocabulary) Is(id int32, special Special) bool { - switch special { - case SpecialBOS: - return slices.Contains(v.BOS, id) - case SpecialEOS: - return slices.Contains(v.EOS, id) - default: - return false - } -} - -func (v *Vocabulary) addSpecials(ids []int32) []int32 { - if v.AddBOS && len(v.BOS) > 0 { - if len(ids) > 0 && slices.Contains(v.BOS, ids[0]) { - slog.Warn("adding bos token to prompt which already has it", "id", v.BOS) - } - - slog.Debug("adding bos token to prompt", "id", v.BOS[0]) - ids = append([]int32{v.BOS[0]}, ids...) - } - - if v.AddEOS && len(v.EOS) > 0 { - if len(ids) > 0 && slices.Contains(v.BOS, ids[len(ids)-1]) { - slog.Warn("adding eos token to prompt which already has it", "id", v.EOS) - } - - slog.Debug("adding eos token to prompt", "id", v.EOS[0]) - ids = append(ids, v.EOS[0]) - } - - return ids -} - -func (v *Vocabulary) Encode(s string) int32 { - v.valuesOnce.Do(func() { - v.values = make(map[string]int32, len(v.Values)) - for i, value := range v.Values { - v.values[value] = int32(i) - } - }) - - if id, ok := v.values[s]; ok { - return id - } - - return -1 -} - -func (v *Vocabulary) Decode(id int32) string { - return v.Values[id] -} - -func (v *Vocabulary) SpecialVocabulary() []string { - v.specialOnce.Do(func() { - for i := range v.Values { - if v.Types[i] == TOKEN_TYPE_CONTROL || v.Types[i] == TOKEN_TYPE_USER_DEFINED { - v.special = append(v.special, v.Values[i]) - } - } - }) - - return v.special -} - -func (v *Vocabulary) Merge(left, right string) int { - v.mergeOnce.Do(func() { - v.merge = make(map[string]int32, len(v.Merges)) - for i, merge := range v.Merges { - v.merge[merge] = int32(i) - } - }) - - if id, ok := v.merge[left+" "+right]; ok { - return int(id) - } - - return -1 -} diff --git a/x/model/vocabulary_test.go b/x/model/vocabulary_test.go deleted file mode 100644 index ccfc39e69..000000000 --- a/x/model/vocabulary_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package model - -import ( - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestSpecialVocabulary(t *testing.T) { - vocab := &Vocabulary{ - Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"}, - Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL}, - } - - specialVocab := vocab.SpecialVocabulary() - - if len(specialVocab) != 4 { - t.Errorf("expected 4 special tokens, got %d", len(specialVocab)) - } -} - -func TestAddSpecialVocabulary(t *testing.T) { - cases := []struct { - name string - vocab *Vocabulary - input []int32 - want []int32 - }{ - { - name: "add bos", - vocab: &Vocabulary{ - BOS: []int32{0}, - EOS: []int32{1}, - AddBOS: true, - AddEOS: false, - }, - input: []int32{2, 3, 4}, - want: []int32{0, 2, 3, 4}, - }, - { - // TODO(mxyng): this is to match previous behaviour - name: "add bos when already present", - vocab: &Vocabulary{ - BOS: []int32{0}, - EOS: []int32{1}, - AddBOS: true, - AddEOS: false, - }, - input: []int32{0, 2, 3, 4}, - want: []int32{0, 0, 2, 3, 4}, - }, - { - name: "add eos", - vocab: &Vocabulary{ - BOS: []int32{0}, - EOS: []int32{1}, - AddBOS: false, - AddEOS: true, - }, - input: []int32{2, 3, 4}, - want: []int32{2, 3, 4, 1}, - }, - { - // TODO(mxyng): this is to match previous behaviour - name: "add eos when already present", - vocab: &Vocabulary{ - BOS: []int32{0}, - EOS: []int32{1}, - AddBOS: false, - AddEOS: true, - }, - input: []int32{2, 3, 4, 1}, - want: []int32{2, 3, 4, 1, 1}, - }, - { - name: "add both", - vocab: &Vocabulary{ - BOS: []int32{0}, - EOS: []int32{1}, - AddBOS: true, - AddEOS: true, - }, - input: []int32{2, 3, 4}, - want: []int32{0, 2, 3, 4, 1}, - }, - { - name: "add bos to empty inputs", - vocab: &Vocabulary{ - BOS: []int32{0}, - EOS: []int32{1}, - AddBOS: true, - AddEOS: false, - }, - input: []int32{}, - want: []int32{0}, - }, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - got := tt.vocab.addSpecials(tt.input) - if diff := cmp.Diff(tt.want, got); diff != "" { - t.Errorf("no match (-want +got):\n%s", diff) - } - }) - } -} diff --git a/x/model/wordpiece.go b/x/model/wordpiece.go deleted file mode 100644 index e552bce0d..000000000 --- a/x/model/wordpiece.go +++ /dev/null @@ -1,171 +0,0 @@ -package model - -import ( - "fmt" - "iter" - "strings" - "unicode" - - "github.com/ollama/ollama/logutil" -) - -type WordPiece struct { - vocab *Vocabulary - lowercase bool -} - -// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries. -// this differs from original word piece which uses "##" to indicate subwords. -const ggmlPrefix = "▁" - -var wordPieceReplacer = strings.NewReplacer( - " .", ".", - " ?", "?", - " !", "!", - " ,", ",", - " ' ", "'", - " n't", "n't", - " 'm", "'m", - " do not", " don't", - " 's", "'s", - " 've", "'ve", - " 're", "'re", -) - -// Decode implements TextProcessor. -func (wpm WordPiece) Decode(ids []int32) (string, error) { - var sb strings.Builder - for i, id := range ids { - if id < 0 || int(id) >= len(wpm.vocab.Values) { - return "", fmt.Errorf("invalid token id: %d", id) - } - - var separator string - piece := wpm.vocab.Values[id] - if i > 0 && - (strings.HasPrefix(piece, ggmlPrefix) || - (strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) { - separator = " " - } - - sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix))) - } - - return sb.String(), nil -} - -// words splits a string into words, treating CJK characters as separate words. -// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models. -func (wpm WordPiece) words(s string) iter.Seq[string] { - return func(yield func(string) bool) { - runes := make([]rune, 0, len(s)*3) - for _, r := range s { - switch { - case r >= 0x4E00 && r <= 0x9FFF, - r >= 0x3400 && r <= 0x4DBF, - r >= 0x20000 && r <= 0x2A6DF, - r >= 0x2A700 && r <= 0x2B73F, - r >= 0x2B740 && r <= 0x2B81F, - r >= 0x2B820 && r <= 0x2CEAF, - r >= 0xF900 && r <= 0xFAFF, - r >= 0x2F800 && r <= 0x2FA1F: - runes = append(runes, ' ', r, ' ') - default: - runes = append(runes, r) - } - } - - for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) { - // split on but keep punctuation - var start int - for start < len(w) { - end := strings.IndexFunc(w[start:], unicode.IsPunct) - if end < 0 { - end = len(w) - start - } else if end == 0 { - end = 1 - } - - if !yield(w[start : start+end]) { - return - } - - start += end - } - } - } -} - -// Encode implements TextProcessor. -func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) { - var ids []int32 - - // TODO: use [UNK] from config - unk := wpm.vocab.Encode("[UNK]") - for word := range wpm.words(s) { - var start int - var pieces []int32 - for start < len(word) { - end := len(word) - - var piece int32 - for start < end { - subword := word[start:end] - if start == 0 { - subword = ggmlPrefix + subword - } - - if wpm.lowercase { - subword = strings.ToLower(subword) - } - piece = wpm.vocab.Encode(subword) - if piece >= 0 { - break - } - - end-- - } - - if piece < 0 { - // Unknown token - pieces = pieces[:0] - break - } - - pieces = append(pieces, piece) - start = end - } - - if len(pieces) > 0 { - ids = append(ids, pieces...) - } else { - ids = append(ids, unk) - } - } - - if addSpecial { - ids = wpm.vocab.addSpecials(ids) - } - - logutil.Trace("encoded", "string", s, "ids", ids) - return ids, nil -} - -// Is implements TextProcessor. -func (wpm WordPiece) Is(id int32, special Special) bool { - return wpm.vocab.Is(id, special) -} - -// Vocabulary implements TextProcessor. -func (wpm WordPiece) Vocabulary() *Vocabulary { - return wpm.vocab -} - -var _ TextProcessor = (*WordPiece)(nil) - -func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece { - return WordPiece{ - vocab: vocab, - lowercase: lowercase, - } -} diff --git a/x/model/wordpiece_test.go b/x/model/wordpiece_test.go deleted file mode 100644 index c03bb17a7..000000000 --- a/x/model/wordpiece_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package model - -import ( - "slices" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestWordPiece(t *testing.T) { - wpm := NewWordPiece( - &Vocabulary{ - Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"}, - AddBOS: true, - AddEOS: true, - BOS: []int32{1}, - EOS: []int32{2}, - }, - true, // lowercase - ) - - ids, err := wpm.Encode("Hello world!", true) - if err != nil { - t.Fatal(err) - } - - if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" { - t.Errorf("unexpected ids (-want +got):\n%s", diff) - } - - words, err := wpm.Decode(ids) - if err != nil { - t.Fatal(err) - } - - if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" { - t.Errorf("unexpected words (-want +got):\n%s", diff) - } -} - -func TestWordPieceWords(t *testing.T) { - var wpm WordPiece - - basic := slices.Collect(wpm.words("Hey friend! How are you?!?")) - if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" { - t.Errorf("unexpected words (-want +got):\n%s", diff) - } - - chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika")) - if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" { - t.Errorf("unexpected words (-want +got):\n%s", diff) - } -}