mirror of
https://github.com/ollama/ollama.git
synced 2026-04-23 17:29:54 +02:00
Compare commits
21 Commits
parth/samp
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
159821594c | ||
|
|
cbeb2aab4f | ||
|
|
96df15edfc | ||
|
|
c001b98087 | ||
|
|
23fc8e92eb | ||
|
|
4059a297a6 | ||
|
|
66b2539238 | ||
|
|
ef27d52e79 | ||
|
|
b2a465296d | ||
|
|
5d097277ef | ||
|
|
071a9872cb | ||
|
|
0bd0454ea7 | ||
|
|
01aa788722 | ||
|
|
ead27aa9fe | ||
|
|
b816ff86c9 | ||
|
|
e5d84fb90b | ||
|
|
dd66712e31 | ||
|
|
f66216e399 | ||
|
|
f4f0992b6e | ||
|
|
1feff61977 | ||
|
|
5e0b904e88 |
@@ -86,9 +86,9 @@ if(CMAKE_CUDA_COMPILER)
|
||||
)
|
||||
endif()
|
||||
|
||||
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a):xnack[+-]$"
|
||||
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a|1200|1201):xnack[+-]$"
|
||||
CACHE STRING
|
||||
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a):xnack[+-]$\"."
|
||||
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a|1200|1201):xnack[+-]$\"."
|
||||
)
|
||||
|
||||
check_language(HIP)
|
||||
@@ -97,7 +97,7 @@ if(CMAKE_HIP_COMPILER)
|
||||
|
||||
find_package(hip REQUIRED)
|
||||
if(NOT AMDGPU_TARGETS)
|
||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012])$")
|
||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012]|120[01])$")
|
||||
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
||||
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
||||
endif()
|
||||
|
||||
@@ -56,7 +56,7 @@
|
||||
"name": "ROCm 6",
|
||||
"inherits": [ "ROCm" ],
|
||||
"cacheVariables": {
|
||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||
}
|
||||
}
|
||||
],
|
||||
|
||||
@@ -285,6 +285,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||
- [Saddle](https://github.com/jikkuatwork/saddle)
|
||||
- [TagSpaces](https://www.tagspaces.org) (A platform for file based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
|
||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
|
||||
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
||||
@@ -394,6 +395,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
|
||||
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
|
||||
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
|
||||
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
|
||||
|
||||
### Cloud
|
||||
|
||||
@@ -433,7 +436,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
|
||||
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
|
||||
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
|
||||
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
|
||||
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
||||
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.
|
||||
|
||||
### Apple Vision Pro
|
||||
|
||||
|
||||
@@ -111,6 +111,7 @@ func GetCPUDetails() ([]CPU, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
return linuxCPUDetails(file)
|
||||
}
|
||||
|
||||
@@ -168,13 +169,11 @@ func linuxCPUDetails(file io.Reader) ([]CPU, error) {
|
||||
for id, s := range socketByID {
|
||||
s.CoreCount = len(coreBySocket[id])
|
||||
s.ThreadCount = 0
|
||||
for _, tc := range threadsByCoreBySocket[id] {
|
||||
s.ThreadCount += tc
|
||||
}
|
||||
|
||||
// This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons?
|
||||
efficiencyCoreCount := 0
|
||||
for _, threads := range threadsByCoreBySocket[id] {
|
||||
s.ThreadCount += threads
|
||||
if threads == 1 {
|
||||
efficiencyCoreCount++
|
||||
}
|
||||
|
||||
@@ -20,7 +20,13 @@ Please refer to the [GPU docs](./gpu.md).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
By default, Ollama uses a context window size of 2048 tokens. This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context length to 8K, use: `OLLAMA_CONTEXT_LENGTH=8192 ollama serve`.
|
||||
By default, Ollama uses a context window size of 2048 tokens.
|
||||
|
||||
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||
|
||||
```shell
|
||||
OLLAMA_CONTEXT_LENGTH=8192 ollama serve
|
||||
```
|
||||
|
||||
To change this when using `ollama run`, use `/set parameter`:
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
|
||||
On **Linux** systems with systemd, the logs can be found with this command:
|
||||
|
||||
```shell
|
||||
journalctl -u ollama --no-pager
|
||||
journalctl -u ollama --no-pager --follow --pager-end
|
||||
```
|
||||
|
||||
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
|
||||
|
||||
@@ -413,7 +413,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
heads := f.KV().HeadCount()
|
||||
headsKV := f.KV().HeadCountKV()
|
||||
@@ -426,7 +426,10 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
layers := f.Tensors().GroupLayers()
|
||||
|
||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||
kv = uint64(float64(context*f.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
kv = make([]uint64, f.KV().BlockCount())
|
||||
for i := range kv {
|
||||
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
}
|
||||
|
||||
switch f.KV().Architecture() {
|
||||
case "llama":
|
||||
@@ -460,16 +463,14 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
case "mllama":
|
||||
var visionTokens, tiles uint64 = 1601, 4
|
||||
|
||||
if crossAttentionLayers, ok := f.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
|
||||
kv = headsKV *
|
||||
(embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
|
||||
(2* // sizeof(float16)
|
||||
(f.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
|
||||
context +
|
||||
4* // sizeof(float32)
|
||||
uint64(crossAttentionLayers.size)* // num cross attention layers
|
||||
visionTokens*
|
||||
tiles)
|
||||
crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers")
|
||||
for i := range kv {
|
||||
if slices.Contains(crossAttentionLayers, uint32(i)) {
|
||||
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
|
||||
4 * // sizeof(float32)
|
||||
visionTokens *
|
||||
tiles
|
||||
}
|
||||
}
|
||||
|
||||
fullOffload = max(
|
||||
@@ -505,6 +506,20 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
4*embeddingHeadsK*context*8+
|
||||
embedding*embeddingHeadsK*heads*9/16,
|
||||
)
|
||||
|
||||
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
|
||||
// engine. Gemma3 always uses the Ollama engine.
|
||||
if f.KV().Architecture() == "gemma3" {
|
||||
const gemma3GlobalCacheCount = 6
|
||||
slidingWindow := (uint64(numParallel) * uint64(f.KV().Uint("attention.sliding_window"))) + batch
|
||||
for i := range kv {
|
||||
// Every 6th layer is a global layer, which is the full context size that has already been set. The other
|
||||
// layers are the smaller local (sliding) layers.
|
||||
if (i+1)%gemma3GlobalCacheCount != 0 {
|
||||
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
}
|
||||
}
|
||||
}
|
||||
case "command-r":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
|
||||
@@ -119,10 +119,10 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
||||
}
|
||||
|
||||
var cacheSize int
|
||||
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch {
|
||||
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize) {
|
||||
cacheSize = maxSequences * capacity
|
||||
} else {
|
||||
cacheSize = maxSequences * (int(c.windowSize) + maxBatch)
|
||||
cacheSize = (maxSequences * int(c.windowSize)) + maxBatch
|
||||
}
|
||||
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||
c.cells = make([]cacheCell, cacheSize)
|
||||
|
||||
@@ -362,7 +362,6 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
}
|
||||
|
||||
func (c *testContext) Input() ml.Context { return c }
|
||||
func (c *testContext) Output() ml.Context { return c }
|
||||
func (c *testContext) Layer(int) ml.Context { return c }
|
||||
|
||||
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||
@@ -463,7 +462,7 @@ func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
|
||||
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -166,6 +166,10 @@ func (c *Context) KvCacheDefrag() {
|
||||
C.llama_kv_cache_defrag(c.c)
|
||||
}
|
||||
|
||||
func (c *Context) KvCacheCanShift() bool {
|
||||
return bool(C.llama_kv_cache_can_shift(c.c))
|
||||
}
|
||||
|
||||
// Get the embeddings for a sequence id
|
||||
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
|
||||
e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
|
||||
|
||||
103
llama/patches/0022-add-rdna4-support.patch
Normal file
103
llama/patches/0022-add-rdna4-support.patch
Normal file
@@ -0,0 +1,103 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Saman <saman.khatir@amd.com>
|
||||
Date: Wed, 19 Mar 2025 14:02:26 -0700
|
||||
Subject: [PATCH] add rdna4 support
|
||||
|
||||
---
|
||||
ggml/src/ggml-cuda/common.cuh | 6 ++++--
|
||||
ggml/src/ggml-cuda/mmq.cu | 2 +-
|
||||
ggml/src/ggml-cuda/mmq.cuh | 4 ++--
|
||||
ggml/src/ggml-cuda/mmvq.cu | 4 ++--
|
||||
ggml/src/ggml-cuda/vendors/hip.h | 4 ++++
|
||||
5 files changed, 13 insertions(+), 7 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
|
||||
index adf0d3ec..b24593fc 100644
|
||||
--- a/ggml/src/ggml-cuda/common.cuh
|
||||
+++ b/ggml/src/ggml-cuda/common.cuh
|
||||
@@ -61,11 +61,13 @@
|
||||
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
|
||||
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
||||
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
||||
+#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
|
||||
|
||||
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
||||
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
||||
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
||||
-#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
|
||||
+#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
|
||||
+#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
|
||||
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
|
||||
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
||||
|
||||
@@ -386,7 +388,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
|
||||
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
||||
-#elif defined(RDNA3)
|
||||
+#elif defined(RDNA3) || defined(RDNA4)
|
||||
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
||||
#elif defined(__gfx1010__) || defined(__gfx900__)
|
||||
int tmp1;
|
||||
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
|
||||
index 10f2ebb1..933d945c 100644
|
||||
--- a/ggml/src/ggml-cuda/mmq.cu
|
||||
+++ b/ggml/src/ggml-cuda/mmq.cu
|
||||
@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
- return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
+ return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
|
||||
index 0451c65f..66ce2bc9 100644
|
||||
--- a/ggml/src/ggml-cuda/mmq.cuh
|
||||
+++ b/ggml/src/ggml-cuda/mmq.cuh
|
||||
@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
|
||||
|
||||
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
-#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
+#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
||||
-#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
+#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
#else
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
||||
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
|
||||
index 4fb466ca..23ae7abc 100644
|
||||
--- a/ggml/src/ggml-cuda/mmvq.cu
|
||||
+++ b/ggml/src/ggml-cuda/mmvq.cu
|
||||
@@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q(
|
||||
|
||||
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
||||
|
||||
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
||||
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
|
||||
constexpr int nwarps = 1;
|
||||
constexpr int rows_per_cuda_block = 1;
|
||||
#else
|
||||
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
||||
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
||||
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
||||
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
|
||||
|
||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||
const int row0 = rows_per_cuda_block*blockIdx.x;
|
||||
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
index 81964611..a62544b5 100644
|
||||
--- a/ggml/src/ggml-cuda/vendors/hip.h
|
||||
+++ b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
@@ -150,6 +150,10 @@
|
||||
#define CDNA
|
||||
#endif
|
||||
|
||||
+#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||
+#define RDNA4
|
||||
+#endif
|
||||
+
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
||||
defined(__gfx1150__) || defined(__gfx1151__)
|
||||
#define RDNA3
|
||||
@@ -15,12 +15,12 @@ import (
|
||||
)
|
||||
|
||||
// This algorithm looks for a complete fit to determine if we need to unload other models
|
||||
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
|
||||
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
|
||||
// Split up the GPUs by type and try them
|
||||
var estimatedVRAM uint64
|
||||
for _, gpus := range allGpus.ByLibrary() {
|
||||
var layerCount int
|
||||
estimate := EstimateGPULayers(gpus, f, projectors, opts)
|
||||
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
|
||||
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
|
||||
if opts.NumGPU < 0 {
|
||||
if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) {
|
||||
@@ -71,7 +71,7 @@ type MemoryEstimate struct {
|
||||
|
||||
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
|
||||
// The GPUs provided must all be the same Library
|
||||
func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options) MemoryEstimate {
|
||||
func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate {
|
||||
// Graph size for a partial offload, applies to all GPUs
|
||||
var graphPartialOffload uint64
|
||||
|
||||
@@ -137,13 +137,19 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
}
|
||||
}
|
||||
|
||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), kvct)
|
||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct)
|
||||
|
||||
// KV is proportional to the number of layers
|
||||
layerSize += kv / f.KV().BlockCount()
|
||||
if len(kv) > 0 {
|
||||
layerSize += kv[0]
|
||||
}
|
||||
|
||||
var kvTotal uint64
|
||||
for _, kvLayer := range kv {
|
||||
kvTotal += kvLayer
|
||||
}
|
||||
|
||||
if graphPartialOffload == 0 {
|
||||
graphPartialOffload = f.KV().GQA() * kv / 6
|
||||
graphPartialOffload = f.KV().GQA() * kvTotal / 6
|
||||
}
|
||||
if graphFullOffload == 0 {
|
||||
graphFullOffload = graphPartialOffload
|
||||
@@ -217,7 +223,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
// Some models have inconsistent layer sizes
|
||||
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
||||
layerSize = blk.Size()
|
||||
layerSize += kv / f.KV().BlockCount()
|
||||
layerSize += kv[i]
|
||||
memoryWeights += blk.Size()
|
||||
}
|
||||
|
||||
@@ -315,7 +321,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
layersRequested: opts.NumGPU,
|
||||
layersModel: int(f.KV().BlockCount()) + 1,
|
||||
availableList: availableList,
|
||||
kv: kv,
|
||||
kv: kvTotal,
|
||||
allocationsList: allocationsList,
|
||||
memoryWeights: memoryWeights,
|
||||
memoryLayerOutput: memoryLayerOutput,
|
||||
@@ -374,7 +380,7 @@ func (m MemoryEstimate) LogValue() slog.Value {
|
||||
slog.Group(
|
||||
"weights",
|
||||
// memory of the weights
|
||||
"total", format.HumanBytes2(m.memoryWeights),
|
||||
"total", format.HumanBytes2(m.memoryWeights+m.memoryLayerOutput),
|
||||
// memory of repeating layers
|
||||
"repeating", format.HumanBytes2(m.memoryWeights),
|
||||
// memory of non-repeating layers
|
||||
|
||||
@@ -61,7 +61,7 @@ func TestEstimateGPULayers(t *testing.T) {
|
||||
projectors := []string{}
|
||||
opts := api.DefaultOptions()
|
||||
t.Run("cpu", func(t *testing.T) {
|
||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
|
||||
assert.Equal(t, 0, estimate.Layers)
|
||||
assert.Equal(t, uint64(0), estimate.Graph)
|
||||
})
|
||||
@@ -112,7 +112,7 @@ func TestEstimateGPULayers(t *testing.T) {
|
||||
gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
|
||||
gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
||||
gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
|
||||
assert.Equal(t, int(s.expect0+s.expect1), estimate.Layers, "scenario %d: %v", i, s)
|
||||
assert.Equal(t, fmt.Sprintf("%d,%d", s.expect0, s.expect1), estimate.TensorSplit, "scenario %d: %v", i, s)
|
||||
var layerSums uint64
|
||||
|
||||
@@ -109,7 +109,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
gpus = discover.GetCPUInfo()
|
||||
}
|
||||
|
||||
estimate := EstimateGPULayers(gpus, f, projectors, opts)
|
||||
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
|
||||
if len(gpus) > 1 || gpus[0].Library != "cpu" {
|
||||
switch {
|
||||
case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
|
||||
|
||||
@@ -110,16 +110,61 @@ type Context interface {
|
||||
MaxGraphNodes() int
|
||||
Close()
|
||||
|
||||
// Input returns a context appropriate for creating input tensors
|
||||
// Input returns a context appropriate for creating tensors that are
|
||||
// inputs to the model (which includes things like output locations)
|
||||
Input() Context
|
||||
|
||||
// Output returns a context appropriate for creating output tensors
|
||||
Output() Context
|
||||
|
||||
// Layer returns a context appropriate for creating intermediate tensors
|
||||
Layer(int) Context
|
||||
}
|
||||
|
||||
// RopeType represents different RoPE (Rotary Position Embedding) implementation types
|
||||
type RopeType int
|
||||
|
||||
// Available RoPE implementation types
|
||||
const (
|
||||
RopeTypeNormal RopeType = iota // Standard RoPE implementation
|
||||
RopeTypeNeox // NeoX-style RoPE implementation
|
||||
RopeTypeMRoPE // Multimodal RoPE implementation
|
||||
RopeTypeVision // Vision-specific RoPE implementation
|
||||
)
|
||||
|
||||
type YarnConfig struct {
|
||||
YarnCtxTrain int // Context size used during training (for YaRN scaling)
|
||||
YarnExtFactor float32 // Extension factor for YaRN
|
||||
YarnAttnFactor float32 // Attention scaling factor for YaRN
|
||||
YarnBetaFast float32 // Fast decay parameter for YaRN
|
||||
YarnBetaSlow float32 // Slow decay parameter for YaRN
|
||||
}
|
||||
|
||||
// DefaultYarnConfig returns a default configuration for YaRN (Yet Another Rope Extension)
|
||||
func DefaultYarnConfig(nCtx int32) *YarnConfig {
|
||||
return &YarnConfig{
|
||||
YarnCtxTrain: int(nCtx),
|
||||
YarnExtFactor: 0.0,
|
||||
YarnAttnFactor: 1.0,
|
||||
YarnBetaFast: 32.0,
|
||||
YarnBetaSlow: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
// RoPEConfig holds configuration for Rotary Position Embedding
|
||||
type RoPEConfig struct {
|
||||
// Dim is the dimensionality for applying rotary embeddings
|
||||
Dim uint32
|
||||
|
||||
// Type specifies the RoPE implementation variant
|
||||
Type RopeType
|
||||
|
||||
// Base controls frequency decay for the embeddings
|
||||
Base float32
|
||||
|
||||
// Scale allows scaling the effective context length
|
||||
Scale float32
|
||||
|
||||
*YarnConfig
|
||||
}
|
||||
|
||||
type Tensor interface {
|
||||
Dim(n int) int
|
||||
Stride(n int) int
|
||||
@@ -143,7 +188,7 @@ type Tensor interface {
|
||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
|
||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, config RoPEConfig) Tensor
|
||||
|
||||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context) Tensor
|
||||
|
||||
@@ -48,9 +48,6 @@ type Backend struct {
|
||||
// input is the backend used for inputs
|
||||
input *C.struct_ggml_backend_buffer_type
|
||||
|
||||
// output is the backend used for outputs
|
||||
output *C.struct_ggml_backend_buffer_type
|
||||
|
||||
// layers is the backend used for repeating layers
|
||||
layers map[int]*C.struct_ggml_backend_buffer_type
|
||||
|
||||
@@ -400,8 +397,7 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
|
||||
C.size_t(maxGraphNodes),
|
||||
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
|
||||
),
|
||||
input: deviceBufferTypes[input.d],
|
||||
output: deviceBufferTypes[output.d],
|
||||
input: deviceBufferTypes[input.d],
|
||||
layers: func() map[int]*C.struct_ggml_backend_buffer_type {
|
||||
m := make(map[int]*C.struct_ggml_backend_buffer_type)
|
||||
for i, layer := range layers {
|
||||
@@ -482,19 +478,6 @@ func (c Context) Input() ml.Context {
|
||||
return &c
|
||||
}
|
||||
|
||||
func (c Context) Output() ml.Context {
|
||||
if c.b.output != nil {
|
||||
return &Context{
|
||||
b: c.b,
|
||||
ctx: c.ctx,
|
||||
buft: c.b.output,
|
||||
maxGraphNodes: c.maxGraphNodes,
|
||||
}
|
||||
}
|
||||
|
||||
return &c
|
||||
}
|
||||
|
||||
func (c Context) Layer(i int) ml.Context {
|
||||
if buft, ok := c.b.layers[i]; ok {
|
||||
return &Context{
|
||||
@@ -924,6 +907,8 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
// GGML RoPE types
|
||||
// These are the types used in the C implementation of RoPE
|
||||
const (
|
||||
ropeTypeNorm C.int = 0
|
||||
ropeTypeNeox C.int = 2
|
||||
@@ -931,7 +916,8 @@ const (
|
||||
ropeTypeVision C.int = 24
|
||||
)
|
||||
|
||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
|
||||
// RoPE applies Rotary Position Embeddings to the tensor
|
||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
|
||||
if ropeFactors == nil {
|
||||
ropeFactors = &Tensor{b: t.b}
|
||||
}
|
||||
@@ -941,19 +927,41 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
||||
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
|
||||
}
|
||||
|
||||
if config.YarnConfig == nil {
|
||||
config.YarnConfig = ml.DefaultYarnConfig(131072) // 131072 is the default for LLaMA, so it is common at the time of writing
|
||||
}
|
||||
|
||||
// Map Go RopeType to C implementation constants
|
||||
var ropeTypeC C.int
|
||||
switch config.Type {
|
||||
case ml.RopeTypeNormal:
|
||||
ropeTypeC = ropeTypeNorm
|
||||
case ml.RopeTypeNeox:
|
||||
ropeTypeC = ropeTypeNeox
|
||||
case ml.RopeTypeMRoPE:
|
||||
ropeTypeC = ropeTypeMrope
|
||||
case ml.RopeTypeVision:
|
||||
ropeTypeC = ropeTypeVision
|
||||
default:
|
||||
ropeTypeC = ropeTypeNorm
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_rope_ext(
|
||||
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
|
||||
C.int(ropeDim),
|
||||
C.int(ropeType),
|
||||
131072, // YaRN n_ctx_train
|
||||
C.float(ropeBase),
|
||||
C.float(ropeScale),
|
||||
0., // YaRN ext_factor
|
||||
1., // YaRN attn_factor
|
||||
32., // YaRN beta_fast
|
||||
1., // YaRN beta_slow
|
||||
ctx.(*Context).ctx,
|
||||
dequant,
|
||||
positionIDs.(*Tensor).t,
|
||||
ropeFactors.(*Tensor).t,
|
||||
C.int(config.Dim),
|
||||
ropeTypeC,
|
||||
C.int(config.YarnCtxTrain),
|
||||
C.float(config.Base),
|
||||
C.float(config.Scale),
|
||||
C.float(config.YarnExtFactor),
|
||||
C.float(config.YarnAttnFactor),
|
||||
C.float(config.YarnBetaFast),
|
||||
C.float(config.YarnBetaSlow),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,11 +61,13 @@
|
||||
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
|
||||
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
||||
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
||||
#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
|
||||
|
||||
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
||||
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
||||
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
||||
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
|
||||
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
|
||||
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
|
||||
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
|
||||
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
||||
|
||||
@@ -386,7 +388,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
|
||||
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
||||
#elif defined(RDNA3)
|
||||
#elif defined(RDNA3) || defined(RDNA4)
|
||||
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
||||
#elif defined(__gfx1010__) || defined(__gfx900__)
|
||||
int tmp1;
|
||||
|
||||
2
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
vendored
2
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
vendored
@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
4
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
vendored
4
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
vendored
@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
|
||||
|
||||
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
||||
#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||
#else
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
||||
|
||||
4
ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
vendored
4
ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
vendored
@@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q(
|
||||
|
||||
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
||||
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
|
||||
constexpr int nwarps = 1;
|
||||
constexpr int rows_per_cuda_block = 1;
|
||||
#else
|
||||
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
||||
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
|
||||
|
||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||
const int row0 = rows_per_cuda_block*blockIdx.x;
|
||||
|
||||
@@ -150,6 +150,10 @@
|
||||
#define CDNA
|
||||
#endif
|
||||
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||
#define RDNA4
|
||||
#endif
|
||||
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
||||
defined(__gfx1150__) || defined(__gfx1151__)
|
||||
#define RDNA3
|
||||
|
||||
@@ -13,10 +13,11 @@ import (
|
||||
type Options struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
attnKeyLen, attnValLen int
|
||||
eps, ropeBase, ropeScale float32
|
||||
eps float32
|
||||
attnLogitSoftcap float32
|
||||
finalLogitSoftcap float32
|
||||
largeModelScaling bool
|
||||
ropeConfig ml.RoPEConfig
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
@@ -55,10 +56,15 @@ func New(c ml.Config) (model.Model, error) {
|
||||
attnKeyLen: int(c.Uint("attention.key_length")),
|
||||
attnValLen: int(c.Uint("attention.value_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base", 10000.0),
|
||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
|
||||
finalLogitSoftcap: c.Float("final_logit_softcapping"),
|
||||
ropeConfig: ml.RoPEConfig{
|
||||
Base: c.Float("rope.freq_base", 10000.0),
|
||||
Scale: c.Float("rope.freq_scale", 1.0),
|
||||
Dim: c.Uint("attention.key_length"),
|
||||
Type: ml.RopeTypeNormal,
|
||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -78,11 +84,10 @@ type SelfAttention struct {
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
ropeType := uint32(2)
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
||||
q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
|
||||
|
||||
if opts.largeModelScaling {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
@@ -92,7 +97,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
||||
k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
@@ -122,7 +127,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
|
||||
return key.RoPE(ctx, shift, nil, m.ropeConfig), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
|
||||
@@ -13,9 +13,11 @@ import (
|
||||
type TextOptions struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
attnKeyLen, attnValLen int
|
||||
eps, ropeScale float32
|
||||
ropeLocalBase, ropeGlobalBase float32
|
||||
eps float32
|
||||
largeModelScaling bool
|
||||
|
||||
ropeLocalConfig ml.RoPEConfig
|
||||
ropeGlobalConfig ml.RoPEConfig
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
@@ -56,15 +58,27 @@ func newTextModel(c ml.Config) *TextModel {
|
||||
),
|
||||
Layers: make([]TextLayer, numBlocks),
|
||||
TextOptions: &TextOptions{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
||||
attnValLen: int(c.Uint("attention.value_length", 256)),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
||||
attnValLen: int(c.Uint("attention.value_length", 256)),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
|
||||
ropeLocalConfig: ml.RoPEConfig{
|
||||
Base: c.Float("rope.local.freq_base", 10000.0),
|
||||
Scale: c.Float("rope.freq_scale", 1.0),
|
||||
Dim: c.Uint("attention.key_length", 256),
|
||||
Type: ml.RopeTypeNeox,
|
||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||
},
|
||||
ropeGlobalConfig: ml.RoPEConfig{
|
||||
Base: c.Float("rope.global.freq_base", 1000000.0),
|
||||
Scale: c.Float("rope.freq_scale", 1.0),
|
||||
Dim: c.Uint("attention.key_length", 256),
|
||||
Type: ml.RopeTypeNeox,
|
||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -86,17 +100,16 @@ type TextSelfAttention struct {
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
ropeType := uint32(2)
|
||||
|
||||
ropeBase := opts.ropeLocalBase
|
||||
ropeConfig := opts.ropeLocalConfig
|
||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||
ropeBase = opts.ropeGlobalBase
|
||||
ropeConfig = opts.ropeGlobalConfig
|
||||
}
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
||||
q = q.RoPE(ctx, positionIDs, nil, ropeConfig)
|
||||
|
||||
if opts.largeModelScaling {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
@@ -107,7 +120,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
||||
k = k.RoPE(ctx, positionIDs, nil, ropeConfig)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
@@ -120,12 +133,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeBase := m.TextOptions.ropeLocalBase
|
||||
ropeConfig := m.ropeLocalConfig
|
||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||
ropeBase = m.TextOptions.ropeGlobalBase
|
||||
ropeConfig = m.ropeGlobalConfig
|
||||
}
|
||||
|
||||
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
|
||||
return key.RoPE(ctx, shift, nil, ropeConfig), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
|
||||
@@ -14,8 +14,8 @@ import (
|
||||
|
||||
type Options struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
eps float32
|
||||
ropeConfig ml.RoPEConfig
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
@@ -54,9 +54,13 @@ func New(c ml.Config) (model.Model, error) {
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeDim: c.Uint("rope.dimension_count"),
|
||||
ropeConfig: ml.RoPEConfig{
|
||||
Base: c.Float("rope.freq_base"),
|
||||
Scale: c.Float("rope.freq_scale", 1),
|
||||
Dim: c.Uint("rope.dimension_count"),
|
||||
Type: ml.RopeTypeNormal,
|
||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -76,15 +80,14 @@ type SelfAttention struct {
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
ropeType := uint32(0)
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@@ -97,7 +100,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
|
||||
@@ -20,15 +20,14 @@ type TextSelfAttention struct {
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
ropeType := uint32(0)
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig)
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig)
|
||||
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@@ -43,7 +42,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// This will only get called for layers in the cache, which are just the self attention layers
|
||||
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
||||
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
|
||||
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeConfig), nil
|
||||
}
|
||||
|
||||
return key, nil
|
||||
@@ -198,8 +197,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
|
||||
|
||||
type TextModelOptions struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
eps float32
|
||||
ropeConfig ml.RoPEConfig
|
||||
|
||||
crossAttentionLayers []uint32
|
||||
}
|
||||
@@ -240,10 +239,14 @@ func newTextModel(c ml.Config) *TextModel {
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeDim: c.Uint("rope.dimension_count"),
|
||||
crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
|
||||
ropeConfig: ml.RoPEConfig{
|
||||
Base: c.Float("rope.freq_base"),
|
||||
Scale: c.Float("rope.freq_scale", 1),
|
||||
Dim: c.Uint("rope.dimension_count"),
|
||||
Type: ml.RopeTypeNormal,
|
||||
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,8 +213,16 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
|
||||
return discard
|
||||
}
|
||||
|
||||
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
||||
// the newest half into that space (saving numKeep inputs at the beginning).
|
||||
type ErrReprocessInputs struct {
|
||||
Inputs []input
|
||||
}
|
||||
|
||||
func (e *ErrReprocessInputs) Error() string {
|
||||
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
|
||||
}
|
||||
|
||||
// ShiftCacheSlot frees up space in the KV cache by deleting the oldest half of history
|
||||
// and shifting the newest half into that space (saving numKeep inputs at the beginning).
|
||||
//
|
||||
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
|
||||
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
||||
@@ -222,7 +230,8 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
||||
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
|
||||
}
|
||||
|
||||
discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
|
||||
inputLen := len(slot.Inputs)
|
||||
discard := c.ShiftDiscard(inputLen, numKeep)
|
||||
|
||||
if discard <= 0 {
|
||||
return nil
|
||||
@@ -231,16 +240,42 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
||||
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
||||
"keep", numKeep, "discard", discard)
|
||||
|
||||
// TODO (jessegross): KV cache removal can fail for certain types of models
|
||||
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
|
||||
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
|
||||
}
|
||||
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
|
||||
var shiftFailed bool
|
||||
|
||||
for i := numKeep + discard; i < len(slot.Inputs); i++ {
|
||||
if c.lc.KvCacheCanShift() {
|
||||
// For models that support shifting, attempt to shift the KV cache
|
||||
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
|
||||
shiftFailed = true
|
||||
slog.Debug("kv cache removal not supported, clearing cache and returning inputs for reprocessing", "id", slot.Id)
|
||||
} else {
|
||||
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, inputLen, -discard)
|
||||
}
|
||||
} else {
|
||||
// For models that don't support shifting
|
||||
shiftFailed = true
|
||||
slog.Debug("kv cache cannot shift, clearing cache and returning inputs for reprocessing", "id", slot.Id)
|
||||
}
|
||||
|
||||
if shiftFailed {
|
||||
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
||||
newInputs := make([]input, numKeep+inputLen-(numKeep+discard))
|
||||
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
||||
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
||||
|
||||
// Clear the entire KV cache
|
||||
_ = c.lc.KvCacheSeqRm(slot.Id, 0, -1)
|
||||
// Reset the slot inputs since we've cleared the cache
|
||||
slot.Inputs = []input{}
|
||||
|
||||
// Return error with inputs that need to be reprocessed
|
||||
return &ErrReprocessInputs{Inputs: newInputs}
|
||||
}
|
||||
|
||||
// Standard shift succeeded - update input array
|
||||
for i := numKeep + discard; i < inputLen; i++ {
|
||||
slot.Inputs[i-discard] = slot.Inputs[i]
|
||||
}
|
||||
slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
|
||||
slot.Inputs = slot.Inputs[:inputLen-discard]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -389,7 +389,15 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
if len(seq.pendingInputs) == 0 {
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
return err
|
||||
var reprocess *ErrReprocessInputs
|
||||
if errors.As(err, &reprocess) {
|
||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||
// Continue processing as normal
|
||||
continue
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break
|
||||
@@ -599,7 +607,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting completion request due to client closing the connection")
|
||||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -611,6 +619,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -626,6 +635,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -691,7 +701,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting embeddings request due to client closing the connection")
|
||||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -703,6 +713,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -715,6 +726,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -239,6 +239,14 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
||||
return discard
|
||||
}
|
||||
|
||||
type ErrReprocessInputs struct {
|
||||
Inputs []input.Input
|
||||
}
|
||||
|
||||
func (e *ErrReprocessInputs) Error() string {
|
||||
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
|
||||
}
|
||||
|
||||
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
||||
// the newest half into that space (saving numKeep inputs at the beginning).
|
||||
//
|
||||
@@ -258,11 +266,23 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
|
||||
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
||||
"keep", numKeep, "discard", discard)
|
||||
|
||||
// TODO (jessegross): KV cache removal can fail for certain types of models
|
||||
if c.cache != nil {
|
||||
err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err)
|
||||
slog.Debug("kv cache removal unsupported, clearing cache and returning inputs for reprocessing",
|
||||
"id", slot.Id, "error", err)
|
||||
|
||||
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
||||
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
|
||||
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
||||
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
||||
|
||||
// Reset the cache
|
||||
_ = c.cache.Remove(slot.Id, 0, -1)
|
||||
slot.Inputs = []input.Input{}
|
||||
|
||||
// Return error with inputs that need to be reprocessed
|
||||
return &ErrReprocessInputs{Inputs: newInputs}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package ollamarunner
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
@@ -425,3 +428,91 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Mock implementation of the Cache interface
|
||||
type mockCache struct {
|
||||
shouldFail bool
|
||||
}
|
||||
|
||||
// Implement only the methods needed for the test
|
||||
func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if m.shouldFail {
|
||||
return fmt.Errorf("mock cache removal error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stub implementations for other interface methods
|
||||
func (m *mockCache) SetLayer(layer int) {}
|
||||
func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { return nil, nil, nil }
|
||||
func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {}
|
||||
func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {}
|
||||
func (m *mockCache) Close() {}
|
||||
func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch) error { return nil }
|
||||
func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
|
||||
func (m *mockCache) SetConfig(ml.CacheConfig) {}
|
||||
|
||||
func TestShiftCacheSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
numCtx int32
|
||||
inputs []input.Input
|
||||
numKeep int32
|
||||
cacheErr bool
|
||||
wantErr any
|
||||
wantInputsLen int
|
||||
}{
|
||||
{
|
||||
name: "Normal shift",
|
||||
numCtx: 10,
|
||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
numKeep: 2,
|
||||
cacheErr: false, // No error
|
||||
wantErr: nil,
|
||||
wantInputsLen: 6, // After discarding 4 tokens
|
||||
},
|
||||
{
|
||||
name: "Cache removal fails",
|
||||
numCtx: 10,
|
||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
numKeep: 2,
|
||||
cacheErr: true,
|
||||
wantErr: &ErrReprocessInputs{},
|
||||
wantInputsLen: 0, // Original inputs should be cleared
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mock := &mockCache{shouldFail: tt.cacheErr}
|
||||
c := InputCache{
|
||||
numCtx: tt.numCtx,
|
||||
cache: mock,
|
||||
}
|
||||
slot := &InputCacheSlot{
|
||||
Id: 123,
|
||||
Inputs: make([]input.Input, len(tt.inputs)),
|
||||
}
|
||||
copy(slot.Inputs, tt.inputs)
|
||||
|
||||
err := c.ShiftCacheSlot(slot, tt.numKeep)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error but got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if !errors.As(err, &tt.wantErr) {
|
||||
t.Errorf("Expected error of type %T but got %T: %v", tt.wantErr, err, err)
|
||||
}
|
||||
} else if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(slot.Inputs) != tt.wantInputsLen {
|
||||
t.Errorf("Slot inputs length after operation: got %v, want %v", len(slot.Inputs), tt.wantInputsLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -267,6 +267,9 @@ type Server struct {
|
||||
// KV cache
|
||||
cache *InputCache
|
||||
|
||||
// next sequence for prompt processing to avoid starvation
|
||||
nextSeq int
|
||||
|
||||
// multimodalHash generates hashes for comparing equality
|
||||
// of non-text data
|
||||
multimodalHash maphash.Hash
|
||||
@@ -351,14 +354,19 @@ func (s *Server) processBatch() error {
|
||||
var batchInputs []int32
|
||||
var batch input.Batch
|
||||
|
||||
for i, seq := range s.seqs {
|
||||
resumeSeq := -1
|
||||
seqIdx := s.nextSeq - 1
|
||||
for range s.seqs {
|
||||
seqIdx = (seqIdx + 1) % len(s.seqs)
|
||||
seq := s.seqs[seqIdx]
|
||||
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(i, "limit")
|
||||
s.removeSequence(seqIdx, "limit")
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -369,16 +377,23 @@ func (s *Server) processBatch() error {
|
||||
|
||||
batchSize := s.batchSize
|
||||
|
||||
for j, inp := range seq.inputs {
|
||||
for i, inp := range seq.inputs {
|
||||
// If we are required to put following inputs into a single batch then extend the
|
||||
// batch size. Since we are only extending the size the minimum amount possible, this
|
||||
// will cause a break if we have pending inputs.
|
||||
// will cause a break if we have existing inputs.
|
||||
minBatch := 1 + inp.SameBatch
|
||||
if minBatch > batchSize {
|
||||
batchSize = minBatch
|
||||
}
|
||||
|
||||
if len(seq.pendingInputs)+minBatch > batchSize {
|
||||
// Stop if the required batch would put us over the total batch size (including tokens
|
||||
// added by other sequences). If we haven't been able to add anything yet then pick up
|
||||
// here again for the next batch to avoid starvation, though we can opportunistically
|
||||
// check if other sequences can still squeeze something in.
|
||||
if len(batchInputs)+minBatch > batchSize {
|
||||
if len(seq.pendingInputs) == 0 && resumeSeq == -1 {
|
||||
resumeSeq = seqIdx
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
@@ -392,7 +407,15 @@ func (s *Server) processBatch() error {
|
||||
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
return err
|
||||
var reprocess *ErrReprocessInputs
|
||||
if errors.As(err, &reprocess) {
|
||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||
// Skip this sequence but continue processing the rest
|
||||
continue
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -405,7 +428,7 @@ func (s *Server) processBatch() error {
|
||||
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||
|
||||
seq.iBatch = len(batch.Outputs)
|
||||
if j+1 == len(seq.inputs) {
|
||||
if i+1 == len(seq.inputs) {
|
||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||
}
|
||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||
@@ -414,6 +437,12 @@ func (s *Server) processBatch() error {
|
||||
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
||||
}
|
||||
|
||||
if resumeSeq != -1 {
|
||||
s.nextSeq = resumeSeq
|
||||
} else {
|
||||
s.nextSeq = seqIdx + 1
|
||||
}
|
||||
|
||||
if len(batchInputs) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -588,7 +617,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting completion request due to client closing the connection")
|
||||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -600,6 +629,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -613,6 +643,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -29,8 +29,9 @@ import (
|
||||
const maxRetries = 6
|
||||
|
||||
var (
|
||||
errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||
errPartStalled = errors.New("part stalled")
|
||||
errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||
errPartStalled = errors.New("part stalled")
|
||||
errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL")
|
||||
)
|
||||
|
||||
var blobDownloadManager sync.Map
|
||||
@@ -236,7 +237,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||
|
||||
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) > 10 {
|
||||
return errors.New("maximum redirects exceeded (10) for directURL")
|
||||
return errMaxRedirectsExceeded
|
||||
}
|
||||
|
||||
// if the hostname is the same, allow the redirect
|
||||
|
||||
@@ -35,6 +35,7 @@ var (
|
||||
errCapabilityCompletion = errors.New("completion")
|
||||
errCapabilityTools = errors.New("tools")
|
||||
errCapabilityInsert = errors.New("insert")
|
||||
errInsecureProtocol = errors.New("insecure protocol http")
|
||||
)
|
||||
|
||||
type Capability string
|
||||
@@ -479,7 +480,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return errors.New("insecure protocol http")
|
||||
return errInsecureProtocol
|
||||
}
|
||||
|
||||
manifest, _, err := GetManifest(mp)
|
||||
@@ -543,7 +544,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
}
|
||||
|
||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
return errors.New("insecure protocol http")
|
||||
return errInsecureProtocol
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pulling manifest"})
|
||||
|
||||
@@ -421,14 +421,6 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func canRetry(err error) bool {
|
||||
var re *Error
|
||||
if !errors.As(err, &re) {
|
||||
return false
|
||||
}
|
||||
return re.Status >= 500
|
||||
}
|
||||
|
||||
// trackingReader is an io.Reader that tracks the number of bytes read and
|
||||
// calls the update function with the layer, the number of bytes read.
|
||||
//
|
||||
@@ -514,13 +506,40 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
break
|
||||
}
|
||||
|
||||
cacheKey := fmt.Sprintf(
|
||||
"v1 pull chunksum %s %s %d-%d",
|
||||
l.Digest,
|
||||
cs.Digest,
|
||||
cs.Chunk.Start,
|
||||
cs.Chunk.End,
|
||||
)
|
||||
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
|
||||
_, err := c.Get(cacheKeyDigest)
|
||||
if err == nil {
|
||||
received.Add(cs.Chunk.Size())
|
||||
t.update(l, cs.Chunk.Size(), ErrCached)
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
g.Go(func() (err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
// Ignore cache key write errors for now. We've already
|
||||
// reported to trace that the chunk is complete.
|
||||
//
|
||||
// Ideally, we should only report completion to trace
|
||||
// after successful cache commit. This current approach
|
||||
// works but could trigger unnecessary redownloads if
|
||||
// the checkpoint key is missing on next pull.
|
||||
//
|
||||
// Not incorrect, just suboptimal - fix this in a
|
||||
// future update.
|
||||
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
|
||||
|
||||
received.Add(cs.Chunk.Size())
|
||||
} else {
|
||||
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
|
||||
t.update(l, 0, err)
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
@@ -563,7 +582,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
return err
|
||||
}
|
||||
if received.Load() != expected {
|
||||
return fmt.Errorf("%w: received %d/%d", ErrIncomplete, received.Load(), expected)
|
||||
return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, received.Load(), expected)
|
||||
}
|
||||
|
||||
md := blob.DigestFromBytes(m.Data)
|
||||
@@ -608,6 +627,30 @@ func (m *Manifest) Layer(d blob.Digest) *Layer {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manifest) All() iter.Seq[*Layer] {
|
||||
return func(yield func(*Layer) bool) {
|
||||
if !yield(m.Config) {
|
||||
return
|
||||
}
|
||||
for _, l := range m.Layers {
|
||||
if !yield(l) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manifest) Size() int64 {
|
||||
var size int64
|
||||
if m.Config != nil {
|
||||
size += m.Config.Size
|
||||
}
|
||||
for _, l := range m.Layers {
|
||||
size += l.Size
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
//
|
||||
// NOTE: It adds an empty config object to the manifest, which is required by
|
||||
@@ -750,20 +793,32 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
|
||||
return
|
||||
}
|
||||
|
||||
// A chunksums response is a sequence of chunksums in a
|
||||
// simple, easy to parse line-oriented format.
|
||||
// The response is a sequence of chunksums.
|
||||
//
|
||||
// Example:
|
||||
// Chunksums are chunks of a larger blob that can be
|
||||
// downloaded and verified independently.
|
||||
//
|
||||
// >> GET /v2/<namespace>/<model>/chunksums/<digest>
|
||||
// The chunksums endpoint is a GET request that returns a
|
||||
// sequence of chunksums in the following format:
|
||||
//
|
||||
// << HTTP/1.1 200 OK
|
||||
// << Content-Location: <blobURL>
|
||||
// <<
|
||||
// << <digest> <start>-<end>
|
||||
// << ...
|
||||
// > GET /v2/<namespace>/<model>/chunksums/<digest>
|
||||
//
|
||||
// The blobURL is the URL to download the chunks from.
|
||||
// < HTTP/1.1 200 OK
|
||||
// < Content-Location: <blobURL>
|
||||
// <
|
||||
// < <digest> <start>-<end>
|
||||
// < ...
|
||||
//
|
||||
// The <blobURL> is the URL to download the chunks from and
|
||||
// each <digest> is the digest of the chunk, and <start>-<end>
|
||||
// is the range the chunk in the blob.
|
||||
//
|
||||
// Ranges may be used directly in Range headers like
|
||||
// "bytes=<start>-<end>".
|
||||
//
|
||||
// The chunksums returned are guaranteed to be contiguous and
|
||||
// include all bytes of the layer. If the stream is cut short,
|
||||
// clients should retry.
|
||||
|
||||
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
|
||||
scheme,
|
||||
|
||||
@@ -9,17 +9,14 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/testutil"
|
||||
@@ -338,15 +335,8 @@ func TestPushCommitRoundtripError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func checkNotExist(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Fatalf("err = %v; want fs.ErrNotExist", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPullInvalidName(t *testing.T) {
|
||||
rc, _ := newClient(t, nil)
|
||||
rc, _ := newRegistryClient(t, nil)
|
||||
err := rc.Pull(t.Context(), "://")
|
||||
if !errors.Is(err, ErrNameInvalid) {
|
||||
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
||||
@@ -362,197 +352,16 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, resp := range cases {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, resp)
|
||||
})
|
||||
err := rc.Pull(t.Context(), "x")
|
||||
err := rc.Pull(t.Context(), "http://example.com/a/b")
|
||||
if !errors.Is(err, ErrManifestInvalid) {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPullNotCached(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
var c *blob.DiskCache
|
||||
var rc *Registry
|
||||
|
||||
d := blob.DigestFromBytes("some data")
|
||||
rc, c = newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
io.WriteString(w, "some data")
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":9}]}`, d)
|
||||
})
|
||||
|
||||
// Confirm that the layer does not exist locally
|
||||
_, err := rc.ResolveLocal("model")
|
||||
checkNotExist(t, err)
|
||||
|
||||
_, err = c.Get(d)
|
||||
checkNotExist(t, err)
|
||||
|
||||
err = rc.Pull(t.Context(), "model")
|
||||
check(err)
|
||||
|
||||
mw, err := rc.Resolve(t.Context(), "model")
|
||||
check(err)
|
||||
mg, err := rc.ResolveLocal("model")
|
||||
check(err)
|
||||
if !reflect.DeepEqual(mw, mg) {
|
||||
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||
}
|
||||
|
||||
// Confirm successful download
|
||||
info, err := c.Get(d)
|
||||
check(err)
|
||||
if info.Digest != d {
|
||||
t.Errorf("info.Digest = %v; want %v", info.Digest, d)
|
||||
}
|
||||
if info.Size != 9 {
|
||||
t.Errorf("info.Size = %v; want %v", info.Size, 9)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(c.GetFile(d))
|
||||
check(err)
|
||||
if string(data) != "some data" {
|
||||
t.Errorf("data = %q; want %q", data, "exists")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPullCached(t *testing.T) {
|
||||
cached := blob.DigestFromBytes("exists")
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
w.WriteHeader(499) // should not be called
|
||||
return
|
||||
}
|
||||
if strings.Contains(r.URL.Path, "/manifests/") {
|
||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, cached)
|
||||
}
|
||||
})
|
||||
|
||||
var errs []error
|
||||
var reads []int64
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(d *Layer, n int64, err error) {
|
||||
t.Logf("update %v %d %v", d, n, err)
|
||||
reads = append(reads, n)
|
||||
errs = append(errs, err)
|
||||
},
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := rc.Pull(ctx, "single")
|
||||
testutil.Check(t, err)
|
||||
|
||||
want := []int64{0, 6}
|
||||
if !errors.Is(errors.Join(errs...), ErrCached) {
|
||||
t.Errorf("errs = %v; want %v", errs, ErrCached)
|
||||
}
|
||||
if !slices.Equal(reads, want) {
|
||||
t.Errorf("pairs = %v; want %v", reads, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryPullManifestNotFound(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
})
|
||||
err := rc.Pull(t.Context(), "notfound")
|
||||
checkErrCode(t, err, 404, "")
|
||||
}
|
||||
|
||||
func TestRegistryPullResolveRemoteError(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
|
||||
})
|
||||
err := rc.Pull(t.Context(), "single")
|
||||
checkErrCode(t, err, 500, "an_error")
|
||||
}
|
||||
|
||||
func TestRegistryPullResolveRoundtripError(t *testing.T) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/manifests/") {
|
||||
w.WriteHeader(499) // force RoundTrip error
|
||||
return
|
||||
}
|
||||
})
|
||||
err := rc.Pull(t.Context(), "single")
|
||||
if !errors.Is(err, errRoundTrip) {
|
||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegistryPullMixedCachedNotCached tests that cached layers do not
|
||||
// interfere with pulling layers that are not cached
|
||||
func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
||||
x := blob.DigestFromBytes("xxxxxx")
|
||||
e := blob.DigestFromBytes("exists")
|
||||
y := blob.DigestFromBytes("yyyyyy")
|
||||
|
||||
for i := range 10 {
|
||||
t.Logf("iteration %d", i)
|
||||
|
||||
digests := []blob.Digest{x, e, y}
|
||||
|
||||
rand.Shuffle(len(digests), func(i, j int) {
|
||||
digests[i], digests[j] = digests[j], digests[i]
|
||||
})
|
||||
|
||||
manifest := fmt.Sprintf(`{
|
||||
"layers": [
|
||||
{"digest":"%s","size":6},
|
||||
{"digest":"%s","size":6},
|
||||
{"digest":"%s","size":6}
|
||||
]
|
||||
}`, digests[0], digests[1], digests[2])
|
||||
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch path.Base(r.URL.Path) {
|
||||
case "latest":
|
||||
io.WriteString(w, manifest)
|
||||
case x.String():
|
||||
io.WriteString(w, "xxxxxx")
|
||||
case e.String():
|
||||
io.WriteString(w, "exists")
|
||||
case y.String():
|
||||
io.WriteString(w, "yyyyyy")
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected request: %v", r))
|
||||
}
|
||||
})
|
||||
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Logf("update %v %d %v", l, n, err)
|
||||
},
|
||||
})
|
||||
|
||||
// Check that we pull all layers that we can.
|
||||
|
||||
err := rc.Pull(ctx, "mixed")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, d := range digests {
|
||||
info, err := c.Get(d)
|
||||
if err != nil {
|
||||
t.Fatalf("Get(%v): %v", d, err)
|
||||
}
|
||||
if info.Size != 6 {
|
||||
t.Errorf("info.Size = %v; want %v", info.Size, 6)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryResolveByDigest(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
@@ -590,26 +399,6 @@ func TestInsecureSkipVerify(t *testing.T) {
|
||||
testutil.Check(t, err)
|
||||
}
|
||||
|
||||
func TestCanRetry(t *testing.T) {
|
||||
cases := []struct {
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{nil, false},
|
||||
{errors.New("x"), false},
|
||||
{ErrCached, false},
|
||||
{ErrManifestInvalid, false},
|
||||
{ErrNameInvalid, false},
|
||||
{&Error{Status: 100}, false},
|
||||
{&Error{Status: 500}, true},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
if got := canRetry(tt.err); got != tt.want {
|
||||
t.Errorf("CanRetry(%v) = %v; want %v", tt.err, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorUnmarshal(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
@@ -761,17 +550,23 @@ func TestParseNameExtended(t *testing.T) {
|
||||
|
||||
func TestUnlink(t *testing.T) {
|
||||
t.Run("found by name", func(t *testing.T) {
|
||||
rc, _ := newClient(t, nil)
|
||||
check := testutil.Checker(t)
|
||||
|
||||
rc, _ := newRegistryClient(t, nil)
|
||||
// make a blob and link it
|
||||
d := blob.DigestFromBytes("{}")
|
||||
err := blob.PutBytes(rc.Cache, d, "{}")
|
||||
check(err)
|
||||
err = rc.Cache.Link("registry.ollama.ai/library/single:latest", d)
|
||||
check(err)
|
||||
|
||||
// confirm linked
|
||||
_, err := rc.ResolveLocal("single")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
_, err = rc.ResolveLocal("single")
|
||||
check(err)
|
||||
|
||||
// unlink
|
||||
_, err = rc.Unlink("single")
|
||||
testutil.Check(t, err)
|
||||
check(err)
|
||||
|
||||
// confirm unlinked
|
||||
_, err = rc.ResolveLocal("single")
|
||||
@@ -780,7 +575,7 @@ func TestUnlink(t *testing.T) {
|
||||
}
|
||||
})
|
||||
t.Run("not found by name", func(t *testing.T) {
|
||||
rc, _ := newClient(t, nil)
|
||||
rc, _ := newRegistryClient(t, nil)
|
||||
ok, err := rc.Unlink("manifestNotFound")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -791,78 +586,368 @@ func TestUnlink(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPullChunksums(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
// Many tests from here out, in this file are based on a single blob, "abc",
|
||||
// with the checksum of its sha256 hash. The checksum is:
|
||||
//
|
||||
// "abc" -> sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad
|
||||
//
|
||||
// Using the literal value instead of a constant with fmt.Xprintf calls proved
|
||||
// to be the most readable and maintainable approach. The sum is consistently
|
||||
// used in the tests and unique so searches do not yield false positives.
|
||||
|
||||
content := "hello"
|
||||
var chunksums string
|
||||
contentDigest := func() blob.Digest {
|
||||
return blob.DigestFromBytes(content)
|
||||
func checkRequest(t *testing.T, req *http.Request, method, path string) {
|
||||
t.Helper()
|
||||
if got := req.URL.Path; got != path {
|
||||
t.Errorf("URL = %q, want %q", got, path)
|
||||
}
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/manifests/latest"):
|
||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
|
||||
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
|
||||
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
|
||||
w.Header().Set("Content-Location", loc)
|
||||
io.WriteString(w, chunksums)
|
||||
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
|
||||
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
|
||||
default:
|
||||
t.Errorf("unexpected request: %v", r)
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
if req.Method != method {
|
||||
t.Errorf("Method = %q, want %q", req.Method, method)
|
||||
}
|
||||
}
|
||||
|
||||
rc.MaxStreams = 1 // prevent concurrent chunk downloads
|
||||
rc.ChunkingThreshold = 1 // for all blobs to be chunked
|
||||
func newRegistryClient(t *testing.T, h http.HandlerFunc) (*Registry, context.Context) {
|
||||
s := httptest.NewServer(h)
|
||||
t.Cleanup(s.Close)
|
||||
cache, err := blob.Open(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
var reads []int64
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Logf("Update: %v %d %v", l, n, err)
|
||||
mu.Lock()
|
||||
reads = append(reads, n)
|
||||
mu.Unlock()
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
},
|
||||
})
|
||||
|
||||
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
|
||||
blob.DigestFromBytes("hel"),
|
||||
blob.DigestFromBytes("lo"),
|
||||
)
|
||||
err := rc.Pull(ctx, "test")
|
||||
check(err)
|
||||
wantReads := []int64{
|
||||
0, // initial signaling of layer pull starting
|
||||
3, // first chunk read
|
||||
2, // second chunk read
|
||||
}
|
||||
if !slices.Equal(reads, wantReads) {
|
||||
t.Errorf("reads = %v; want %v", reads, wantReads)
|
||||
rc := &Registry{
|
||||
Cache: cache,
|
||||
HTTPClient: &http.Client{Transport: &http.Transport{
|
||||
Dial: func(network, addr string) (net.Conn, error) {
|
||||
return net.Dial(network, s.Listener.Addr().String())
|
||||
},
|
||||
}},
|
||||
}
|
||||
return rc, ctx
|
||||
}
|
||||
|
||||
mw, err := rc.Resolve(t.Context(), "test")
|
||||
check(err)
|
||||
mg, err := rc.ResolveLocal("test")
|
||||
check(err)
|
||||
if !reflect.DeepEqual(mw, mg) {
|
||||
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||
}
|
||||
for i := range mg.Layers {
|
||||
_, err = c.Get(mg.Layers[i].Digest)
|
||||
if err != nil {
|
||||
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
|
||||
func TestPullChunked(t *testing.T) {
|
||||
var steps atomic.Int64
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch steps.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
case 3, 4:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
switch rng := r.Header.Get("Range"); rng {
|
||||
case "bytes=0-1":
|
||||
io.WriteString(w, "ab")
|
||||
case "bytes=2-2":
|
||||
t.Logf("writing c")
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected range %q", rng)
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", steps.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// missing chunks
|
||||
content = "llama"
|
||||
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
|
||||
err = rc.Pull(ctx, "missingchunks")
|
||||
if err == nil {
|
||||
t.Error("expected error because of missing chunks")
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
testutil.Check(t, err)
|
||||
|
||||
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||
testutil.Check(t, err)
|
||||
|
||||
if g := steps.Load(); g != 4 {
|
||||
t.Fatalf("got %d steps, want 4", g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullCached(t *testing.T) {
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
})
|
||||
|
||||
check := testutil.Checker(t)
|
||||
|
||||
// Premeptively cache the blob
|
||||
d, err := blob.ParseDigest("sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
check(err)
|
||||
err = blob.PutBytes(c.Cache, d, []byte("abc"))
|
||||
check(err)
|
||||
|
||||
// Pull only the manifest, which should be enough to resolve the cached blob
|
||||
err = c.Pull(ctx, "http://o.com/library/abc")
|
||||
check(err)
|
||||
}
|
||||
|
||||
func TestPullManifestError(t *testing.T) {
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
io.WriteString(w, `{"errors":[{"code":"MANIFEST_UNKNOWN"}]}`)
|
||||
})
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
var got *Error
|
||||
if !errors.Is(err, ErrModelNotFound) {
|
||||
t.Fatalf("err = %v, want %v", got, ErrModelNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullLayerError(t *testing.T) {
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `!`)
|
||||
})
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
var want *json.SyntaxError
|
||||
if !errors.As(err, &want) {
|
||||
t.Fatalf("err = %T, want %T", err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullLayerChecksumError(t *testing.T) {
|
||||
var step atomic.Int64
|
||||
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
case 3:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
io.WriteString(w, `{"errors":[{"code":"BLOB_UNKNOWN"}]}`)
|
||||
case 4:
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.MaxStreams = 1
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
var written atomic.Int64
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
written.Add(n)
|
||||
},
|
||||
})
|
||||
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
var got *Error
|
||||
if !errors.As(err, &got) || got.Code != "BLOB_UNKNOWN" {
|
||||
t.Fatalf("err = %v, want %v", err, got)
|
||||
}
|
||||
|
||||
if g := written.Load(); g != 1 {
|
||||
t.Fatalf("wrote %d bytes, want 1", g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullChunksumStreamError(t *testing.T) {
|
||||
var step atomic.Int64
|
||||
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
|
||||
// Write one valid chunksum and one invalid chunksum
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) // valid
|
||||
fmt.Fprint(w, "sha256:!") // invalid
|
||||
case 3:
|
||||
io.WriteString(w, "ab")
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
got := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if !errors.Is(got, ErrIncomplete) {
|
||||
t.Fatalf("err = %v, want %v", got, ErrIncomplete)
|
||||
}
|
||||
}
|
||||
|
||||
type flushAfterWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (f *flushAfterWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = f.w.Write(p)
|
||||
f.w.(http.Flusher).Flush() // panic if not a flusher
|
||||
return
|
||||
}
|
||||
|
||||
func TestPullChunksumStreaming(t *testing.T) {
|
||||
csr, csw := io.Pipe()
|
||||
defer csw.Close()
|
||||
|
||||
var step atomic.Int64
|
||||
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fw := &flushAfterWriter{w} // ensure client gets data as it arrives by aggressively flushing
|
||||
_, err := io.Copy(fw, csr)
|
||||
if err != nil {
|
||||
t.Errorf("copy: %v", err)
|
||||
}
|
||||
case 3:
|
||||
io.WriteString(w, "ab")
|
||||
case 4:
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
update := make(chan int64, 1)
|
||||
ctx := WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
if n > 0 {
|
||||
update <- n
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
errc <- c.Pull(ctx, "http://o.com/library/abc")
|
||||
}()
|
||||
|
||||
// Send first chunksum and ensure it kicks off work immediately
|
||||
fmt.Fprintf(csw, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
if g := <-update; g != 2 {
|
||||
t.Fatalf("got %d, want 2", g)
|
||||
}
|
||||
|
||||
// now send the second chunksum and ensure it kicks off work immediately
|
||||
fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
if g := <-update; g != 1 {
|
||||
t.Fatalf("got %d, want 1", g)
|
||||
}
|
||||
csw.Close()
|
||||
testutil.Check(t, <-errc)
|
||||
}
|
||||
|
||||
func TestPullChunksumsCached(t *testing.T) {
|
||||
var step atomic.Int64
|
||||
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch step.Add(1) {
|
||||
case 1:
|
||||
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||
case 2:
|
||||
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||
case 3, 4:
|
||||
switch rng := r.Header.Get("Range"); rng {
|
||||
case "bytes=0-1":
|
||||
io.WriteString(w, "ab")
|
||||
case "bytes=2-2":
|
||||
io.WriteString(w, "c")
|
||||
default:
|
||||
t.Errorf("unexpected range %q", rng)
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
c.MaxStreams = 1 // force serial processing of chunksums
|
||||
c.ChunkingThreshold = 1 // force chunking
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
// Cancel the pull after the first chunksum is processed, but before
|
||||
// the second chunksum is processed (which is waiting because
|
||||
// MaxStreams=1). This should cause the second chunksum to error out
|
||||
// leaving the blob incomplete.
|
||||
ctx = WithTrace(ctx, &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
if n > 0 {
|
||||
cancel()
|
||||
}
|
||||
},
|
||||
})
|
||||
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("err = %v, want %v", err, context.Canceled)
|
||||
}
|
||||
|
||||
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Fatalf("err = %v, want nil", err)
|
||||
}
|
||||
|
||||
// Reset state and pull again to ensure the blob chunks that should
|
||||
// have been cached are, and the remaining chunk was downloaded, making
|
||||
// the blob complete.
|
||||
step.Store(0)
|
||||
var written atomic.Int64
|
||||
var cached atomic.Int64
|
||||
ctx = WithTrace(t.Context(), &Trace{
|
||||
Update: func(l *Layer, n int64, err error) {
|
||||
t.Log("trace:", l.Digest.Short(), n, err)
|
||||
if errors.Is(err, ErrCached) {
|
||||
cached.Add(n)
|
||||
}
|
||||
written.Add(n)
|
||||
},
|
||||
})
|
||||
|
||||
check := testutil.Checker(t)
|
||||
|
||||
err = c.Pull(ctx, "http://o.com/library/abc")
|
||||
check(err)
|
||||
|
||||
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||
check(err)
|
||||
|
||||
if g := written.Load(); g != 3 {
|
||||
t.Fatalf("wrote %d bytes, want 3", g)
|
||||
}
|
||||
if g := cached.Load(); g != 2 { // "ab" should have been cached
|
||||
t.Fatalf("cached %d bytes, want 3", g)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,9 +31,10 @@ const (
|
||||
|
||||
var (
|
||||
ErrInvalidImageFormat = errors.New("invalid image format")
|
||||
ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||
ErrInvalidProtocol = errors.New("invalid protocol scheme")
|
||||
ErrInsecureProtocol = errors.New("insecure protocol http")
|
||||
ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||
ErrModelPathInvalid = errors.New("invalid model path")
|
||||
)
|
||||
|
||||
func ParseModelPath(name string) ModelPath {
|
||||
@@ -73,8 +74,6 @@ func ParseModelPath(name string) ModelPath {
|
||||
return mp
|
||||
}
|
||||
|
||||
var errModelPathInvalid = errors.New("invalid model path")
|
||||
|
||||
func (mp ModelPath) GetNamespaceRepository() string {
|
||||
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
|
||||
}
|
||||
|
||||
@@ -777,7 +777,7 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
||||
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
return nil, errModelPathInvalid
|
||||
return nil, ErrModelPathInvalid
|
||||
}
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
|
||||
@@ -711,7 +711,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuIn
|
||||
req.opts.NumCtx = req.origNumCtx * p
|
||||
if !envconfig.SchedSpread() {
|
||||
for _, g := range sgl {
|
||||
if ok, estimatedVRAM = llm.PredictServerFit([]discover.GpuInfo{g}, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
||||
if ok, estimatedVRAM = llm.PredictServerFit([]discover.GpuInfo{g}, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p); ok {
|
||||
slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
|
||||
*numParallel = p
|
||||
return []discover.GpuInfo{g}
|
||||
@@ -727,7 +727,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuIn
|
||||
// Now try all the GPUs
|
||||
for _, p := range numParallelToTry {
|
||||
req.opts.NumCtx = req.origNumCtx * p
|
||||
if ok, estimatedVRAM = llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
||||
if ok, estimatedVRAM = llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p); ok {
|
||||
slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM))
|
||||
*numParallel = p
|
||||
return sgl
|
||||
@@ -750,7 +750,7 @@ func pickBestPartialFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.Gp
|
||||
var bestEstimate uint64
|
||||
var bestFit int
|
||||
for i, gl := range byLibrary {
|
||||
_, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
|
||||
_, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, *numParallel)
|
||||
if estimatedVRAM > bestEstimate {
|
||||
bestEstimate = estimatedVRAM
|
||||
bestFit = i
|
||||
@@ -825,7 +825,7 @@ func (s *Scheduler) expireRunner(model *Model) {
|
||||
// If not, pick a runner to unload, else return nil and the request can be loaded
|
||||
func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList) *runnerRef {
|
||||
slog.Debug("evaluating if CPU model load will fit in available system memory")
|
||||
estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts)
|
||||
estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts, req.opts.NumCtx/req.origNumCtx)
|
||||
if estimate.TotalSize <= gpus[0].FreeMemory {
|
||||
slog.Debug("cpu inference mode, model fits in available system memory", "model", format.HumanBytes2(estimate.TotalSize), "available", format.HumanBytes2(gpus[0].FreeMemory))
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user