mirror of
https://github.com/ollama/ollama.git
synced 2026-04-24 01:35:49 +02:00
Compare commits
21 Commits
parth/samp
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
159821594c | ||
|
|
cbeb2aab4f | ||
|
|
96df15edfc | ||
|
|
c001b98087 | ||
|
|
23fc8e92eb | ||
|
|
4059a297a6 | ||
|
|
66b2539238 | ||
|
|
ef27d52e79 | ||
|
|
b2a465296d | ||
|
|
5d097277ef | ||
|
|
071a9872cb | ||
|
|
0bd0454ea7 | ||
|
|
01aa788722 | ||
|
|
ead27aa9fe | ||
|
|
b816ff86c9 | ||
|
|
e5d84fb90b | ||
|
|
dd66712e31 | ||
|
|
f66216e399 | ||
|
|
f4f0992b6e | ||
|
|
1feff61977 | ||
|
|
5e0b904e88 |
@@ -86,9 +86,9 @@ if(CMAKE_CUDA_COMPILER)
|
|||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a):xnack[+-]$"
|
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a|1200|1201):xnack[+-]$"
|
||||||
CACHE STRING
|
CACHE STRING
|
||||||
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a):xnack[+-]$\"."
|
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a|1200|1201):xnack[+-]$\"."
|
||||||
)
|
)
|
||||||
|
|
||||||
check_language(HIP)
|
check_language(HIP)
|
||||||
@@ -97,7 +97,7 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
|
|
||||||
find_package(hip REQUIRED)
|
find_package(hip REQUIRED)
|
||||||
if(NOT AMDGPU_TARGETS)
|
if(NOT AMDGPU_TARGETS)
|
||||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012])$")
|
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012]|120[01])$")
|
||||||
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
||||||
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@@ -56,7 +56,7 @@
|
|||||||
"name": "ROCm 6",
|
"name": "ROCm 6",
|
||||||
"inherits": [ "ROCm" ],
|
"inherits": [ "ROCm" ],
|
||||||
"cacheVariables": {
|
"cacheVariables": {
|
||||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -285,6 +285,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||||
- [Saddle](https://github.com/jikkuatwork/saddle)
|
- [Saddle](https://github.com/jikkuatwork/saddle)
|
||||||
|
- [TagSpaces](https://www.tagspaces.org) (A platform for file based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
|
||||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||||
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
|
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
|
||||||
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
||||||
@@ -394,6 +395,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||||
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
|
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
|
||||||
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
|
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
|
||||||
|
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
|
||||||
|
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
|
||||||
|
|
||||||
### Cloud
|
### Cloud
|
||||||
|
|
||||||
@@ -433,7 +436,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
|
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
|
||||||
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
|
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
|
||||||
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
|
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
|
||||||
|
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
|
||||||
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
||||||
|
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.
|
||||||
|
|
||||||
### Apple Vision Pro
|
### Apple Vision Pro
|
||||||
|
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ func GetCPUDetails() ([]CPU, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer file.Close()
|
||||||
return linuxCPUDetails(file)
|
return linuxCPUDetails(file)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,13 +169,11 @@ func linuxCPUDetails(file io.Reader) ([]CPU, error) {
|
|||||||
for id, s := range socketByID {
|
for id, s := range socketByID {
|
||||||
s.CoreCount = len(coreBySocket[id])
|
s.CoreCount = len(coreBySocket[id])
|
||||||
s.ThreadCount = 0
|
s.ThreadCount = 0
|
||||||
for _, tc := range threadsByCoreBySocket[id] {
|
|
||||||
s.ThreadCount += tc
|
|
||||||
}
|
|
||||||
|
|
||||||
// This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons?
|
// This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons?
|
||||||
efficiencyCoreCount := 0
|
efficiencyCoreCount := 0
|
||||||
for _, threads := range threadsByCoreBySocket[id] {
|
for _, threads := range threadsByCoreBySocket[id] {
|
||||||
|
s.ThreadCount += threads
|
||||||
if threads == 1 {
|
if threads == 1 {
|
||||||
efficiencyCoreCount++
|
efficiencyCoreCount++
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,13 @@ Please refer to the [GPU docs](./gpu.md).
|
|||||||
|
|
||||||
## How can I specify the context window size?
|
## How can I specify the context window size?
|
||||||
|
|
||||||
By default, Ollama uses a context window size of 2048 tokens. This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context length to 8K, use: `OLLAMA_CONTEXT_LENGTH=8192 ollama serve`.
|
By default, Ollama uses a context window size of 2048 tokens.
|
||||||
|
|
||||||
|
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
OLLAMA_CONTEXT_LENGTH=8192 ollama serve
|
||||||
|
```
|
||||||
|
|
||||||
To change this when using `ollama run`, use `/set parameter`:
|
To change this when using `ollama run`, use `/set parameter`:
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
|
|||||||
On **Linux** systems with systemd, the logs can be found with this command:
|
On **Linux** systems with systemd, the logs can be found with this command:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
journalctl -u ollama --no-pager
|
journalctl -u ollama --no-pager --follow --pager-end
|
||||||
```
|
```
|
||||||
|
|
||||||
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
|
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
|
||||||
|
|||||||
@@ -413,7 +413,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
|||||||
}, offset, nil
|
}, offset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
|
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||||
embedding := f.KV().EmbeddingLength()
|
embedding := f.KV().EmbeddingLength()
|
||||||
heads := f.KV().HeadCount()
|
heads := f.KV().HeadCount()
|
||||||
headsKV := f.KV().HeadCountKV()
|
headsKV := f.KV().HeadCountKV()
|
||||||
@@ -426,7 +426,10 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
|||||||
layers := f.Tensors().GroupLayers()
|
layers := f.Tensors().GroupLayers()
|
||||||
|
|
||||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||||
kv = uint64(float64(context*f.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
kv = make([]uint64, f.KV().BlockCount())
|
||||||
|
for i := range kv {
|
||||||
|
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||||
|
}
|
||||||
|
|
||||||
switch f.KV().Architecture() {
|
switch f.KV().Architecture() {
|
||||||
case "llama":
|
case "llama":
|
||||||
@@ -460,16 +463,14 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
|||||||
case "mllama":
|
case "mllama":
|
||||||
var visionTokens, tiles uint64 = 1601, 4
|
var visionTokens, tiles uint64 = 1601, 4
|
||||||
|
|
||||||
if crossAttentionLayers, ok := f.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
|
crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers")
|
||||||
kv = headsKV *
|
for i := range kv {
|
||||||
(embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
|
if slices.Contains(crossAttentionLayers, uint32(i)) {
|
||||||
(2* // sizeof(float16)
|
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
|
||||||
(f.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
|
|
||||||
context +
|
|
||||||
4 * // sizeof(float32)
|
4 * // sizeof(float32)
|
||||||
uint64(crossAttentionLayers.size)* // num cross attention layers
|
|
||||||
visionTokens *
|
visionTokens *
|
||||||
tiles)
|
tiles
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
@@ -505,6 +506,20 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
|||||||
4*embeddingHeadsK*context*8+
|
4*embeddingHeadsK*context*8+
|
||||||
embedding*embeddingHeadsK*heads*9/16,
|
embedding*embeddingHeadsK*heads*9/16,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
|
||||||
|
// engine. Gemma3 always uses the Ollama engine.
|
||||||
|
if f.KV().Architecture() == "gemma3" {
|
||||||
|
const gemma3GlobalCacheCount = 6
|
||||||
|
slidingWindow := (uint64(numParallel) * uint64(f.KV().Uint("attention.sliding_window"))) + batch
|
||||||
|
for i := range kv {
|
||||||
|
// Every 6th layer is a global layer, which is the full context size that has already been set. The other
|
||||||
|
// layers are the smaller local (sliding) layers.
|
||||||
|
if (i+1)%gemma3GlobalCacheCount != 0 {
|
||||||
|
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
case "command-r":
|
case "command-r":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(embedding+vocab),
|
4*batch*(embedding+vocab),
|
||||||
|
|||||||
@@ -119,10 +119,10 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
|||||||
}
|
}
|
||||||
|
|
||||||
var cacheSize int
|
var cacheSize int
|
||||||
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch {
|
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize) {
|
||||||
cacheSize = maxSequences * capacity
|
cacheSize = maxSequences * capacity
|
||||||
} else {
|
} else {
|
||||||
cacheSize = maxSequences * (int(c.windowSize) + maxBatch)
|
cacheSize = (maxSequences * int(c.windowSize)) + maxBatch
|
||||||
}
|
}
|
||||||
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||||
c.cells = make([]cacheCell, cacheSize)
|
c.cells = make([]cacheCell, cacheSize)
|
||||||
|
|||||||
@@ -362,7 +362,6 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *testContext) Input() ml.Context { return c }
|
func (c *testContext) Input() ml.Context { return c }
|
||||||
func (c *testContext) Output() ml.Context { return c }
|
|
||||||
func (c *testContext) Layer(int) ml.Context { return c }
|
func (c *testContext) Layer(int) ml.Context { return c }
|
||||||
|
|
||||||
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||||
@@ -463,7 +462,7 @@ func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0
|
|||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
|
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -166,6 +166,10 @@ func (c *Context) KvCacheDefrag() {
|
|||||||
C.llama_kv_cache_defrag(c.c)
|
C.llama_kv_cache_defrag(c.c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Context) KvCacheCanShift() bool {
|
||||||
|
return bool(C.llama_kv_cache_can_shift(c.c))
|
||||||
|
}
|
||||||
|
|
||||||
// Get the embeddings for a sequence id
|
// Get the embeddings for a sequence id
|
||||||
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
|
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
|
||||||
e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
|
e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
|
||||||
|
|||||||
103
llama/patches/0022-add-rdna4-support.patch
Normal file
103
llama/patches/0022-add-rdna4-support.patch
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Saman <saman.khatir@amd.com>
|
||||||
|
Date: Wed, 19 Mar 2025 14:02:26 -0700
|
||||||
|
Subject: [PATCH] add rdna4 support
|
||||||
|
|
||||||
|
---
|
||||||
|
ggml/src/ggml-cuda/common.cuh | 6 ++++--
|
||||||
|
ggml/src/ggml-cuda/mmq.cu | 2 +-
|
||||||
|
ggml/src/ggml-cuda/mmq.cuh | 4 ++--
|
||||||
|
ggml/src/ggml-cuda/mmvq.cu | 4 ++--
|
||||||
|
ggml/src/ggml-cuda/vendors/hip.h | 4 ++++
|
||||||
|
5 files changed, 13 insertions(+), 7 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
|
||||||
|
index adf0d3ec..b24593fc 100644
|
||||||
|
--- a/ggml/src/ggml-cuda/common.cuh
|
||||||
|
+++ b/ggml/src/ggml-cuda/common.cuh
|
||||||
|
@@ -61,11 +61,13 @@
|
||||||
|
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
|
||||||
|
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
||||||
|
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
||||||
|
+#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
|
||||||
|
|
||||||
|
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
||||||
|
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
||||||
|
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
||||||
|
-#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
|
||||||
|
+#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
|
||||||
|
+#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
|
||||||
|
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
|
||||||
|
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
||||||
|
|
||||||
|
@@ -386,7 +388,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
||||||
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
|
||||||
|
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
||||||
|
-#elif defined(RDNA3)
|
||||||
|
+#elif defined(RDNA3) || defined(RDNA4)
|
||||||
|
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
||||||
|
#elif defined(__gfx1010__) || defined(__gfx900__)
|
||||||
|
int tmp1;
|
||||||
|
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
|
||||||
|
index 10f2ebb1..933d945c 100644
|
||||||
|
--- a/ggml/src/ggml-cuda/mmq.cu
|
||||||
|
+++ b/ggml/src/ggml-cuda/mmq.cu
|
||||||
|
@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||||
|
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
|
- return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||||
|
+ return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||||
|
}
|
||||||
|
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
|
||||||
|
index 0451c65f..66ce2bc9 100644
|
||||||
|
--- a/ggml/src/ggml-cuda/mmq.cuh
|
||||||
|
+++ b/ggml/src/ggml-cuda/mmq.cuh
|
||||||
|
@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
|
||||||
|
|
||||||
|
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
||||||
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
-#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||||
|
+#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||||
|
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
||||||
|
-#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||||
|
+#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||||
|
#else
|
||||||
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||||
|
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
||||||
|
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
|
||||||
|
index 4fb466ca..23ae7abc 100644
|
||||||
|
--- a/ggml/src/ggml-cuda/mmvq.cu
|
||||||
|
+++ b/ggml/src/ggml-cuda/mmvq.cu
|
||||||
|
@@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q(
|
||||||
|
|
||||||
|
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
||||||
|
|
||||||
|
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
||||||
|
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
|
||||||
|
constexpr int nwarps = 1;
|
||||||
|
constexpr int rows_per_cuda_block = 1;
|
||||||
|
#else
|
||||||
|
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
||||||
|
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
||||||
|
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
||||||
|
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
|
||||||
|
|
||||||
|
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||||
|
const int row0 = rows_per_cuda_block*blockIdx.x;
|
||||||
|
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
|
||||||
|
index 81964611..a62544b5 100644
|
||||||
|
--- a/ggml/src/ggml-cuda/vendors/hip.h
|
||||||
|
+++ b/ggml/src/ggml-cuda/vendors/hip.h
|
||||||
|
@@ -150,6 +150,10 @@
|
||||||
|
#define CDNA
|
||||||
|
#endif
|
||||||
|
|
||||||
|
+#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||||
|
+#define RDNA4
|
||||||
|
+#endif
|
||||||
|
+
|
||||||
|
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
||||||
|
defined(__gfx1150__) || defined(__gfx1151__)
|
||||||
|
#define RDNA3
|
||||||
@@ -15,12 +15,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// This algorithm looks for a complete fit to determine if we need to unload other models
|
// This algorithm looks for a complete fit to determine if we need to unload other models
|
||||||
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
|
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
|
||||||
// Split up the GPUs by type and try them
|
// Split up the GPUs by type and try them
|
||||||
var estimatedVRAM uint64
|
var estimatedVRAM uint64
|
||||||
for _, gpus := range allGpus.ByLibrary() {
|
for _, gpus := range allGpus.ByLibrary() {
|
||||||
var layerCount int
|
var layerCount int
|
||||||
estimate := EstimateGPULayers(gpus, f, projectors, opts)
|
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
|
||||||
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
|
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
|
||||||
if opts.NumGPU < 0 {
|
if opts.NumGPU < 0 {
|
||||||
if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) {
|
if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) {
|
||||||
@@ -71,7 +71,7 @@ type MemoryEstimate struct {
|
|||||||
|
|
||||||
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
|
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
|
||||||
// The GPUs provided must all be the same Library
|
// The GPUs provided must all be the same Library
|
||||||
func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options) MemoryEstimate {
|
func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate {
|
||||||
// Graph size for a partial offload, applies to all GPUs
|
// Graph size for a partial offload, applies to all GPUs
|
||||||
var graphPartialOffload uint64
|
var graphPartialOffload uint64
|
||||||
|
|
||||||
@@ -137,13 +137,19 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), kvct)
|
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct)
|
||||||
|
|
||||||
// KV is proportional to the number of layers
|
if len(kv) > 0 {
|
||||||
layerSize += kv / f.KV().BlockCount()
|
layerSize += kv[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
var kvTotal uint64
|
||||||
|
for _, kvLayer := range kv {
|
||||||
|
kvTotal += kvLayer
|
||||||
|
}
|
||||||
|
|
||||||
if graphPartialOffload == 0 {
|
if graphPartialOffload == 0 {
|
||||||
graphPartialOffload = f.KV().GQA() * kv / 6
|
graphPartialOffload = f.KV().GQA() * kvTotal / 6
|
||||||
}
|
}
|
||||||
if graphFullOffload == 0 {
|
if graphFullOffload == 0 {
|
||||||
graphFullOffload = graphPartialOffload
|
graphFullOffload = graphPartialOffload
|
||||||
@@ -217,7 +223,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||||||
// Some models have inconsistent layer sizes
|
// Some models have inconsistent layer sizes
|
||||||
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
||||||
layerSize = blk.Size()
|
layerSize = blk.Size()
|
||||||
layerSize += kv / f.KV().BlockCount()
|
layerSize += kv[i]
|
||||||
memoryWeights += blk.Size()
|
memoryWeights += blk.Size()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -315,7 +321,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||||||
layersRequested: opts.NumGPU,
|
layersRequested: opts.NumGPU,
|
||||||
layersModel: int(f.KV().BlockCount()) + 1,
|
layersModel: int(f.KV().BlockCount()) + 1,
|
||||||
availableList: availableList,
|
availableList: availableList,
|
||||||
kv: kv,
|
kv: kvTotal,
|
||||||
allocationsList: allocationsList,
|
allocationsList: allocationsList,
|
||||||
memoryWeights: memoryWeights,
|
memoryWeights: memoryWeights,
|
||||||
memoryLayerOutput: memoryLayerOutput,
|
memoryLayerOutput: memoryLayerOutput,
|
||||||
@@ -374,7 +380,7 @@ func (m MemoryEstimate) LogValue() slog.Value {
|
|||||||
slog.Group(
|
slog.Group(
|
||||||
"weights",
|
"weights",
|
||||||
// memory of the weights
|
// memory of the weights
|
||||||
"total", format.HumanBytes2(m.memoryWeights),
|
"total", format.HumanBytes2(m.memoryWeights+m.memoryLayerOutput),
|
||||||
// memory of repeating layers
|
// memory of repeating layers
|
||||||
"repeating", format.HumanBytes2(m.memoryWeights),
|
"repeating", format.HumanBytes2(m.memoryWeights),
|
||||||
// memory of non-repeating layers
|
// memory of non-repeating layers
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ func TestEstimateGPULayers(t *testing.T) {
|
|||||||
projectors := []string{}
|
projectors := []string{}
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
t.Run("cpu", func(t *testing.T) {
|
t.Run("cpu", func(t *testing.T) {
|
||||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
|
estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
|
||||||
assert.Equal(t, 0, estimate.Layers)
|
assert.Equal(t, 0, estimate.Layers)
|
||||||
assert.Equal(t, uint64(0), estimate.Graph)
|
assert.Equal(t, uint64(0), estimate.Graph)
|
||||||
})
|
})
|
||||||
@@ -112,7 +112,7 @@ func TestEstimateGPULayers(t *testing.T) {
|
|||||||
gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
|
gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
|
||||||
gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
||||||
gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
||||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
|
estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
|
||||||
assert.Equal(t, int(s.expect0+s.expect1), estimate.Layers, "scenario %d: %v", i, s)
|
assert.Equal(t, int(s.expect0+s.expect1), estimate.Layers, "scenario %d: %v", i, s)
|
||||||
assert.Equal(t, fmt.Sprintf("%d,%d", s.expect0, s.expect1), estimate.TensorSplit, "scenario %d: %v", i, s)
|
assert.Equal(t, fmt.Sprintf("%d,%d", s.expect0, s.expect1), estimate.TensorSplit, "scenario %d: %v", i, s)
|
||||||
var layerSums uint64
|
var layerSums uint64
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
gpus = discover.GetCPUInfo()
|
gpus = discover.GetCPUInfo()
|
||||||
}
|
}
|
||||||
|
|
||||||
estimate := EstimateGPULayers(gpus, f, projectors, opts)
|
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
|
||||||
if len(gpus) > 1 || gpus[0].Library != "cpu" {
|
if len(gpus) > 1 || gpus[0].Library != "cpu" {
|
||||||
switch {
|
switch {
|
||||||
case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
|
case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
|
||||||
|
|||||||
@@ -110,16 +110,61 @@ type Context interface {
|
|||||||
MaxGraphNodes() int
|
MaxGraphNodes() int
|
||||||
Close()
|
Close()
|
||||||
|
|
||||||
// Input returns a context appropriate for creating input tensors
|
// Input returns a context appropriate for creating tensors that are
|
||||||
|
// inputs to the model (which includes things like output locations)
|
||||||
Input() Context
|
Input() Context
|
||||||
|
|
||||||
// Output returns a context appropriate for creating output tensors
|
|
||||||
Output() Context
|
|
||||||
|
|
||||||
// Layer returns a context appropriate for creating intermediate tensors
|
// Layer returns a context appropriate for creating intermediate tensors
|
||||||
Layer(int) Context
|
Layer(int) Context
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RopeType represents different RoPE (Rotary Position Embedding) implementation types
|
||||||
|
type RopeType int
|
||||||
|
|
||||||
|
// Available RoPE implementation types
|
||||||
|
const (
|
||||||
|
RopeTypeNormal RopeType = iota // Standard RoPE implementation
|
||||||
|
RopeTypeNeox // NeoX-style RoPE implementation
|
||||||
|
RopeTypeMRoPE // Multimodal RoPE implementation
|
||||||
|
RopeTypeVision // Vision-specific RoPE implementation
|
||||||
|
)
|
||||||
|
|
||||||
|
type YarnConfig struct {
|
||||||
|
YarnCtxTrain int // Context size used during training (for YaRN scaling)
|
||||||
|
YarnExtFactor float32 // Extension factor for YaRN
|
||||||
|
YarnAttnFactor float32 // Attention scaling factor for YaRN
|
||||||
|
YarnBetaFast float32 // Fast decay parameter for YaRN
|
||||||
|
YarnBetaSlow float32 // Slow decay parameter for YaRN
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultYarnConfig returns a default configuration for YaRN (Yet Another Rope Extension)
|
||||||
|
func DefaultYarnConfig(nCtx int32) *YarnConfig {
|
||||||
|
return &YarnConfig{
|
||||||
|
YarnCtxTrain: int(nCtx),
|
||||||
|
YarnExtFactor: 0.0,
|
||||||
|
YarnAttnFactor: 1.0,
|
||||||
|
YarnBetaFast: 32.0,
|
||||||
|
YarnBetaSlow: 1.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoPEConfig holds configuration for Rotary Position Embedding
|
||||||
|
type RoPEConfig struct {
|
||||||
|
// Dim is the dimensionality for applying rotary embeddings
|
||||||
|
Dim uint32
|
||||||
|
|
||||||
|
// Type specifies the RoPE implementation variant
|
||||||
|
Type RopeType
|
||||||
|
|
||||||
|
// Base controls frequency decay for the embeddings
|
||||||
|
Base float32
|
||||||
|
|
||||||
|
// Scale allows scaling the effective context length
|
||||||
|
Scale float32
|
||||||
|
|
||||||
|
*YarnConfig
|
||||||
|
}
|
||||||
|
|
||||||
type Tensor interface {
|
type Tensor interface {
|
||||||
Dim(n int) int
|
Dim(n int) int
|
||||||
Stride(n int) int
|
Stride(n int) int
|
||||||
@@ -143,7 +188,7 @@ type Tensor interface {
|
|||||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||||
|
|
||||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
|
RoPE(ctx Context, positionIDs, ropeFactors Tensor, config RoPEConfig) Tensor
|
||||||
|
|
||||||
Tanh(ctx Context) Tensor
|
Tanh(ctx Context) Tensor
|
||||||
GELU(ctx Context) Tensor
|
GELU(ctx Context) Tensor
|
||||||
|
|||||||
@@ -48,9 +48,6 @@ type Backend struct {
|
|||||||
// input is the backend used for inputs
|
// input is the backend used for inputs
|
||||||
input *C.struct_ggml_backend_buffer_type
|
input *C.struct_ggml_backend_buffer_type
|
||||||
|
|
||||||
// output is the backend used for outputs
|
|
||||||
output *C.struct_ggml_backend_buffer_type
|
|
||||||
|
|
||||||
// layers is the backend used for repeating layers
|
// layers is the backend used for repeating layers
|
||||||
layers map[int]*C.struct_ggml_backend_buffer_type
|
layers map[int]*C.struct_ggml_backend_buffer_type
|
||||||
|
|
||||||
@@ -401,7 +398,6 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
|
|||||||
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
|
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
|
||||||
),
|
),
|
||||||
input: deviceBufferTypes[input.d],
|
input: deviceBufferTypes[input.d],
|
||||||
output: deviceBufferTypes[output.d],
|
|
||||||
layers: func() map[int]*C.struct_ggml_backend_buffer_type {
|
layers: func() map[int]*C.struct_ggml_backend_buffer_type {
|
||||||
m := make(map[int]*C.struct_ggml_backend_buffer_type)
|
m := make(map[int]*C.struct_ggml_backend_buffer_type)
|
||||||
for i, layer := range layers {
|
for i, layer := range layers {
|
||||||
@@ -482,19 +478,6 @@ func (c Context) Input() ml.Context {
|
|||||||
return &c
|
return &c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Context) Output() ml.Context {
|
|
||||||
if c.b.output != nil {
|
|
||||||
return &Context{
|
|
||||||
b: c.b,
|
|
||||||
ctx: c.ctx,
|
|
||||||
buft: c.b.output,
|
|
||||||
maxGraphNodes: c.maxGraphNodes,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &c
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c Context) Layer(i int) ml.Context {
|
func (c Context) Layer(i int) ml.Context {
|
||||||
if buft, ok := c.b.layers[i]; ok {
|
if buft, ok := c.b.layers[i]; ok {
|
||||||
return &Context{
|
return &Context{
|
||||||
@@ -924,6 +907,8 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GGML RoPE types
|
||||||
|
// These are the types used in the C implementation of RoPE
|
||||||
const (
|
const (
|
||||||
ropeTypeNorm C.int = 0
|
ropeTypeNorm C.int = 0
|
||||||
ropeTypeNeox C.int = 2
|
ropeTypeNeox C.int = 2
|
||||||
@@ -931,7 +916,8 @@ const (
|
|||||||
ropeTypeVision C.int = 24
|
ropeTypeVision C.int = 24
|
||||||
)
|
)
|
||||||
|
|
||||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
|
// RoPE applies Rotary Position Embeddings to the tensor
|
||||||
|
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
|
||||||
if ropeFactors == nil {
|
if ropeFactors == nil {
|
||||||
ropeFactors = &Tensor{b: t.b}
|
ropeFactors = &Tensor{b: t.b}
|
||||||
}
|
}
|
||||||
@@ -941,19 +927,41 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
|||||||
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
|
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.YarnConfig == nil {
|
||||||
|
config.YarnConfig = ml.DefaultYarnConfig(131072) // 131072 is the default for LLaMA, so it is common at the time of writing
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map Go RopeType to C implementation constants
|
||||||
|
var ropeTypeC C.int
|
||||||
|
switch config.Type {
|
||||||
|
case ml.RopeTypeNormal:
|
||||||
|
ropeTypeC = ropeTypeNorm
|
||||||
|
case ml.RopeTypeNeox:
|
||||||
|
ropeTypeC = ropeTypeNeox
|
||||||
|
case ml.RopeTypeMRoPE:
|
||||||
|
ropeTypeC = ropeTypeMrope
|
||||||
|
case ml.RopeTypeVision:
|
||||||
|
ropeTypeC = ropeTypeVision
|
||||||
|
default:
|
||||||
|
ropeTypeC = ropeTypeNorm
|
||||||
|
}
|
||||||
|
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
t: C.ggml_rope_ext(
|
t: C.ggml_rope_ext(
|
||||||
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
|
ctx.(*Context).ctx,
|
||||||
C.int(ropeDim),
|
dequant,
|
||||||
C.int(ropeType),
|
positionIDs.(*Tensor).t,
|
||||||
131072, // YaRN n_ctx_train
|
ropeFactors.(*Tensor).t,
|
||||||
C.float(ropeBase),
|
C.int(config.Dim),
|
||||||
C.float(ropeScale),
|
ropeTypeC,
|
||||||
0., // YaRN ext_factor
|
C.int(config.YarnCtxTrain),
|
||||||
1., // YaRN attn_factor
|
C.float(config.Base),
|
||||||
32., // YaRN beta_fast
|
C.float(config.Scale),
|
||||||
1., // YaRN beta_slow
|
C.float(config.YarnExtFactor),
|
||||||
|
C.float(config.YarnAttnFactor),
|
||||||
|
C.float(config.YarnBetaFast),
|
||||||
|
C.float(config.YarnBetaSlow),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,11 +61,13 @@
|
|||||||
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
|
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
|
||||||
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
||||||
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
||||||
|
#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
|
||||||
|
|
||||||
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
||||||
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
||||||
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
||||||
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
|
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
|
||||||
|
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
|
||||||
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
|
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
|
||||||
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
||||||
|
|
||||||
@@ -386,7 +388,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
|||||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||||
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
|
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
|
||||||
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
||||||
#elif defined(RDNA3)
|
#elif defined(RDNA3) || defined(RDNA4)
|
||||||
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
||||||
#elif defined(__gfx1010__) || defined(__gfx900__)
|
#elif defined(__gfx1010__) || defined(__gfx900__)
|
||||||
int tmp1;
|
int tmp1;
|
||||||
|
|||||||
2
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
vendored
2
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
vendored
@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
|||||||
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||||
}
|
}
|
||||||
|
|||||||
4
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
vendored
4
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
vendored
@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
|
|||||||
|
|
||||||
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
||||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||||
#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||||
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
||||||
#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||||
#else
|
#else
|
||||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||||
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
||||||
|
|||||||
4
ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
vendored
4
ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
vendored
@@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q(
|
|||||||
|
|
||||||
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
||||||
|
|
||||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
|
||||||
constexpr int nwarps = 1;
|
constexpr int nwarps = 1;
|
||||||
constexpr int rows_per_cuda_block = 1;
|
constexpr int rows_per_cuda_block = 1;
|
||||||
#else
|
#else
|
||||||
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
||||||
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
||||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
|
||||||
|
|
||||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||||
const int row0 = rows_per_cuda_block*blockIdx.x;
|
const int row0 = rows_per_cuda_block*blockIdx.x;
|
||||||
|
|||||||
@@ -150,6 +150,10 @@
|
|||||||
#define CDNA
|
#define CDNA
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||||
|
#define RDNA4
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
||||||
defined(__gfx1150__) || defined(__gfx1151__)
|
defined(__gfx1150__) || defined(__gfx1151__)
|
||||||
#define RDNA3
|
#define RDNA3
|
||||||
|
|||||||
@@ -13,10 +13,11 @@ import (
|
|||||||
type Options struct {
|
type Options struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
attnKeyLen, attnValLen int
|
attnKeyLen, attnValLen int
|
||||||
eps, ropeBase, ropeScale float32
|
eps float32
|
||||||
attnLogitSoftcap float32
|
attnLogitSoftcap float32
|
||||||
finalLogitSoftcap float32
|
finalLogitSoftcap float32
|
||||||
largeModelScaling bool
|
largeModelScaling bool
|
||||||
|
ropeConfig ml.RoPEConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
@@ -55,10 +56,15 @@ func New(c ml.Config) (model.Model, error) {
|
|||||||
attnKeyLen: int(c.Uint("attention.key_length")),
|
attnKeyLen: int(c.Uint("attention.key_length")),
|
||||||
attnValLen: int(c.Uint("attention.value_length")),
|
attnValLen: int(c.Uint("attention.value_length")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base", 10000.0),
|
|
||||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
|
||||||
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
|
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
|
||||||
finalLogitSoftcap: c.Float("final_logit_softcapping"),
|
finalLogitSoftcap: c.Float("final_logit_softcapping"),
|
||||||
|
ropeConfig: ml.RoPEConfig{
|
||||||
|
Base: c.Float("rope.freq_base", 10000.0),
|
||||||
|
Scale: c.Float("rope.freq_scale", 1.0),
|
||||||
|
Dim: c.Uint("attention.key_length"),
|
||||||
|
Type: ml.RopeTypeNormal,
|
||||||
|
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,11 +84,10 @@ type SelfAttention struct {
|
|||||||
|
|
||||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
ropeType := uint32(2)
|
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||||
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
|
||||||
|
|
||||||
if opts.largeModelScaling {
|
if opts.largeModelScaling {
|
||||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||||
@@ -92,7 +97,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||||
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||||
@@ -122,7 +127,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
|
return key.RoPE(ctx, shift, nil, m.ropeConfig), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
|
|||||||
@@ -13,9 +13,11 @@ import (
|
|||||||
type TextOptions struct {
|
type TextOptions struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
attnKeyLen, attnValLen int
|
attnKeyLen, attnValLen int
|
||||||
eps, ropeScale float32
|
eps float32
|
||||||
ropeLocalBase, ropeGlobalBase float32
|
|
||||||
largeModelScaling bool
|
largeModelScaling bool
|
||||||
|
|
||||||
|
ropeLocalConfig ml.RoPEConfig
|
||||||
|
ropeGlobalConfig ml.RoPEConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextModel struct {
|
type TextModel struct {
|
||||||
@@ -62,9 +64,21 @@ func newTextModel(c ml.Config) *TextModel {
|
|||||||
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
||||||
attnValLen: int(c.Uint("attention.value_length", 256)),
|
attnValLen: int(c.Uint("attention.value_length", 256)),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
|
||||||
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
ropeLocalConfig: ml.RoPEConfig{
|
||||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
Base: c.Float("rope.local.freq_base", 10000.0),
|
||||||
|
Scale: c.Float("rope.freq_scale", 1.0),
|
||||||
|
Dim: c.Uint("attention.key_length", 256),
|
||||||
|
Type: ml.RopeTypeNeox,
|
||||||
|
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||||
|
},
|
||||||
|
ropeGlobalConfig: ml.RoPEConfig{
|
||||||
|
Base: c.Float("rope.global.freq_base", 1000000.0),
|
||||||
|
Scale: c.Float("rope.freq_scale", 1.0),
|
||||||
|
Dim: c.Uint("attention.key_length", 256),
|
||||||
|
Type: ml.RopeTypeNeox,
|
||||||
|
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,17 +100,16 @@ type TextSelfAttention struct {
|
|||||||
|
|
||||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
ropeType := uint32(2)
|
|
||||||
|
|
||||||
ropeBase := opts.ropeLocalBase
|
ropeConfig := opts.ropeLocalConfig
|
||||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||||
ropeBase = opts.ropeGlobalBase
|
ropeConfig = opts.ropeGlobalConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||||
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
q = q.RoPE(ctx, positionIDs, nil, ropeConfig)
|
||||||
|
|
||||||
if opts.largeModelScaling {
|
if opts.largeModelScaling {
|
||||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||||
@@ -107,7 +120,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
|||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||||
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
k = k.RoPE(ctx, positionIDs, nil, ropeConfig)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||||
@@ -120,12 +133,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
ropeBase := m.TextOptions.ropeLocalBase
|
ropeConfig := m.ropeLocalConfig
|
||||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||||
ropeBase = m.TextOptions.ropeGlobalBase
|
ropeConfig = m.ropeGlobalConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
|
return key.RoPE(ctx, shift, nil, ropeConfig), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextMLP struct {
|
type TextMLP struct {
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ import (
|
|||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
eps, ropeBase, ropeScale float32
|
eps float32
|
||||||
ropeDim uint32
|
ropeConfig ml.RoPEConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
@@ -54,9 +54,13 @@ func New(c ml.Config) (model.Model, error) {
|
|||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
ropeConfig: ml.RoPEConfig{
|
||||||
ropeScale: c.Float("rope.freq_scale", 1),
|
Base: c.Float("rope.freq_base"),
|
||||||
ropeDim: c.Uint("rope.dimension_count"),
|
Scale: c.Float("rope.freq_scale", 1),
|
||||||
|
Dim: c.Uint("rope.dimension_count"),
|
||||||
|
Type: ml.RopeTypeNormal,
|
||||||
|
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,15 +80,14 @@ type SelfAttention struct {
|
|||||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
headDim := opts.hiddenSize / opts.numHeads
|
||||||
ropeType := uint32(0)
|
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
@@ -97,7 +100,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
|
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
|
|||||||
@@ -20,15 +20,14 @@ type TextSelfAttention struct {
|
|||||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
headDim := opts.hiddenSize / opts.numHeads
|
||||||
ropeType := uint32(0)
|
|
||||||
|
|
||||||
query := sa.Query.Forward(ctx, hiddenState)
|
query := sa.Query.Forward(ctx, hiddenState)
|
||||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig)
|
||||||
|
|
||||||
key := sa.Key.Forward(ctx, hiddenState)
|
key := sa.Key.Forward(ctx, hiddenState)
|
||||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig)
|
||||||
|
|
||||||
value := sa.Value.Forward(ctx, hiddenState)
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
@@ -43,7 +42,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
|||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
// This will only get called for layers in the cache, which are just the self attention layers
|
// This will only get called for layers in the cache, which are just the self attention layers
|
||||||
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
||||||
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
|
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeConfig), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return key, nil
|
return key, nil
|
||||||
@@ -198,8 +197,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
|
|||||||
|
|
||||||
type TextModelOptions struct {
|
type TextModelOptions struct {
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
eps, ropeBase, ropeScale float32
|
eps float32
|
||||||
ropeDim uint32
|
ropeConfig ml.RoPEConfig
|
||||||
|
|
||||||
crossAttentionLayers []uint32
|
crossAttentionLayers []uint32
|
||||||
}
|
}
|
||||||
@@ -240,10 +239,14 @@ func newTextModel(c ml.Config) *TextModel {
|
|||||||
numHeads: int(c.Uint("attention.head_count")),
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
ropeBase: c.Float("rope.freq_base"),
|
|
||||||
ropeScale: c.Float("rope.freq_scale", 1),
|
|
||||||
ropeDim: c.Uint("rope.dimension_count"),
|
|
||||||
crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
|
crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
|
||||||
|
ropeConfig: ml.RoPEConfig{
|
||||||
|
Base: c.Float("rope.freq_base"),
|
||||||
|
Scale: c.Float("rope.freq_scale", 1),
|
||||||
|
Dim: c.Uint("rope.dimension_count"),
|
||||||
|
Type: ml.RopeTypeNormal,
|
||||||
|
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -213,8 +213,16 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
|
|||||||
return discard
|
return discard
|
||||||
}
|
}
|
||||||
|
|
||||||
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
type ErrReprocessInputs struct {
|
||||||
// the newest half into that space (saving numKeep inputs at the beginning).
|
Inputs []input
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ErrReprocessInputs) Error() string {
|
||||||
|
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShiftCacheSlot frees up space in the KV cache by deleting the oldest half of history
|
||||||
|
// and shifting the newest half into that space (saving numKeep inputs at the beginning).
|
||||||
//
|
//
|
||||||
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
|
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
|
||||||
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
||||||
@@ -222,7 +230,8 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
|||||||
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
|
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
|
inputLen := len(slot.Inputs)
|
||||||
|
discard := c.ShiftDiscard(inputLen, numKeep)
|
||||||
|
|
||||||
if discard <= 0 {
|
if discard <= 0 {
|
||||||
return nil
|
return nil
|
||||||
@@ -231,16 +240,42 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
|||||||
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
||||||
"keep", numKeep, "discard", discard)
|
"keep", numKeep, "discard", discard)
|
||||||
|
|
||||||
// TODO (jessegross): KV cache removal can fail for certain types of models
|
var shiftFailed bool
|
||||||
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
|
|
||||||
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
|
|
||||||
}
|
|
||||||
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
|
|
||||||
|
|
||||||
for i := numKeep + discard; i < len(slot.Inputs); i++ {
|
if c.lc.KvCacheCanShift() {
|
||||||
|
// For models that support shifting, attempt to shift the KV cache
|
||||||
|
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
|
||||||
|
shiftFailed = true
|
||||||
|
slog.Debug("kv cache removal not supported, clearing cache and returning inputs for reprocessing", "id", slot.Id)
|
||||||
|
} else {
|
||||||
|
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, inputLen, -discard)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// For models that don't support shifting
|
||||||
|
shiftFailed = true
|
||||||
|
slog.Debug("kv cache cannot shift, clearing cache and returning inputs for reprocessing", "id", slot.Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
if shiftFailed {
|
||||||
|
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
||||||
|
newInputs := make([]input, numKeep+inputLen-(numKeep+discard))
|
||||||
|
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
||||||
|
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
||||||
|
|
||||||
|
// Clear the entire KV cache
|
||||||
|
_ = c.lc.KvCacheSeqRm(slot.Id, 0, -1)
|
||||||
|
// Reset the slot inputs since we've cleared the cache
|
||||||
|
slot.Inputs = []input{}
|
||||||
|
|
||||||
|
// Return error with inputs that need to be reprocessed
|
||||||
|
return &ErrReprocessInputs{Inputs: newInputs}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Standard shift succeeded - update input array
|
||||||
|
for i := numKeep + discard; i < inputLen; i++ {
|
||||||
slot.Inputs[i-discard] = slot.Inputs[i]
|
slot.Inputs[i-discard] = slot.Inputs[i]
|
||||||
}
|
}
|
||||||
slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
|
slot.Inputs = slot.Inputs[:inputLen-discard]
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -389,8 +389,16 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
if len(seq.pendingInputs) == 0 {
|
if len(seq.pendingInputs) == 0 {
|
||||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
var reprocess *ErrReprocessInputs
|
||||||
|
if errors.As(err, &reprocess) {
|
||||||
|
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||||
|
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||||
|
// Continue processing as normal
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -599,7 +607,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
slog.Info("aborting completion request due to client closing the connection")
|
slog.Info("aborting completion request due to client closing the connection")
|
||||||
} else {
|
} else {
|
||||||
slog.Error("Failed to acquire semaphore", "error", err)
|
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -611,6 +619,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
s.seqsSem.Release(1)
|
||||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -626,6 +635,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
|
s.seqsSem.Release(1)
|
||||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -691,7 +701,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
slog.Info("aborting embeddings request due to client closing the connection")
|
slog.Info("aborting embeddings request due to client closing the connection")
|
||||||
} else {
|
} else {
|
||||||
slog.Error("Failed to acquire semaphore", "error", err)
|
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -703,6 +713,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
s.seqsSem.Release(1)
|
||||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -715,6 +726,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
|
s.seqsSem.Release(1)
|
||||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -239,6 +239,14 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
|||||||
return discard
|
return discard
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ErrReprocessInputs struct {
|
||||||
|
Inputs []input.Input
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ErrReprocessInputs) Error() string {
|
||||||
|
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
|
||||||
|
}
|
||||||
|
|
||||||
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
||||||
// the newest half into that space (saving numKeep inputs at the beginning).
|
// the newest half into that space (saving numKeep inputs at the beginning).
|
||||||
//
|
//
|
||||||
@@ -258,11 +266,23 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
|
|||||||
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
||||||
"keep", numKeep, "discard", discard)
|
"keep", numKeep, "discard", discard)
|
||||||
|
|
||||||
// TODO (jessegross): KV cache removal can fail for certain types of models
|
|
||||||
if c.cache != nil {
|
if c.cache != nil {
|
||||||
err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
|
err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err)
|
slog.Debug("kv cache removal unsupported, clearing cache and returning inputs for reprocessing",
|
||||||
|
"id", slot.Id, "error", err)
|
||||||
|
|
||||||
|
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
||||||
|
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
|
||||||
|
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
||||||
|
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
||||||
|
|
||||||
|
// Reset the cache
|
||||||
|
_ = c.cache.Remove(slot.Id, 0, -1)
|
||||||
|
slot.Inputs = []input.Input{}
|
||||||
|
|
||||||
|
// Return error with inputs that need to be reprocessed
|
||||||
|
return &ErrReprocessInputs{Inputs: newInputs}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
package ollamarunner
|
package ollamarunner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"image"
|
"image"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -425,3 +428,91 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mock implementation of the Cache interface
|
||||||
|
type mockCache struct {
|
||||||
|
shouldFail bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement only the methods needed for the test
|
||||||
|
func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||||
|
if m.shouldFail {
|
||||||
|
return fmt.Errorf("mock cache removal error")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stub implementations for other interface methods
|
||||||
|
func (m *mockCache) SetLayer(layer int) {}
|
||||||
|
func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { return nil, nil, nil }
|
||||||
|
func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {}
|
||||||
|
func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {}
|
||||||
|
func (m *mockCache) Close() {}
|
||||||
|
func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch) error { return nil }
|
||||||
|
func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
|
||||||
|
func (m *mockCache) SetConfig(ml.CacheConfig) {}
|
||||||
|
|
||||||
|
func TestShiftCacheSlot(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
numCtx int32
|
||||||
|
inputs []input.Input
|
||||||
|
numKeep int32
|
||||||
|
cacheErr bool
|
||||||
|
wantErr any
|
||||||
|
wantInputsLen int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Normal shift",
|
||||||
|
numCtx: 10,
|
||||||
|
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||||
|
numKeep: 2,
|
||||||
|
cacheErr: false, // No error
|
||||||
|
wantErr: nil,
|
||||||
|
wantInputsLen: 6, // After discarding 4 tokens
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Cache removal fails",
|
||||||
|
numCtx: 10,
|
||||||
|
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||||
|
numKeep: 2,
|
||||||
|
cacheErr: true,
|
||||||
|
wantErr: &ErrReprocessInputs{},
|
||||||
|
wantInputsLen: 0, // Original inputs should be cleared
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mock := &mockCache{shouldFail: tt.cacheErr}
|
||||||
|
c := InputCache{
|
||||||
|
numCtx: tt.numCtx,
|
||||||
|
cache: mock,
|
||||||
|
}
|
||||||
|
slot := &InputCacheSlot{
|
||||||
|
Id: 123,
|
||||||
|
Inputs: make([]input.Input, len(tt.inputs)),
|
||||||
|
}
|
||||||
|
copy(slot.Inputs, tt.inputs)
|
||||||
|
|
||||||
|
err := c.ShiftCacheSlot(slot, tt.numKeep)
|
||||||
|
|
||||||
|
if tt.wantErr != nil {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error but got nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !errors.As(err, &tt.wantErr) {
|
||||||
|
t.Errorf("Expected error of type %T but got %T: %v", tt.wantErr, err, err)
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(slot.Inputs) != tt.wantInputsLen {
|
||||||
|
t.Errorf("Slot inputs length after operation: got %v, want %v", len(slot.Inputs), tt.wantInputsLen)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -267,6 +267,9 @@ type Server struct {
|
|||||||
// KV cache
|
// KV cache
|
||||||
cache *InputCache
|
cache *InputCache
|
||||||
|
|
||||||
|
// next sequence for prompt processing to avoid starvation
|
||||||
|
nextSeq int
|
||||||
|
|
||||||
// multimodalHash generates hashes for comparing equality
|
// multimodalHash generates hashes for comparing equality
|
||||||
// of non-text data
|
// of non-text data
|
||||||
multimodalHash maphash.Hash
|
multimodalHash maphash.Hash
|
||||||
@@ -351,14 +354,19 @@ func (s *Server) processBatch() error {
|
|||||||
var batchInputs []int32
|
var batchInputs []int32
|
||||||
var batch input.Batch
|
var batch input.Batch
|
||||||
|
|
||||||
for i, seq := range s.seqs {
|
resumeSeq := -1
|
||||||
|
seqIdx := s.nextSeq - 1
|
||||||
|
for range s.seqs {
|
||||||
|
seqIdx = (seqIdx + 1) % len(s.seqs)
|
||||||
|
seq := s.seqs[seqIdx]
|
||||||
|
|
||||||
if seq == nil {
|
if seq == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// if past the num predict limit
|
// if past the num predict limit
|
||||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||||
s.removeSequence(i, "limit")
|
s.removeSequence(seqIdx, "limit")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -369,16 +377,23 @@ func (s *Server) processBatch() error {
|
|||||||
|
|
||||||
batchSize := s.batchSize
|
batchSize := s.batchSize
|
||||||
|
|
||||||
for j, inp := range seq.inputs {
|
for i, inp := range seq.inputs {
|
||||||
// If we are required to put following inputs into a single batch then extend the
|
// If we are required to put following inputs into a single batch then extend the
|
||||||
// batch size. Since we are only extending the size the minimum amount possible, this
|
// batch size. Since we are only extending the size the minimum amount possible, this
|
||||||
// will cause a break if we have pending inputs.
|
// will cause a break if we have existing inputs.
|
||||||
minBatch := 1 + inp.SameBatch
|
minBatch := 1 + inp.SameBatch
|
||||||
if minBatch > batchSize {
|
if minBatch > batchSize {
|
||||||
batchSize = minBatch
|
batchSize = minBatch
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(seq.pendingInputs)+minBatch > batchSize {
|
// Stop if the required batch would put us over the total batch size (including tokens
|
||||||
|
// added by other sequences). If we haven't been able to add anything yet then pick up
|
||||||
|
// here again for the next batch to avoid starvation, though we can opportunistically
|
||||||
|
// check if other sequences can still squeeze something in.
|
||||||
|
if len(batchInputs)+minBatch > batchSize {
|
||||||
|
if len(seq.pendingInputs) == 0 && resumeSeq == -1 {
|
||||||
|
resumeSeq = seqIdx
|
||||||
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -392,9 +407,17 @@ func (s *Server) processBatch() error {
|
|||||||
|
|
||||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
var reprocess *ErrReprocessInputs
|
||||||
|
if errors.As(err, &reprocess) {
|
||||||
|
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||||
|
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||||
|
// Skip this sequence but continue processing the rest
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
batchInputs = append(batchInputs, inp.Token)
|
batchInputs = append(batchInputs, inp.Token)
|
||||||
if inp.Multimodal != nil {
|
if inp.Multimodal != nil {
|
||||||
@@ -405,7 +428,7 @@ func (s *Server) processBatch() error {
|
|||||||
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||||
|
|
||||||
seq.iBatch = len(batch.Outputs)
|
seq.iBatch = len(batch.Outputs)
|
||||||
if j+1 == len(seq.inputs) {
|
if i+1 == len(seq.inputs) {
|
||||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||||
}
|
}
|
||||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||||
@@ -414,6 +437,12 @@ func (s *Server) processBatch() error {
|
|||||||
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if resumeSeq != -1 {
|
||||||
|
s.nextSeq = resumeSeq
|
||||||
|
} else {
|
||||||
|
s.nextSeq = seqIdx + 1
|
||||||
|
}
|
||||||
|
|
||||||
if len(batchInputs) == 0 {
|
if len(batchInputs) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -588,7 +617,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
slog.Info("aborting completion request due to client closing the connection")
|
slog.Info("aborting completion request due to client closing the connection")
|
||||||
} else {
|
} else {
|
||||||
slog.Error("Failed to acquire semaphore", "error", err)
|
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -600,6 +629,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
s.seqsSem.Release(1)
|
||||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -613,6 +643,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
|
s.seqsSem.Release(1)
|
||||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ const maxRetries = 6
|
|||||||
var (
|
var (
|
||||||
errMaxRetriesExceeded = errors.New("max retries exceeded")
|
errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||||
errPartStalled = errors.New("part stalled")
|
errPartStalled = errors.New("part stalled")
|
||||||
|
errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL")
|
||||||
)
|
)
|
||||||
|
|
||||||
var blobDownloadManager sync.Map
|
var blobDownloadManager sync.Map
|
||||||
@@ -236,7 +237,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||||||
|
|
||||||
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||||
if len(via) > 10 {
|
if len(via) > 10 {
|
||||||
return errors.New("maximum redirects exceeded (10) for directURL")
|
return errMaxRedirectsExceeded
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the hostname is the same, allow the redirect
|
// if the hostname is the same, allow the redirect
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ var (
|
|||||||
errCapabilityCompletion = errors.New("completion")
|
errCapabilityCompletion = errors.New("completion")
|
||||||
errCapabilityTools = errors.New("tools")
|
errCapabilityTools = errors.New("tools")
|
||||||
errCapabilityInsert = errors.New("insert")
|
errCapabilityInsert = errors.New("insert")
|
||||||
|
errInsecureProtocol = errors.New("insecure protocol http")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Capability string
|
type Capability string
|
||||||
@@ -479,7 +480,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||||
|
|
||||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||||
return errors.New("insecure protocol http")
|
return errInsecureProtocol
|
||||||
}
|
}
|
||||||
|
|
||||||
manifest, _, err := GetManifest(mp)
|
manifest, _, err := GetManifest(mp)
|
||||||
@@ -543,7 +544,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
}
|
}
|
||||||
|
|
||||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||||
return errors.New("insecure protocol http")
|
return errInsecureProtocol
|
||||||
}
|
}
|
||||||
|
|
||||||
fn(api.ProgressResponse{Status: "pulling manifest"})
|
fn(api.ProgressResponse{Status: "pulling manifest"})
|
||||||
|
|||||||
@@ -421,14 +421,6 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func canRetry(err error) bool {
|
|
||||||
var re *Error
|
|
||||||
if !errors.As(err, &re) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return re.Status >= 500
|
|
||||||
}
|
|
||||||
|
|
||||||
// trackingReader is an io.Reader that tracks the number of bytes read and
|
// trackingReader is an io.Reader that tracks the number of bytes read and
|
||||||
// calls the update function with the layer, the number of bytes read.
|
// calls the update function with the layer, the number of bytes read.
|
||||||
//
|
//
|
||||||
@@ -514,13 +506,40 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cacheKey := fmt.Sprintf(
|
||||||
|
"v1 pull chunksum %s %s %d-%d",
|
||||||
|
l.Digest,
|
||||||
|
cs.Digest,
|
||||||
|
cs.Chunk.Start,
|
||||||
|
cs.Chunk.End,
|
||||||
|
)
|
||||||
|
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
|
||||||
|
_, err := c.Get(cacheKeyDigest)
|
||||||
|
if err == nil {
|
||||||
|
received.Add(cs.Chunk.Size())
|
||||||
|
t.update(l, cs.Chunk.Size(), ErrCached)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
g.Go(func() (err error) {
|
g.Go(func() (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
// Ignore cache key write errors for now. We've already
|
||||||
|
// reported to trace that the chunk is complete.
|
||||||
|
//
|
||||||
|
// Ideally, we should only report completion to trace
|
||||||
|
// after successful cache commit. This current approach
|
||||||
|
// works but could trigger unnecessary redownloads if
|
||||||
|
// the checkpoint key is missing on next pull.
|
||||||
|
//
|
||||||
|
// Not incorrect, just suboptimal - fix this in a
|
||||||
|
// future update.
|
||||||
|
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
|
||||||
|
|
||||||
received.Add(cs.Chunk.Size())
|
received.Add(cs.Chunk.Size())
|
||||||
} else {
|
} else {
|
||||||
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
|
t.update(l, 0, err)
|
||||||
}
|
}
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
@@ -563,7 +582,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if received.Load() != expected {
|
if received.Load() != expected {
|
||||||
return fmt.Errorf("%w: received %d/%d", ErrIncomplete, received.Load(), expected)
|
return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, received.Load(), expected)
|
||||||
}
|
}
|
||||||
|
|
||||||
md := blob.DigestFromBytes(m.Data)
|
md := blob.DigestFromBytes(m.Data)
|
||||||
@@ -608,6 +627,30 @@ func (m *Manifest) Layer(d blob.Digest) *Layer {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manifest) All() iter.Seq[*Layer] {
|
||||||
|
return func(yield func(*Layer) bool) {
|
||||||
|
if !yield(m.Config) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, l := range m.Layers {
|
||||||
|
if !yield(l) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manifest) Size() int64 {
|
||||||
|
var size int64
|
||||||
|
if m.Config != nil {
|
||||||
|
size += m.Config.Size
|
||||||
|
}
|
||||||
|
for _, l := range m.Layers {
|
||||||
|
size += l.Size
|
||||||
|
}
|
||||||
|
return size
|
||||||
|
}
|
||||||
|
|
||||||
// MarshalJSON implements json.Marshaler.
|
// MarshalJSON implements json.Marshaler.
|
||||||
//
|
//
|
||||||
// NOTE: It adds an empty config object to the manifest, which is required by
|
// NOTE: It adds an empty config object to the manifest, which is required by
|
||||||
@@ -750,20 +793,32 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// A chunksums response is a sequence of chunksums in a
|
// The response is a sequence of chunksums.
|
||||||
// simple, easy to parse line-oriented format.
|
|
||||||
//
|
//
|
||||||
// Example:
|
// Chunksums are chunks of a larger blob that can be
|
||||||
|
// downloaded and verified independently.
|
||||||
//
|
//
|
||||||
// >> GET /v2/<namespace>/<model>/chunksums/<digest>
|
// The chunksums endpoint is a GET request that returns a
|
||||||
|
// sequence of chunksums in the following format:
|
||||||
//
|
//
|
||||||
// << HTTP/1.1 200 OK
|
// > GET /v2/<namespace>/<model>/chunksums/<digest>
|
||||||
// << Content-Location: <blobURL>
|
|
||||||
// <<
|
|
||||||
// << <digest> <start>-<end>
|
|
||||||
// << ...
|
|
||||||
//
|
//
|
||||||
// The blobURL is the URL to download the chunks from.
|
// < HTTP/1.1 200 OK
|
||||||
|
// < Content-Location: <blobURL>
|
||||||
|
// <
|
||||||
|
// < <digest> <start>-<end>
|
||||||
|
// < ...
|
||||||
|
//
|
||||||
|
// The <blobURL> is the URL to download the chunks from and
|
||||||
|
// each <digest> is the digest of the chunk, and <start>-<end>
|
||||||
|
// is the range the chunk in the blob.
|
||||||
|
//
|
||||||
|
// Ranges may be used directly in Range headers like
|
||||||
|
// "bytes=<start>-<end>".
|
||||||
|
//
|
||||||
|
// The chunksums returned are guaranteed to be contiguous and
|
||||||
|
// include all bytes of the layer. If the stream is cut short,
|
||||||
|
// clients should retry.
|
||||||
|
|
||||||
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
|
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
|
||||||
scheme,
|
scheme,
|
||||||
|
|||||||
@@ -9,17 +9,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"math/rand/v2"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||||
"github.com/ollama/ollama/server/internal/testutil"
|
"github.com/ollama/ollama/server/internal/testutil"
|
||||||
@@ -338,15 +335,8 @@ func TestPushCommitRoundtripError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkNotExist(t *testing.T, err error) {
|
|
||||||
t.Helper()
|
|
||||||
if !errors.Is(err, fs.ErrNotExist) {
|
|
||||||
t.Fatalf("err = %v; want fs.ErrNotExist", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistryPullInvalidName(t *testing.T) {
|
func TestRegistryPullInvalidName(t *testing.T) {
|
||||||
rc, _ := newClient(t, nil)
|
rc, _ := newRegistryClient(t, nil)
|
||||||
err := rc.Pull(t.Context(), "://")
|
err := rc.Pull(t.Context(), "://")
|
||||||
if !errors.Is(err, ErrNameInvalid) {
|
if !errors.Is(err, ErrNameInvalid) {
|
||||||
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
||||||
@@ -362,197 +352,16 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, resp := range cases {
|
for _, resp := range cases {
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
io.WriteString(w, resp)
|
io.WriteString(w, resp)
|
||||||
})
|
})
|
||||||
err := rc.Pull(t.Context(), "x")
|
err := rc.Pull(t.Context(), "http://example.com/a/b")
|
||||||
if !errors.Is(err, ErrManifestInvalid) {
|
if !errors.Is(err, ErrManifestInvalid) {
|
||||||
t.Errorf("err = %v; want invalid manifest", err)
|
t.Errorf("err = %v; want invalid manifest", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistryPullNotCached(t *testing.T) {
|
|
||||||
check := testutil.Checker(t)
|
|
||||||
|
|
||||||
var c *blob.DiskCache
|
|
||||||
var rc *Registry
|
|
||||||
|
|
||||||
d := blob.DigestFromBytes("some data")
|
|
||||||
rc, c = newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
|
||||||
io.WriteString(w, "some data")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":9}]}`, d)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Confirm that the layer does not exist locally
|
|
||||||
_, err := rc.ResolveLocal("model")
|
|
||||||
checkNotExist(t, err)
|
|
||||||
|
|
||||||
_, err = c.Get(d)
|
|
||||||
checkNotExist(t, err)
|
|
||||||
|
|
||||||
err = rc.Pull(t.Context(), "model")
|
|
||||||
check(err)
|
|
||||||
|
|
||||||
mw, err := rc.Resolve(t.Context(), "model")
|
|
||||||
check(err)
|
|
||||||
mg, err := rc.ResolveLocal("model")
|
|
||||||
check(err)
|
|
||||||
if !reflect.DeepEqual(mw, mg) {
|
|
||||||
t.Errorf("mw = %v; mg = %v", mw, mg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Confirm successful download
|
|
||||||
info, err := c.Get(d)
|
|
||||||
check(err)
|
|
||||||
if info.Digest != d {
|
|
||||||
t.Errorf("info.Digest = %v; want %v", info.Digest, d)
|
|
||||||
}
|
|
||||||
if info.Size != 9 {
|
|
||||||
t.Errorf("info.Size = %v; want %v", info.Size, 9)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := os.ReadFile(c.GetFile(d))
|
|
||||||
check(err)
|
|
||||||
if string(data) != "some data" {
|
|
||||||
t.Errorf("data = %q; want %q", data, "exists")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistryPullCached(t *testing.T) {
|
|
||||||
cached := blob.DigestFromBytes("exists")
|
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
|
||||||
w.WriteHeader(499) // should not be called
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if strings.Contains(r.URL.Path, "/manifests/") {
|
|
||||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, cached)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
var errs []error
|
|
||||||
var reads []int64
|
|
||||||
ctx := WithTrace(t.Context(), &Trace{
|
|
||||||
Update: func(d *Layer, n int64, err error) {
|
|
||||||
t.Logf("update %v %d %v", d, n, err)
|
|
||||||
reads = append(reads, n)
|
|
||||||
errs = append(errs, err)
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
err := rc.Pull(ctx, "single")
|
|
||||||
testutil.Check(t, err)
|
|
||||||
|
|
||||||
want := []int64{0, 6}
|
|
||||||
if !errors.Is(errors.Join(errs...), ErrCached) {
|
|
||||||
t.Errorf("errs = %v; want %v", errs, ErrCached)
|
|
||||||
}
|
|
||||||
if !slices.Equal(reads, want) {
|
|
||||||
t.Errorf("pairs = %v; want %v", reads, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistryPullManifestNotFound(t *testing.T) {
|
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusNotFound)
|
|
||||||
})
|
|
||||||
err := rc.Pull(t.Context(), "notfound")
|
|
||||||
checkErrCode(t, err, 404, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistryPullResolveRemoteError(t *testing.T) {
|
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
|
|
||||||
})
|
|
||||||
err := rc.Pull(t.Context(), "single")
|
|
||||||
checkErrCode(t, err, 500, "an_error")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistryPullResolveRoundtripError(t *testing.T) {
|
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if strings.Contains(r.URL.Path, "/manifests/") {
|
|
||||||
w.WriteHeader(499) // force RoundTrip error
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err := rc.Pull(t.Context(), "single")
|
|
||||||
if !errors.Is(err, errRoundTrip) {
|
|
||||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRegistryPullMixedCachedNotCached tests that cached layers do not
|
|
||||||
// interfere with pulling layers that are not cached
|
|
||||||
func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
|
||||||
x := blob.DigestFromBytes("xxxxxx")
|
|
||||||
e := blob.DigestFromBytes("exists")
|
|
||||||
y := blob.DigestFromBytes("yyyyyy")
|
|
||||||
|
|
||||||
for i := range 10 {
|
|
||||||
t.Logf("iteration %d", i)
|
|
||||||
|
|
||||||
digests := []blob.Digest{x, e, y}
|
|
||||||
|
|
||||||
rand.Shuffle(len(digests), func(i, j int) {
|
|
||||||
digests[i], digests[j] = digests[j], digests[i]
|
|
||||||
})
|
|
||||||
|
|
||||||
manifest := fmt.Sprintf(`{
|
|
||||||
"layers": [
|
|
||||||
{"digest":"%s","size":6},
|
|
||||||
{"digest":"%s","size":6},
|
|
||||||
{"digest":"%s","size":6}
|
|
||||||
]
|
|
||||||
}`, digests[0], digests[1], digests[2])
|
|
||||||
|
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
switch path.Base(r.URL.Path) {
|
|
||||||
case "latest":
|
|
||||||
io.WriteString(w, manifest)
|
|
||||||
case x.String():
|
|
||||||
io.WriteString(w, "xxxxxx")
|
|
||||||
case e.String():
|
|
||||||
io.WriteString(w, "exists")
|
|
||||||
case y.String():
|
|
||||||
io.WriteString(w, "yyyyyy")
|
|
||||||
default:
|
|
||||||
panic(fmt.Sprintf("unexpected request: %v", r))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx := WithTrace(t.Context(), &Trace{
|
|
||||||
Update: func(l *Layer, n int64, err error) {
|
|
||||||
t.Logf("update %v %d %v", l, n, err)
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Check that we pull all layers that we can.
|
|
||||||
|
|
||||||
err := rc.Pull(ctx, "mixed")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, d := range digests {
|
|
||||||
info, err := c.Get(d)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Get(%v): %v", d, err)
|
|
||||||
}
|
|
||||||
if info.Size != 6 {
|
|
||||||
t.Errorf("info.Size = %v; want %v", info.Size, 6)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistryResolveByDigest(t *testing.T) {
|
func TestRegistryResolveByDigest(t *testing.T) {
|
||||||
check := testutil.Checker(t)
|
check := testutil.Checker(t)
|
||||||
|
|
||||||
@@ -590,26 +399,6 @@ func TestInsecureSkipVerify(t *testing.T) {
|
|||||||
testutil.Check(t, err)
|
testutil.Check(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCanRetry(t *testing.T) {
|
|
||||||
cases := []struct {
|
|
||||||
err error
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{nil, false},
|
|
||||||
{errors.New("x"), false},
|
|
||||||
{ErrCached, false},
|
|
||||||
{ErrManifestInvalid, false},
|
|
||||||
{ErrNameInvalid, false},
|
|
||||||
{&Error{Status: 100}, false},
|
|
||||||
{&Error{Status: 500}, true},
|
|
||||||
}
|
|
||||||
for _, tt := range cases {
|
|
||||||
if got := canRetry(tt.err); got != tt.want {
|
|
||||||
t.Errorf("CanRetry(%v) = %v; want %v", tt.err, got, tt.want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestErrorUnmarshal(t *testing.T) {
|
func TestErrorUnmarshal(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -761,17 +550,23 @@ func TestParseNameExtended(t *testing.T) {
|
|||||||
|
|
||||||
func TestUnlink(t *testing.T) {
|
func TestUnlink(t *testing.T) {
|
||||||
t.Run("found by name", func(t *testing.T) {
|
t.Run("found by name", func(t *testing.T) {
|
||||||
rc, _ := newClient(t, nil)
|
check := testutil.Checker(t)
|
||||||
|
|
||||||
|
rc, _ := newRegistryClient(t, nil)
|
||||||
|
// make a blob and link it
|
||||||
|
d := blob.DigestFromBytes("{}")
|
||||||
|
err := blob.PutBytes(rc.Cache, d, "{}")
|
||||||
|
check(err)
|
||||||
|
err = rc.Cache.Link("registry.ollama.ai/library/single:latest", d)
|
||||||
|
check(err)
|
||||||
|
|
||||||
// confirm linked
|
// confirm linked
|
||||||
_, err := rc.ResolveLocal("single")
|
_, err = rc.ResolveLocal("single")
|
||||||
if err != nil {
|
check(err)
|
||||||
t.Errorf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// unlink
|
// unlink
|
||||||
_, err = rc.Unlink("single")
|
_, err = rc.Unlink("single")
|
||||||
testutil.Check(t, err)
|
check(err)
|
||||||
|
|
||||||
// confirm unlinked
|
// confirm unlinked
|
||||||
_, err = rc.ResolveLocal("single")
|
_, err = rc.ResolveLocal("single")
|
||||||
@@ -780,7 +575,7 @@ func TestUnlink(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("not found by name", func(t *testing.T) {
|
t.Run("not found by name", func(t *testing.T) {
|
||||||
rc, _ := newClient(t, nil)
|
rc, _ := newRegistryClient(t, nil)
|
||||||
ok, err := rc.Unlink("manifestNotFound")
|
ok, err := rc.Unlink("manifestNotFound")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -791,78 +586,368 @@ func TestUnlink(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPullChunksums(t *testing.T) {
|
// Many tests from here out, in this file are based on a single blob, "abc",
|
||||||
check := testutil.Checker(t)
|
// with the checksum of its sha256 hash. The checksum is:
|
||||||
|
//
|
||||||
|
// "abc" -> sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad
|
||||||
|
//
|
||||||
|
// Using the literal value instead of a constant with fmt.Xprintf calls proved
|
||||||
|
// to be the most readable and maintainable approach. The sum is consistently
|
||||||
|
// used in the tests and unique so searches do not yield false positives.
|
||||||
|
|
||||||
content := "hello"
|
func checkRequest(t *testing.T, req *http.Request, method, path string) {
|
||||||
var chunksums string
|
t.Helper()
|
||||||
contentDigest := func() blob.Digest {
|
if got := req.URL.Path; got != path {
|
||||||
return blob.DigestFromBytes(content)
|
t.Errorf("URL = %q, want %q", got, path)
|
||||||
|
}
|
||||||
|
if req.Method != method {
|
||||||
|
t.Errorf("Method = %q, want %q", req.Method, method)
|
||||||
}
|
}
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
switch {
|
|
||||||
case strings.Contains(r.URL.Path, "/manifests/latest"):
|
|
||||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
|
|
||||||
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
|
|
||||||
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
|
|
||||||
w.Header().Set("Content-Location", loc)
|
|
||||||
io.WriteString(w, chunksums)
|
|
||||||
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
|
|
||||||
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
|
|
||||||
default:
|
|
||||||
t.Errorf("unexpected request: %v", r)
|
|
||||||
http.NotFound(w, r)
|
|
||||||
}
|
}
|
||||||
})
|
|
||||||
|
|
||||||
rc.MaxStreams = 1 // prevent concurrent chunk downloads
|
func newRegistryClient(t *testing.T, h http.HandlerFunc) (*Registry, context.Context) {
|
||||||
rc.ChunkingThreshold = 1 // for all blobs to be chunked
|
s := httptest.NewServer(h)
|
||||||
|
t.Cleanup(s.Close)
|
||||||
|
cache, err := blob.Open(t.TempDir())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
var mu sync.Mutex
|
|
||||||
var reads []int64
|
|
||||||
ctx := WithTrace(t.Context(), &Trace{
|
ctx := WithTrace(t.Context(), &Trace{
|
||||||
Update: func(l *Layer, n int64, err error) {
|
Update: func(l *Layer, n int64, err error) {
|
||||||
t.Logf("Update: %v %d %v", l, n, err)
|
t.Log("trace:", l.Digest.Short(), n, err)
|
||||||
mu.Lock()
|
|
||||||
reads = append(reads, n)
|
|
||||||
mu.Unlock()
|
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
|
rc := &Registry{
|
||||||
blob.DigestFromBytes("hel"),
|
Cache: cache,
|
||||||
blob.DigestFromBytes("lo"),
|
HTTPClient: &http.Client{Transport: &http.Transport{
|
||||||
)
|
Dial: func(network, addr string) (net.Conn, error) {
|
||||||
err := rc.Pull(ctx, "test")
|
return net.Dial(network, s.Listener.Addr().String())
|
||||||
check(err)
|
},
|
||||||
wantReads := []int64{
|
}},
|
||||||
0, // initial signaling of layer pull starting
|
|
||||||
3, // first chunk read
|
|
||||||
2, // second chunk read
|
|
||||||
}
|
}
|
||||||
if !slices.Equal(reads, wantReads) {
|
return rc, ctx
|
||||||
t.Errorf("reads = %v; want %v", reads, wantReads)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mw, err := rc.Resolve(t.Context(), "test")
|
func TestPullChunked(t *testing.T) {
|
||||||
check(err)
|
var steps atomic.Int64
|
||||||
mg, err := rc.ResolveLocal("test")
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
check(err)
|
switch steps.Add(1) {
|
||||||
if !reflect.DeepEqual(mw, mg) {
|
case 1:
|
||||||
t.Errorf("mw = %v; mg = %v", mw, mg)
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||||
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||||
|
case 2:
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||||
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||||
|
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||||
|
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||||
|
case 3, 4:
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||||
|
switch rng := r.Header.Get("Range"); rng {
|
||||||
|
case "bytes=0-1":
|
||||||
|
io.WriteString(w, "ab")
|
||||||
|
case "bytes=2-2":
|
||||||
|
t.Logf("writing c")
|
||||||
|
io.WriteString(w, "c")
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected range %q", rng)
|
||||||
}
|
}
|
||||||
for i := range mg.Layers {
|
default:
|
||||||
_, err = c.Get(mg.Layers[i].Digest)
|
t.Errorf("unexpected steps %d: %v", steps.Load(), r)
|
||||||
if err != nil {
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||||
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
c.ChunkingThreshold = 1 // force chunking
|
||||||
|
|
||||||
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||||
|
testutil.Check(t, err)
|
||||||
|
|
||||||
|
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||||
|
testutil.Check(t, err)
|
||||||
|
|
||||||
|
if g := steps.Load(); g != 4 {
|
||||||
|
t.Fatalf("got %d steps, want 4", g)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// missing chunks
|
func TestPullCached(t *testing.T) {
|
||||||
content = "llama"
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||||
err = rc.Pull(ctx, "missingchunks")
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||||
|
})
|
||||||
|
|
||||||
|
check := testutil.Checker(t)
|
||||||
|
|
||||||
|
// Premeptively cache the blob
|
||||||
|
d, err := blob.ParseDigest("sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||||
|
check(err)
|
||||||
|
err = blob.PutBytes(c.Cache, d, []byte("abc"))
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
// Pull only the manifest, which should be enough to resolve the cached blob
|
||||||
|
err = c.Pull(ctx, "http://o.com/library/abc")
|
||||||
|
check(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPullManifestError(t *testing.T) {
|
||||||
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
io.WriteString(w, `{"errors":[{"code":"MANIFEST_UNKNOWN"}]}`)
|
||||||
|
})
|
||||||
|
|
||||||
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error because of missing chunks")
|
t.Fatalf("expected error")
|
||||||
|
}
|
||||||
|
var got *Error
|
||||||
|
if !errors.Is(err, ErrModelNotFound) {
|
||||||
|
t.Fatalf("err = %v, want %v", got, ErrModelNotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPullLayerError(t *testing.T) {
|
||||||
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||||
|
io.WriteString(w, `!`)
|
||||||
|
})
|
||||||
|
|
||||||
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error")
|
||||||
|
}
|
||||||
|
var want *json.SyntaxError
|
||||||
|
if !errors.As(err, &want) {
|
||||||
|
t.Fatalf("err = %T, want %T", err, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPullLayerChecksumError(t *testing.T) {
|
||||||
|
var step atomic.Int64
|
||||||
|
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch step.Add(1) {
|
||||||
|
case 1:
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||||
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||||
|
case 2:
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||||
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||||
|
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||||
|
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||||
|
case 3:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
io.WriteString(w, `{"errors":[{"code":"BLOB_UNKNOWN"}]}`)
|
||||||
|
case 4:
|
||||||
|
io.WriteString(w, "c")
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||||
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
c.MaxStreams = 1
|
||||||
|
c.ChunkingThreshold = 1 // force chunking
|
||||||
|
|
||||||
|
var written atomic.Int64
|
||||||
|
ctx := WithTrace(t.Context(), &Trace{
|
||||||
|
Update: func(l *Layer, n int64, err error) {
|
||||||
|
t.Log("trace:", l.Digest.Short(), n, err)
|
||||||
|
written.Add(n)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||||
|
var got *Error
|
||||||
|
if !errors.As(err, &got) || got.Code != "BLOB_UNKNOWN" {
|
||||||
|
t.Fatalf("err = %v, want %v", err, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if g := written.Load(); g != 1 {
|
||||||
|
t.Fatalf("wrote %d bytes, want 1", g)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPullChunksumStreamError(t *testing.T) {
|
||||||
|
var step atomic.Int64
|
||||||
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch step.Add(1) {
|
||||||
|
case 1:
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||||
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||||
|
case 2:
|
||||||
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||||
|
|
||||||
|
// Write one valid chunksum and one invalid chunksum
|
||||||
|
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) // valid
|
||||||
|
fmt.Fprint(w, "sha256:!") // invalid
|
||||||
|
case 3:
|
||||||
|
io.WriteString(w, "ab")
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||||
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
c.ChunkingThreshold = 1 // force chunking
|
||||||
|
|
||||||
|
got := c.Pull(ctx, "http://o.com/library/abc")
|
||||||
|
if !errors.Is(got, ErrIncomplete) {
|
||||||
|
t.Fatalf("err = %v, want %v", got, ErrIncomplete)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type flushAfterWriter struct {
|
||||||
|
w io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *flushAfterWriter) Write(p []byte) (n int, err error) {
|
||||||
|
n, err = f.w.Write(p)
|
||||||
|
f.w.(http.Flusher).Flush() // panic if not a flusher
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPullChunksumStreaming(t *testing.T) {
|
||||||
|
csr, csw := io.Pipe()
|
||||||
|
defer csw.Close()
|
||||||
|
|
||||||
|
var step atomic.Int64
|
||||||
|
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch step.Add(1) {
|
||||||
|
case 1:
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||||
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||||
|
case 2:
|
||||||
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||||
|
fw := &flushAfterWriter{w} // ensure client gets data as it arrives by aggressively flushing
|
||||||
|
_, err := io.Copy(fw, csr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("copy: %v", err)
|
||||||
|
}
|
||||||
|
case 3:
|
||||||
|
io.WriteString(w, "ab")
|
||||||
|
case 4:
|
||||||
|
io.WriteString(w, "c")
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||||
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
c.ChunkingThreshold = 1 // force chunking
|
||||||
|
|
||||||
|
update := make(chan int64, 1)
|
||||||
|
ctx := WithTrace(t.Context(), &Trace{
|
||||||
|
Update: func(l *Layer, n int64, err error) {
|
||||||
|
t.Log("trace:", l.Digest.Short(), n, err)
|
||||||
|
if n > 0 {
|
||||||
|
update <- n
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
errc := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
errc <- c.Pull(ctx, "http://o.com/library/abc")
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Send first chunksum and ensure it kicks off work immediately
|
||||||
|
fmt.Fprintf(csw, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||||
|
if g := <-update; g != 2 {
|
||||||
|
t.Fatalf("got %d, want 2", g)
|
||||||
|
}
|
||||||
|
|
||||||
|
// now send the second chunksum and ensure it kicks off work immediately
|
||||||
|
fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||||
|
if g := <-update; g != 1 {
|
||||||
|
t.Fatalf("got %d, want 1", g)
|
||||||
|
}
|
||||||
|
csw.Close()
|
||||||
|
testutil.Check(t, <-errc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPullChunksumsCached(t *testing.T) {
|
||||||
|
var step atomic.Int64
|
||||||
|
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch step.Add(1) {
|
||||||
|
case 1:
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||||
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||||
|
case 2:
|
||||||
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||||
|
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||||
|
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||||
|
case 3, 4:
|
||||||
|
switch rng := r.Header.Get("Range"); rng {
|
||||||
|
case "bytes=0-1":
|
||||||
|
io.WriteString(w, "ab")
|
||||||
|
case "bytes=2-2":
|
||||||
|
io.WriteString(w, "c")
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected range %q", rng)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||||
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
c.MaxStreams = 1 // force serial processing of chunksums
|
||||||
|
c.ChunkingThreshold = 1 // force chunking
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Cancel the pull after the first chunksum is processed, but before
|
||||||
|
// the second chunksum is processed (which is waiting because
|
||||||
|
// MaxStreams=1). This should cause the second chunksum to error out
|
||||||
|
// leaving the blob incomplete.
|
||||||
|
ctx = WithTrace(ctx, &Trace{
|
||||||
|
Update: func(l *Layer, n int64, err error) {
|
||||||
|
if n > 0 {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("err = %v, want %v", err, context.Canceled)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||||
|
if !errors.Is(err, fs.ErrNotExist) {
|
||||||
|
t.Fatalf("err = %v, want nil", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset state and pull again to ensure the blob chunks that should
|
||||||
|
// have been cached are, and the remaining chunk was downloaded, making
|
||||||
|
// the blob complete.
|
||||||
|
step.Store(0)
|
||||||
|
var written atomic.Int64
|
||||||
|
var cached atomic.Int64
|
||||||
|
ctx = WithTrace(t.Context(), &Trace{
|
||||||
|
Update: func(l *Layer, n int64, err error) {
|
||||||
|
t.Log("trace:", l.Digest.Short(), n, err)
|
||||||
|
if errors.Is(err, ErrCached) {
|
||||||
|
cached.Add(n)
|
||||||
|
}
|
||||||
|
written.Add(n)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
check := testutil.Checker(t)
|
||||||
|
|
||||||
|
err = c.Pull(ctx, "http://o.com/library/abc")
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
if g := written.Load(); g != 3 {
|
||||||
|
t.Fatalf("wrote %d bytes, want 3", g)
|
||||||
|
}
|
||||||
|
if g := cached.Load(); g != 2 { // "ab" should have been cached
|
||||||
|
t.Fatalf("cached %d bytes, want 3", g)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,9 +31,10 @@ const (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
ErrInvalidImageFormat = errors.New("invalid image format")
|
ErrInvalidImageFormat = errors.New("invalid image format")
|
||||||
|
ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||||
ErrInvalidProtocol = errors.New("invalid protocol scheme")
|
ErrInvalidProtocol = errors.New("invalid protocol scheme")
|
||||||
ErrInsecureProtocol = errors.New("insecure protocol http")
|
ErrInsecureProtocol = errors.New("insecure protocol http")
|
||||||
ErrInvalidDigestFormat = errors.New("invalid digest format")
|
ErrModelPathInvalid = errors.New("invalid model path")
|
||||||
)
|
)
|
||||||
|
|
||||||
func ParseModelPath(name string) ModelPath {
|
func ParseModelPath(name string) ModelPath {
|
||||||
@@ -73,8 +74,6 @@ func ParseModelPath(name string) ModelPath {
|
|||||||
return mp
|
return mp
|
||||||
}
|
}
|
||||||
|
|
||||||
var errModelPathInvalid = errors.New("invalid model path")
|
|
||||||
|
|
||||||
func (mp ModelPath) GetNamespaceRepository() string {
|
func (mp ModelPath) GetNamespaceRepository() string {
|
||||||
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
|
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -777,7 +777,7 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
|||||||
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||||
name := model.ParseName(req.Model)
|
name := model.ParseName(req.Model)
|
||||||
if !name.IsValid() {
|
if !name.IsValid() {
|
||||||
return nil, errModelPathInvalid
|
return nil, ErrModelPathInvalid
|
||||||
}
|
}
|
||||||
name, err := getExistingName(name)
|
name, err := getExistingName(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -711,7 +711,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuIn
|
|||||||
req.opts.NumCtx = req.origNumCtx * p
|
req.opts.NumCtx = req.origNumCtx * p
|
||||||
if !envconfig.SchedSpread() {
|
if !envconfig.SchedSpread() {
|
||||||
for _, g := range sgl {
|
for _, g := range sgl {
|
||||||
if ok, estimatedVRAM = llm.PredictServerFit([]discover.GpuInfo{g}, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
if ok, estimatedVRAM = llm.PredictServerFit([]discover.GpuInfo{g}, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p); ok {
|
||||||
slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
|
slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
|
||||||
*numParallel = p
|
*numParallel = p
|
||||||
return []discover.GpuInfo{g}
|
return []discover.GpuInfo{g}
|
||||||
@@ -727,7 +727,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuIn
|
|||||||
// Now try all the GPUs
|
// Now try all the GPUs
|
||||||
for _, p := range numParallelToTry {
|
for _, p := range numParallelToTry {
|
||||||
req.opts.NumCtx = req.origNumCtx * p
|
req.opts.NumCtx = req.origNumCtx * p
|
||||||
if ok, estimatedVRAM = llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
if ok, estimatedVRAM = llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p); ok {
|
||||||
slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM))
|
slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM))
|
||||||
*numParallel = p
|
*numParallel = p
|
||||||
return sgl
|
return sgl
|
||||||
@@ -750,7 +750,7 @@ func pickBestPartialFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.Gp
|
|||||||
var bestEstimate uint64
|
var bestEstimate uint64
|
||||||
var bestFit int
|
var bestFit int
|
||||||
for i, gl := range byLibrary {
|
for i, gl := range byLibrary {
|
||||||
_, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
|
_, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, *numParallel)
|
||||||
if estimatedVRAM > bestEstimate {
|
if estimatedVRAM > bestEstimate {
|
||||||
bestEstimate = estimatedVRAM
|
bestEstimate = estimatedVRAM
|
||||||
bestFit = i
|
bestFit = i
|
||||||
@@ -825,7 +825,7 @@ func (s *Scheduler) expireRunner(model *Model) {
|
|||||||
// If not, pick a runner to unload, else return nil and the request can be loaded
|
// If not, pick a runner to unload, else return nil and the request can be loaded
|
||||||
func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList) *runnerRef {
|
func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList) *runnerRef {
|
||||||
slog.Debug("evaluating if CPU model load will fit in available system memory")
|
slog.Debug("evaluating if CPU model load will fit in available system memory")
|
||||||
estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts)
|
estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts, req.opts.NumCtx/req.origNumCtx)
|
||||||
if estimate.TotalSize <= gpus[0].FreeMemory {
|
if estimate.TotalSize <= gpus[0].FreeMemory {
|
||||||
slog.Debug("cpu inference mode, model fits in available system memory", "model", format.HumanBytes2(estimate.TotalSize), "available", format.HumanBytes2(gpus[0].FreeMemory))
|
slog.Debug("cpu inference mode, model fits in available system memory", "model", format.HumanBytes2(estimate.TotalSize), "available", format.HumanBytes2(gpus[0].FreeMemory))
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
Reference in New Issue
Block a user