mirror of
https://github.com/ollama/ollama.git
synced 2026-04-20 06:54:29 +02:00
Compare commits
47 Commits
pdevine/bf
...
v0.6.4-rc0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b51e0f397c | ||
|
|
b42970063d | ||
|
|
493385eb3e | ||
|
|
9876c9faa4 | ||
|
|
4e415029b3 | ||
|
|
e172f095ba | ||
|
|
c001b98087 | ||
|
|
23fc8e92eb | ||
|
|
4059a297a6 | ||
|
|
66b2539238 | ||
|
|
ef27d52e79 | ||
|
|
b2a465296d | ||
|
|
5d097277ef | ||
|
|
071a9872cb | ||
|
|
0bd0454ea7 | ||
|
|
01aa788722 | ||
|
|
ead27aa9fe | ||
|
|
b816ff86c9 | ||
|
|
e5d84fb90b | ||
|
|
dd66712e31 | ||
|
|
f66216e399 | ||
|
|
f4f0992b6e | ||
|
|
1feff61977 | ||
|
|
5e0b904e88 | ||
|
|
131f0355a5 | ||
|
|
ce929984a3 | ||
|
|
4b34930a31 | ||
|
|
74bd09652d | ||
|
|
fb6252d786 | ||
|
|
c794fef2f2 | ||
|
|
00ebda8cc4 | ||
|
|
d14ce75b95 | ||
|
|
2d6eac9084 | ||
|
|
3ed7ad3ab3 | ||
|
|
6d1103048e | ||
|
|
0ff28758b3 | ||
|
|
d3e9ca3eda | ||
|
|
0fbfcf3c9c | ||
|
|
0c220935bd | ||
|
|
ffbfe833da | ||
|
|
42a14f7f63 | ||
|
|
f8c3dbe5b5 | ||
|
|
b078dd157c | ||
|
|
2ddacd7516 | ||
|
|
da0e345200 | ||
|
|
df94175a0f | ||
|
|
61a8825216 |
@@ -86,9 +86,9 @@ if(CMAKE_CUDA_COMPILER)
|
|||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a):xnack[+-]$"
|
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a|1200|1201):xnack[+-]$"
|
||||||
CACHE STRING
|
CACHE STRING
|
||||||
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a):xnack[+-]$\"."
|
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a|1200|1201):xnack[+-]$\"."
|
||||||
)
|
)
|
||||||
|
|
||||||
check_language(HIP)
|
check_language(HIP)
|
||||||
@@ -97,7 +97,7 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
|
|
||||||
find_package(hip REQUIRED)
|
find_package(hip REQUIRED)
|
||||||
if(NOT AMDGPU_TARGETS)
|
if(NOT AMDGPU_TARGETS)
|
||||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012])$")
|
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012]|120[01])$")
|
||||||
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
||||||
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@@ -56,7 +56,7 @@
|
|||||||
"name": "ROCm 6",
|
"name": "ROCm 6",
|
||||||
"inherits": [ "ROCm" ],
|
"inherits": [ "ROCm" ],
|
||||||
"cacheVariables": {
|
"cacheVariables": {
|
||||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -285,6 +285,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||||
- [Saddle](https://github.com/jikkuatwork/saddle)
|
- [Saddle](https://github.com/jikkuatwork/saddle)
|
||||||
|
- [TagSpaces](https://www.tagspaces.org) (A platform for file based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
|
||||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||||
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
|
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
|
||||||
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
||||||
@@ -324,6 +325,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
|
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
|
||||||
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
|
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
|
||||||
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
|
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
|
||||||
|
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support and multiple large language models.)
|
||||||
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
||||||
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
||||||
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
|
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
|
||||||
@@ -394,6 +396,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||||
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
|
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
|
||||||
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
|
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
|
||||||
|
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
|
||||||
|
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
|
||||||
|
|
||||||
### Cloud
|
### Cloud
|
||||||
|
|
||||||
@@ -433,7 +437,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
|
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
|
||||||
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
|
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
|
||||||
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
|
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
|
||||||
|
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
|
||||||
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
||||||
|
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.
|
||||||
|
|
||||||
### Apple Vision Pro
|
### Apple Vision Pro
|
||||||
|
|
||||||
@@ -512,6 +518,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
|
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
|
||||||
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
|
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
|
||||||
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
||||||
|
- [Ollama for D](https://github.com/kassane/ollama-d)
|
||||||
|
|
||||||
### Mobile
|
### Mobile
|
||||||
|
|
||||||
|
|||||||
24
api/types.go
24
api/types.go
@@ -12,6 +12,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StatusError is an error with an HTTP status code and message.
|
// StatusError is an error with an HTTP status code and message.
|
||||||
@@ -81,7 +82,7 @@ type GenerateRequest struct {
|
|||||||
|
|
||||||
// Options lists model-specific options. For example, temperature can be
|
// Options lists model-specific options. For example, temperature can be
|
||||||
// set through this field, if the model supports it.
|
// set through this field, if the model supports it.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]any `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatRequest describes a request sent by [Client.Chat].
|
// ChatRequest describes a request sent by [Client.Chat].
|
||||||
@@ -106,7 +107,7 @@ type ChatRequest struct {
|
|||||||
Tools `json:"tools,omitempty"`
|
Tools `json:"tools,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]any `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tools []Tool
|
type Tools []Tool
|
||||||
@@ -260,7 +261,7 @@ type EmbedRequest struct {
|
|||||||
Truncate *bool `json:"truncate,omitempty"`
|
Truncate *bool `json:"truncate,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]any `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// EmbedResponse is the response from [Client.Embed].
|
// EmbedResponse is the response from [Client.Embed].
|
||||||
@@ -286,7 +287,7 @@ type EmbeddingRequest struct {
|
|||||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]any `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// EmbeddingResponse is the response from [Client.Embeddings].
|
// EmbeddingResponse is the response from [Client.Embeddings].
|
||||||
@@ -332,7 +333,7 @@ type ShowRequest struct {
|
|||||||
Template string `json:"template"`
|
Template string `json:"template"`
|
||||||
Verbose bool `json:"verbose"`
|
Verbose bool `json:"verbose"`
|
||||||
|
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]any `json:"options"`
|
||||||
|
|
||||||
// Deprecated: set the model name with Model instead
|
// Deprecated: set the model name with Model instead
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
@@ -350,6 +351,7 @@ type ShowResponse struct {
|
|||||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||||
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||||
Tensors []Tensor `json:"tensors,omitempty"`
|
Tensors []Tensor `json:"tensors,omitempty"`
|
||||||
|
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -503,7 +505,7 @@ func (m *Metrics) Summary() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (opts *Options) FromMap(m map[string]interface{}) error {
|
func (opts *Options) FromMap(m map[string]any) error {
|
||||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||||
|
|
||||||
@@ -560,12 +562,12 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
|
|||||||
}
|
}
|
||||||
field.SetString(val)
|
field.SetString(val)
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
// JSON unmarshals to []interface{}, not []string
|
// JSON unmarshals to []any, not []string
|
||||||
val, ok := val.([]interface{})
|
val, ok := val.([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("option %q must be of type array", key)
|
return fmt.Errorf("option %q must be of type array", key)
|
||||||
}
|
}
|
||||||
// convert []interface{} to []string
|
// convert []any to []string
|
||||||
slice := make([]string, len(val))
|
slice := make([]string, len(val))
|
||||||
for i, item := range val {
|
for i, item := range val {
|
||||||
str, ok := item.(string)
|
str, ok := item.(string)
|
||||||
@@ -672,7 +674,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// FormatParams converts specified parameter options to their correct types
|
// FormatParams converts specified parameter options to their correct types
|
||||||
func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
func FormatParams(params map[string][]string) (map[string]any, error) {
|
||||||
opts := Options{}
|
opts := Options{}
|
||||||
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
|
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
|
||||||
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
|
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
|
||||||
@@ -686,7 +688,7 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
out := make(map[string]interface{})
|
out := make(map[string]any)
|
||||||
// iterate params and set values based on json struct tags
|
// iterate params and set values based on json struct tags
|
||||||
for key, vals := range params {
|
for key, vals := range params {
|
||||||
if opt, ok := jsonOpts[key]; !ok {
|
if opt, ok := jsonOpts[key]; !ok {
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
var oMap map[string]interface{}
|
var oMap map[string]any
|
||||||
err := json.Unmarshal([]byte(test.req), &oMap)
|
err := json.Unmarshal([]byte(test.req), &oMap)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
opts := DefaultOptions()
|
opts := DefaultOptions()
|
||||||
|
|||||||
178
benchmark/server_benchmark_test.go
Normal file
178
benchmark/server_benchmark_test.go
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
package benchmark
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Command line flags
|
||||||
|
var modelFlag string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
|
||||||
|
flag.Lookup("m").DefValue = "model"
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelName returns the model name from flags, failing the test if not set
|
||||||
|
func modelName(b *testing.B) string {
|
||||||
|
if modelFlag == "" {
|
||||||
|
b.Fatal("Error: -m flag is required for benchmark tests")
|
||||||
|
}
|
||||||
|
return modelFlag
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestCase struct {
|
||||||
|
name string
|
||||||
|
prompt string
|
||||||
|
maxTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
// runGenerateBenchmark contains the common generate and metrics logic
|
||||||
|
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
|
||||||
|
start := time.Now()
|
||||||
|
var ttft time.Duration
|
||||||
|
var metrics api.Metrics
|
||||||
|
|
||||||
|
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||||
|
if ttft == 0 && resp.Response != "" {
|
||||||
|
ttft = time.Since(start)
|
||||||
|
}
|
||||||
|
if resp.Done {
|
||||||
|
metrics = resp.Metrics
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Report custom metrics as part of the benchmark results
|
||||||
|
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
|
||||||
|
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
|
||||||
|
|
||||||
|
// Token throughput metrics
|
||||||
|
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
|
||||||
|
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
|
||||||
|
b.ReportMetric(promptThroughput, "prompt_tok/s")
|
||||||
|
b.ReportMetric(genThroughput, "gen_tok/s")
|
||||||
|
|
||||||
|
// Token counts
|
||||||
|
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
|
||||||
|
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkColdStart runs benchmarks with model loading from cold state
|
||||||
|
func BenchmarkColdStart(b *testing.B) {
|
||||||
|
client := setup(b)
|
||||||
|
tests := []TestCase{
|
||||||
|
{"short_prompt", "Write a long story", 100},
|
||||||
|
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||||
|
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||||
|
}
|
||||||
|
m := modelName(b)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Set number of tokens as our throughput metric
|
||||||
|
b.SetBytes(int64(tt.maxTokens))
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
b.StopTimer()
|
||||||
|
// Ensure model is unloaded before each iteration
|
||||||
|
unload(client, m, b)
|
||||||
|
b.StartTimer()
|
||||||
|
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: m,
|
||||||
|
Prompt: tt.prompt,
|
||||||
|
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||||
|
}
|
||||||
|
|
||||||
|
runGenerateBenchmark(b, ctx, client, req)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkWarmStart runs benchmarks with pre-loaded model
|
||||||
|
func BenchmarkWarmStart(b *testing.B) {
|
||||||
|
client := setup(b)
|
||||||
|
tests := []TestCase{
|
||||||
|
{"short_prompt", "Write a long story", 100},
|
||||||
|
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||||
|
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||||
|
}
|
||||||
|
m := modelName(b)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Pre-warm the model
|
||||||
|
warmup(client, m, tt.prompt, b)
|
||||||
|
|
||||||
|
// Set number of tokens as our throughput metric
|
||||||
|
b.SetBytes(int64(tt.maxTokens))
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: m,
|
||||||
|
Prompt: tt.prompt,
|
||||||
|
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||||
|
}
|
||||||
|
|
||||||
|
runGenerateBenchmark(b, ctx, client, req)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setup verifies server and model availability
|
||||||
|
func setup(b *testing.B) *api.Client {
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
|
||||||
|
b.Fatalf("Model unavailable: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
// warmup ensures the model is loaded and warmed up
|
||||||
|
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
|
||||||
|
for range 3 {
|
||||||
|
err := client.Generate(
|
||||||
|
context.Background(),
|
||||||
|
&api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
Prompt: prompt,
|
||||||
|
Options: map[string]any{"num_predict": 50, "temperature": 0.1},
|
||||||
|
},
|
||||||
|
func(api.GenerateResponse) error { return nil },
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
b.Logf("Error during model warm-up: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unload forces model unloading using KeepAlive: 0 parameter
|
||||||
|
func unload(client *api.Client, model string, b *testing.B) {
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
KeepAlive: &api.Duration{Duration: 0},
|
||||||
|
}
|
||||||
|
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
|
||||||
|
b.Logf("Unload error: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
21
cmd/cmd.go
21
cmd/cmd.go
@@ -18,6 +18,7 @@ import (
|
|||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -267,7 +268,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
opts := runOptions{
|
opts := runOptions{
|
||||||
Model: args[0],
|
Model: args[0],
|
||||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||||
Options: map[string]interface{}{},
|
Options: map[string]any{},
|
||||||
}
|
}
|
||||||
|
|
||||||
format, err := cmd.Flags().GetString("format")
|
format, err := cmd.Flags().GetString("format")
|
||||||
@@ -339,6 +340,11 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
|
||||||
|
|
||||||
|
// TODO: remove the projector info and vision info checks below,
|
||||||
|
// these are left in for backwards compatibility with older servers
|
||||||
|
// that don't have the capabilities field in the model info
|
||||||
if len(info.ProjectorInfo) != 0 {
|
if len(info.ProjectorInfo) != 0 {
|
||||||
opts.MultiModal = true
|
opts.MultiModal = true
|
||||||
}
|
}
|
||||||
@@ -669,6 +675,15 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
|||||||
return
|
return
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if len(resp.Capabilities) > 0 {
|
||||||
|
tableRender("Capabilities", func() (rows [][]string) {
|
||||||
|
for _, capability := range resp.Capabilities {
|
||||||
|
rows = append(rows, []string{"", capability.String()})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if resp.ProjectorInfo != nil {
|
if resp.ProjectorInfo != nil {
|
||||||
tableRender("Projector", func() (rows [][]string) {
|
tableRender("Projector", func() (rows [][]string) {
|
||||||
arch := resp.ProjectorInfo["general.architecture"].(string)
|
arch := resp.ProjectorInfo["general.architecture"].(string)
|
||||||
@@ -703,6 +718,8 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
|||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
var v string
|
var v string
|
||||||
switch vData := resp.ModelInfo[k].(type) {
|
switch vData := resp.ModelInfo[k].(type) {
|
||||||
|
case bool:
|
||||||
|
v = fmt.Sprintf("%t", vData)
|
||||||
case string:
|
case string:
|
||||||
v = vData
|
v = vData
|
||||||
case float64:
|
case float64:
|
||||||
@@ -835,7 +852,7 @@ type runOptions struct {
|
|||||||
Format string
|
Format string
|
||||||
System string
|
System string
|
||||||
Images []api.ImageData
|
Images []api.ImageData
|
||||||
Options map[string]interface{}
|
Options map[string]any
|
||||||
MultiModal bool
|
MultiModal bool
|
||||||
KeepAlive *api.Duration
|
KeepAlive *api.Duration
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestShowInfo(t *testing.T) {
|
func TestShowInfo(t *testing.T) {
|
||||||
@@ -87,6 +88,8 @@ func TestShowInfo(t *testing.T) {
|
|||||||
ModelInfo: map[string]any{
|
ModelInfo: map[string]any{
|
||||||
"general.architecture": "test",
|
"general.architecture": "test",
|
||||||
"general.parameter_count": float64(8_000_000_000),
|
"general.parameter_count": float64(8_000_000_000),
|
||||||
|
"some.true_bool": true,
|
||||||
|
"some.false_bool": false,
|
||||||
"test.context_length": float64(1000),
|
"test.context_length": float64(1000),
|
||||||
"test.embedding_length": float64(11434),
|
"test.embedding_length": float64(11434),
|
||||||
},
|
},
|
||||||
@@ -111,6 +114,8 @@ func TestShowInfo(t *testing.T) {
|
|||||||
Metadata
|
Metadata
|
||||||
general.architecture test
|
general.architecture test
|
||||||
general.parameter_count 8e+09
|
general.parameter_count 8e+09
|
||||||
|
some.false_bool false
|
||||||
|
some.true_bool true
|
||||||
test.context_length 1000
|
test.context_length 1000
|
||||||
test.embedding_length 11434
|
test.embedding_length 11434
|
||||||
|
|
||||||
@@ -256,6 +261,34 @@ Weigh anchor!
|
|||||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("capabilities", func(t *testing.T) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := showInfo(&api.ShowResponse{
|
||||||
|
Details: api.ModelDetails{
|
||||||
|
Family: "test",
|
||||||
|
ParameterSize: "7B",
|
||||||
|
QuantizationLevel: "FP16",
|
||||||
|
},
|
||||||
|
Capabilities: []model.Capability{model.CapabilityVision, model.CapabilityTools},
|
||||||
|
}, false, &b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect := " Model\n" +
|
||||||
|
" architecture test \n" +
|
||||||
|
" parameters 7B \n" +
|
||||||
|
" quantization FP16 \n" +
|
||||||
|
"\n" +
|
||||||
|
" Capabilities\n" +
|
||||||
|
" vision \n" +
|
||||||
|
" tools \n" +
|
||||||
|
"\n"
|
||||||
|
|
||||||
|
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||||
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteHandler(t *testing.T) {
|
func TestDeleteHandler(t *testing.T) {
|
||||||
|
|||||||
@@ -201,7 +201,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
case "CohereForCausalLM":
|
case "CohereForCausalLM":
|
||||||
conv = &commandrModel{}
|
conv = &commandrModel{}
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported architecture")
|
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(bts, conv); err != nil {
|
if err := json.Unmarshal(bts, conv); err != nil {
|
||||||
|
|||||||
@@ -1360,7 +1360,7 @@ func file_sentencepiece_model_proto_rawDescGZIP() []byte {
|
|||||||
|
|
||||||
var file_sentencepiece_model_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
|
var file_sentencepiece_model_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
|
||||||
var file_sentencepiece_model_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
|
var file_sentencepiece_model_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
|
||||||
var file_sentencepiece_model_proto_goTypes = []interface{}{
|
var file_sentencepiece_model_proto_goTypes = []any{
|
||||||
(TrainerSpec_ModelType)(0), // 0: sentencepiece.TrainerSpec.ModelType
|
(TrainerSpec_ModelType)(0), // 0: sentencepiece.TrainerSpec.ModelType
|
||||||
(ModelProto_SentencePiece_Type)(0), // 1: sentencepiece.ModelProto.SentencePiece.Type
|
(ModelProto_SentencePiece_Type)(0), // 1: sentencepiece.ModelProto.SentencePiece.Type
|
||||||
(*TrainerSpec)(nil), // 2: sentencepiece.TrainerSpec
|
(*TrainerSpec)(nil), // 2: sentencepiece.TrainerSpec
|
||||||
@@ -1392,7 +1392,7 @@ func file_sentencepiece_model_proto_init() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !protoimpl.UnsafeEnabled {
|
if !protoimpl.UnsafeEnabled {
|
||||||
file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v any, i int) any {
|
||||||
switch v := v.(*TrainerSpec); i {
|
switch v := v.(*TrainerSpec); i {
|
||||||
case 0:
|
case 0:
|
||||||
return &v.state
|
return &v.state
|
||||||
@@ -1406,7 +1406,7 @@ func file_sentencepiece_model_proto_init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v any, i int) any {
|
||||||
switch v := v.(*NormalizerSpec); i {
|
switch v := v.(*NormalizerSpec); i {
|
||||||
case 0:
|
case 0:
|
||||||
return &v.state
|
return &v.state
|
||||||
@@ -1420,7 +1420,7 @@ func file_sentencepiece_model_proto_init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
|
file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v any, i int) any {
|
||||||
switch v := v.(*SelfTestData); i {
|
switch v := v.(*SelfTestData); i {
|
||||||
case 0:
|
case 0:
|
||||||
return &v.state
|
return &v.state
|
||||||
@@ -1434,7 +1434,7 @@ func file_sentencepiece_model_proto_init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
|
file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v any, i int) any {
|
||||||
switch v := v.(*ModelProto); i {
|
switch v := v.(*ModelProto); i {
|
||||||
case 0:
|
case 0:
|
||||||
return &v.state
|
return &v.state
|
||||||
@@ -1448,7 +1448,7 @@ func file_sentencepiece_model_proto_init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
|
file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v any, i int) any {
|
||||||
switch v := v.(*SelfTestData_Sample); i {
|
switch v := v.(*SelfTestData_Sample); i {
|
||||||
case 0:
|
case 0:
|
||||||
return &v.state
|
return &v.state
|
||||||
@@ -1460,7 +1460,7 @@ func file_sentencepiece_model_proto_init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
|
file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v any, i int) any {
|
||||||
switch v := v.(*ModelProto_SentencePiece); i {
|
switch v := v.(*ModelProto_SentencePiece); i {
|
||||||
case 0:
|
case 0:
|
||||||
return &v.state
|
return &v.state
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ func IsNUMA() bool {
|
|||||||
// numa support in llama.cpp is linux only
|
// numa support in llama.cpp is linux only
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
ids := map[string]interface{}{}
|
ids := map[string]any{}
|
||||||
packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id")
|
packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id")
|
||||||
for _, packageId := range packageIds {
|
for _, packageId := range packageIds {
|
||||||
id, err := os.ReadFile(packageId)
|
id, err := os.ReadFile(packageId)
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ func GetCPUDetails() ([]CPU, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer file.Close()
|
||||||
return linuxCPUDetails(file)
|
return linuxCPUDetails(file)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,13 +169,11 @@ func linuxCPUDetails(file io.Reader) ([]CPU, error) {
|
|||||||
for id, s := range socketByID {
|
for id, s := range socketByID {
|
||||||
s.CoreCount = len(coreBySocket[id])
|
s.CoreCount = len(coreBySocket[id])
|
||||||
s.ThreadCount = 0
|
s.ThreadCount = 0
|
||||||
for _, tc := range threadsByCoreBySocket[id] {
|
|
||||||
s.ThreadCount += tc
|
|
||||||
}
|
|
||||||
|
|
||||||
// This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons?
|
// This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons?
|
||||||
efficiencyCoreCount := 0
|
efficiencyCoreCount := 0
|
||||||
for _, threads := range threadsByCoreBySocket[id] {
|
for _, threads := range threadsByCoreBySocket[id] {
|
||||||
|
s.ThreadCount += threads
|
||||||
if threads == 1 {
|
if threads == 1 {
|
||||||
efficiencyCoreCount++
|
efficiencyCoreCount++
|
||||||
}
|
}
|
||||||
|
|||||||
12
docs/api.md
12
docs/api.md
@@ -558,6 +558,10 @@ Final response:
|
|||||||
{
|
{
|
||||||
"model": "llama3.2",
|
"model": "llama3.2",
|
||||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": ""
|
||||||
|
},
|
||||||
"done": true,
|
"done": true,
|
||||||
"total_duration": 4883583458,
|
"total_duration": 4883583458,
|
||||||
"load_duration": 1334875,
|
"load_duration": 1334875,
|
||||||
@@ -1213,7 +1217,7 @@ Show information about a model including details, modelfile, template, parameter
|
|||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/show -d '{
|
curl http://localhost:11434/api/show -d '{
|
||||||
"model": "llama3.2"
|
"model": "llava"
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1256,7 +1260,11 @@ curl http://localhost:11434/api/show -d '{
|
|||||||
"tokenizer.ggml.pre": "llama-bpe",
|
"tokenizer.ggml.pre": "llama-bpe",
|
||||||
"tokenizer.ggml.token_type": [], // populates if `verbose=true`
|
"tokenizer.ggml.token_type": [], // populates if `verbose=true`
|
||||||
"tokenizer.ggml.tokens": [] // populates if `verbose=true`
|
"tokenizer.ggml.tokens": [] // populates if `verbose=true`
|
||||||
}
|
},
|
||||||
|
"capabilities": [
|
||||||
|
"completion",
|
||||||
|
"vision"
|
||||||
|
],
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
59
docs/benchmark.md
Normal file
59
docs/benchmark.md
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# Benchmark
|
||||||
|
|
||||||
|
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
|
||||||
|
|
||||||
|
## When to use
|
||||||
|
|
||||||
|
Run these benchmarks when:
|
||||||
|
- Making changes to the model inference engine
|
||||||
|
- Modifying model loading/unloading logic
|
||||||
|
- Changing prompt processing or token generation code
|
||||||
|
- Implementing a new model architecture
|
||||||
|
- Testing performance across different hardware setups
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
|
||||||
|
## Usage and Examples
|
||||||
|
|
||||||
|
>[!NOTE]
|
||||||
|
>All commands must be run from the root directory of the Ollama project.
|
||||||
|
|
||||||
|
Basic syntax:
|
||||||
|
```bash
|
||||||
|
go test -bench=. ./benchmark/... -m $MODEL_NAME
|
||||||
|
```
|
||||||
|
|
||||||
|
Required flags:
|
||||||
|
- `-bench=.`: Run all benchmarks
|
||||||
|
- `-m`: Model name to benchmark
|
||||||
|
|
||||||
|
Optional flags:
|
||||||
|
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
|
||||||
|
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
|
||||||
|
|
||||||
|
Common usage patterns:
|
||||||
|
|
||||||
|
Single benchmark run with a model specified:
|
||||||
|
```bash
|
||||||
|
go test -bench=. ./benchmark/... -m llama3.3
|
||||||
|
```
|
||||||
|
|
||||||
|
## Output metrics
|
||||||
|
|
||||||
|
The benchmark reports several key metrics:
|
||||||
|
|
||||||
|
- `gen_tok/s`: Generated tokens per second
|
||||||
|
- `prompt_tok/s`: Prompt processing tokens per second
|
||||||
|
- `ttft_ms`: Time to first token in milliseconds
|
||||||
|
- `load_ms`: Model load time in milliseconds
|
||||||
|
- `gen_tokens`: Total tokens generated
|
||||||
|
- `prompt_tokens`: Total prompt tokens processed
|
||||||
|
|
||||||
|
Each benchmark runs two scenarios:
|
||||||
|
- Cold start: Model is loaded from disk for each test
|
||||||
|
- Warm start: Model is pre-loaded in memory
|
||||||
|
|
||||||
|
Three prompt lengths are tested for each scenario:
|
||||||
|
- Short prompt (100 tokens)
|
||||||
|
- Medium prompt (500 tokens)
|
||||||
|
- Long prompt (1000 tokens)
|
||||||
@@ -20,7 +20,13 @@ Please refer to the [GPU docs](./gpu.md).
|
|||||||
|
|
||||||
## How can I specify the context window size?
|
## How can I specify the context window size?
|
||||||
|
|
||||||
By default, Ollama uses a context window size of 2048 tokens. This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context length to 8K, use: `OLLAMA_CONTEXT_LENGTH=8192 ollama serve`.
|
By default, Ollama uses a context window size of 2048 tokens.
|
||||||
|
|
||||||
|
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
OLLAMA_CONTEXT_LENGTH=8192 ollama serve
|
||||||
|
```
|
||||||
|
|
||||||
To change this when using `ollama run`, use `/set parameter`:
|
To change this when using `ollama run`, use `/set parameter`:
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
|
|||||||
On **Linux** systems with systemd, the logs can be found with this command:
|
On **Linux** systems with systemd, the logs can be found with this command:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
journalctl -u ollama --no-pager
|
journalctl -u ollama --no-pager --follow --pager-end
|
||||||
```
|
```
|
||||||
|
|
||||||
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
|
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func assertEqual(t *testing.T, a interface{}, b interface{}) {
|
func assertEqual(t *testing.T, a any, b any) {
|
||||||
if a != b {
|
if a != b {
|
||||||
t.Errorf("Assert failed, expected %v, got %v", b, a)
|
t.Errorf("Assert failed, expected %v, got %v", b, a)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -413,7 +413,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
|||||||
}, offset, nil
|
}, offset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
|
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||||
embedding := f.KV().EmbeddingLength()
|
embedding := f.KV().EmbeddingLength()
|
||||||
heads := f.KV().HeadCount()
|
heads := f.KV().HeadCount()
|
||||||
headsKV := f.KV().HeadCountKV()
|
headsKV := f.KV().HeadCountKV()
|
||||||
@@ -426,7 +426,10 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
|||||||
layers := f.Tensors().GroupLayers()
|
layers := f.Tensors().GroupLayers()
|
||||||
|
|
||||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||||
kv = uint64(float64(context*f.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
kv = make([]uint64, f.KV().BlockCount())
|
||||||
|
for i := range kv {
|
||||||
|
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||||
|
}
|
||||||
|
|
||||||
switch f.KV().Architecture() {
|
switch f.KV().Architecture() {
|
||||||
case "llama":
|
case "llama":
|
||||||
@@ -460,16 +463,14 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
|||||||
case "mllama":
|
case "mllama":
|
||||||
var visionTokens, tiles uint64 = 1601, 4
|
var visionTokens, tiles uint64 = 1601, 4
|
||||||
|
|
||||||
if crossAttentionLayers, ok := f.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
|
crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers")
|
||||||
kv = headsKV *
|
for i := range kv {
|
||||||
(embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
|
if slices.Contains(crossAttentionLayers, uint32(i)) {
|
||||||
(2* // sizeof(float16)
|
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
|
||||||
(f.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
|
|
||||||
context +
|
|
||||||
4 * // sizeof(float32)
|
4 * // sizeof(float32)
|
||||||
uint64(crossAttentionLayers.size)* // num cross attention layers
|
|
||||||
visionTokens *
|
visionTokens *
|
||||||
tiles)
|
tiles
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
@@ -505,6 +506,20 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
|||||||
4*embeddingHeadsK*context*8+
|
4*embeddingHeadsK*context*8+
|
||||||
embedding*embeddingHeadsK*heads*9/16,
|
embedding*embeddingHeadsK*heads*9/16,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
|
||||||
|
// engine. Gemma3 always uses the Ollama engine.
|
||||||
|
if f.KV().Architecture() == "gemma3" {
|
||||||
|
const gemma3GlobalCacheCount = 6
|
||||||
|
slidingWindow := (uint64(numParallel) * uint64(f.KV().Uint("attention.sliding_window"))) + batch
|
||||||
|
for i := range kv {
|
||||||
|
// Every 6th layer is a global layer, which is the full context size that has already been set. The other
|
||||||
|
// layers are the smaller local (sliding) layers.
|
||||||
|
if (i+1)%gemma3GlobalCacheCount != 0 {
|
||||||
|
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
case "command-r":
|
case "command-r":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(embedding+vocab),
|
4*batch*(embedding+vocab),
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ func TestOrcaMiniBlueSky(t *testing.T) {
|
|||||||
Model: "orca-mini",
|
Model: "orca-mini",
|
||||||
Prompt: "why is the sky blue?",
|
Prompt: "why is the sky blue?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"seed": 123,
|
"seed": 123,
|
||||||
},
|
},
|
||||||
@@ -39,7 +39,7 @@ func TestUnicode(t *testing.T) {
|
|||||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K",
|
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K",
|
||||||
Prompt: "天空为什么是蓝色的?",
|
Prompt: "天空为什么是蓝色的?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"seed": 123,
|
"seed": 123,
|
||||||
// Workaround deepseek context shifting bug
|
// Workaround deepseek context shifting bug
|
||||||
@@ -61,7 +61,7 @@ func TestExtendedUnicodeOutput(t *testing.T) {
|
|||||||
Model: "gemma2:2b",
|
Model: "gemma2:2b",
|
||||||
Prompt: "Output some smily face emoji",
|
Prompt: "Output some smily face emoji",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"seed": 123,
|
"seed": 123,
|
||||||
},
|
},
|
||||||
@@ -96,7 +96,7 @@ func TestUnicodeModelDir(t *testing.T) {
|
|||||||
Model: "orca-mini",
|
Model: "orca-mini",
|
||||||
Prompt: "why is the sky blue?",
|
Prompt: "why is the sky blue?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"seed": 123,
|
"seed": 123,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func TestMultiModelConcurrency(t *testing.T) {
|
|||||||
Prompt: "why is the ocean blue?",
|
Prompt: "why is the ocean blue?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
@@ -34,7 +34,7 @@ func TestMultiModelConcurrency(t *testing.T) {
|
|||||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func TestLongInputContext(t *testing.T) {
|
|||||||
Model: "llama2",
|
Model: "llama2",
|
||||||
Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"seed": 123,
|
"seed": 123,
|
||||||
"num_ctx": 128,
|
"num_ctx": 128,
|
||||||
@@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) {
|
|||||||
Model: "llama2",
|
Model: "llama2",
|
||||||
Prompt: "Write me a story with a ton of emojis?",
|
Prompt: "Write me a story with a ton of emojis?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"seed": 123,
|
"seed": 123,
|
||||||
"num_ctx": 128,
|
"num_ctx": 128,
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ func TestIntegrationLlava(t *testing.T) {
|
|||||||
Model: "llava:7b",
|
Model: "llava:7b",
|
||||||
Prompt: "what does the text in this image say?",
|
Prompt: "what does the text in this image say?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
@@ -47,7 +47,7 @@ func TestIntegrationMllama(t *testing.T) {
|
|||||||
Model: "x/llama3.2-vision",
|
Model: "x/llama3.2-vision",
|
||||||
Prompt: "what does the text in this image say?",
|
Prompt: "what does the text in this image say?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
@@ -75,7 +75,7 @@ func TestIntegrationSplitBatch(t *testing.T) {
|
|||||||
System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.",
|
System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.",
|
||||||
Prompt: "what does the text in this image say?",
|
Prompt: "what does the text in this image say?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ var (
|
|||||||
Model: "orca-mini",
|
Model: "orca-mini",
|
||||||
Prompt: "why is the ocean blue?",
|
Prompt: "why is the ocean blue?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
@@ -28,7 +28,7 @@ var (
|
|||||||
Model: "orca-mini",
|
Model: "orca-mini",
|
||||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ func TestMaxQueue(t *testing.T) {
|
|||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: "orca-mini",
|
Model: "orca-mini",
|
||||||
Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey",
|
Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey",
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -291,7 +291,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||||||
Prompt: "why is the ocean blue?",
|
Prompt: "why is the ocean blue?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
@@ -300,7 +300,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||||||
Prompt: "why is the color of dirt brown?",
|
Prompt: "why is the color of dirt brown?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
@@ -309,7 +309,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
@@ -318,7 +318,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||||||
Prompt: "what is the origin of independence day?",
|
Prompt: "what is the origin of independence day?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
@@ -327,7 +327,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||||||
Prompt: "what is the composition of air?",
|
Prompt: "what is the composition of air?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
Options: map[string]interface{}{
|
Options: map[string]any{
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -43,8 +43,13 @@ type Cache interface {
|
|||||||
|
|
||||||
// ** cache management **
|
// ** cache management **
|
||||||
|
|
||||||
// Init sets up runtime parameters
|
// Init sets up runtime parameters.
|
||||||
Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
|
||||||
|
// dtype: The data type for storing cache entries
|
||||||
|
// maxSequences: The maximum number of sequences stored in the cache - across all batches
|
||||||
|
// capacity: The number of cache entries to store, per sequence
|
||||||
|
// maxBatch: The maximum number of tokens that can occur in a single batch
|
||||||
|
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
|
||||||
|
|
||||||
// Close closes the cache and frees resources associated with it
|
// Close closes the cache and frees resources associated with it
|
||||||
Close()
|
Close()
|
||||||
@@ -52,11 +57,16 @@ type Cache interface {
|
|||||||
// StartForward is called before the start of the model's forward pass.
|
// StartForward is called before the start of the model's forward pass.
|
||||||
// For each token in the coming batch, there must be a corresponding
|
// For each token in the coming batch, there must be a corresponding
|
||||||
// entry in positions and seqs.
|
// entry in positions and seqs.
|
||||||
StartForward(ctx ml.Context, opts input.Options) error
|
StartForward(ctx ml.Context, batch input.Batch) error
|
||||||
|
|
||||||
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||||
CopyPrefix(srcSeq, dstSeq int, len int32)
|
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||||
|
|
||||||
|
// CanResume returns true if the cache can continue with the next token at
|
||||||
|
// the given position and sequence. Assumes that the caller has already
|
||||||
|
// verified the contents of the cache.
|
||||||
|
CanResume(seq int, pos int32) bool
|
||||||
|
|
||||||
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
|
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
|
||||||
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
|
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
|
|||||||
// The mask is of shape history size, batch size
|
// The mask is of shape history size, batch size
|
||||||
type Causal struct {
|
type Causal struct {
|
||||||
DType ml.DType
|
DType ml.DType
|
||||||
Capacity int32
|
|
||||||
windowSize int32
|
windowSize int32
|
||||||
|
|
||||||
opts CausalOptions
|
opts CausalOptions
|
||||||
@@ -98,7 +97,7 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
if c.config == nil {
|
if c.config == nil {
|
||||||
var config ml.CacheConfig
|
var config ml.CacheConfig
|
||||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||||
@@ -119,9 +118,16 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
|||||||
c.config.MaskDType = ml.DTypeF32
|
c.config.MaskDType = ml.DTypeF32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var cacheSize int
|
||||||
|
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize) {
|
||||||
|
cacheSize = maxSequences * capacity
|
||||||
|
} else {
|
||||||
|
cacheSize = (maxSequences * int(c.windowSize)) + maxBatch
|
||||||
|
}
|
||||||
|
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||||
|
c.cells = make([]cacheCell, cacheSize)
|
||||||
|
|
||||||
c.DType = dtype
|
c.DType = dtype
|
||||||
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
|
|
||||||
c.cells = make([]cacheCell, c.Capacity)
|
|
||||||
c.cellRanges = make(map[int]cellRange)
|
c.cellRanges = make(map[int]cellRange)
|
||||||
c.backend = backend
|
c.backend = backend
|
||||||
}
|
}
|
||||||
@@ -140,12 +146,14 @@ func (c *Causal) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
|
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||||
c.curBatchSize = len(opts.Positions)
|
c.curBatchSize = len(batch.Positions)
|
||||||
c.curSequences = opts.Sequences
|
c.curSequences = batch.Sequences
|
||||||
c.curPositions = opts.Positions
|
c.curPositions = batch.Positions
|
||||||
c.opts.Except = nil
|
c.opts.Except = nil
|
||||||
|
|
||||||
|
c.updateSlidingWindow()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
c.curLoc, err = c.findStartLoc()
|
c.curLoc, err = c.findStartLoc()
|
||||||
if errors.Is(err, ErrKvCacheFull) {
|
if errors.Is(err, ErrKvCacheFull) {
|
||||||
@@ -157,8 +165,8 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.curCellRange = newRange()
|
c.curCellRange = newRange()
|
||||||
for i, pos := range opts.Positions {
|
for i, pos := range batch.Positions {
|
||||||
seq := opts.Sequences[i]
|
seq := batch.Sequences[i]
|
||||||
|
|
||||||
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||||
|
|
||||||
@@ -210,7 +218,51 @@ func (c *Causal) findStartLoc() (int, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
|
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Causal) updateSlidingWindow() {
|
||||||
|
if c.windowSize == math.MaxInt32 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a map of unique sequences to the lowest position in that sequence
|
||||||
|
lowestPos := make(map[int]int32)
|
||||||
|
for i := range c.curPositions {
|
||||||
|
seq := c.curSequences[i]
|
||||||
|
|
||||||
|
pos, ok := lowestPos[seq]
|
||||||
|
if !ok {
|
||||||
|
pos = c.curPositions[i]
|
||||||
|
} else if c.curPositions[i] < pos {
|
||||||
|
pos = c.curPositions[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
lowestPos[seq] = pos
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete any entries that are beyond the window of the oldest position in the sequence
|
||||||
|
for seq, pos := range lowestPos {
|
||||||
|
oldRange, ok := c.cellRanges[seq]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
newRange := newRange()
|
||||||
|
|
||||||
|
for i := oldRange.min; i <= oldRange.max; i++ {
|
||||||
|
if slices.Contains(c.cells[i].sequences, seq) {
|
||||||
|
if c.cells[i].pos < pos-c.windowSize {
|
||||||
|
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||||
|
} else {
|
||||||
|
newRange.min = min(newRange.min, i)
|
||||||
|
newRange.max = max(newRange.max, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.cellRanges[seq] = newRange
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func roundDown(length, pad int) int {
|
func roundDown(length, pad int) int {
|
||||||
@@ -265,7 +317,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
|||||||
return maskTensor, nil
|
return maskTensor, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
|
||||||
for i, key := range c.keys {
|
for i, key := range c.keys {
|
||||||
if key == nil {
|
if key == nil {
|
||||||
continue
|
continue
|
||||||
@@ -275,8 +327,8 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
|||||||
numKVHeads := key.Dim(1)
|
numKVHeads := key.Dim(1)
|
||||||
rowSize := key.Stride(2)
|
rowSize := key.Stride(2)
|
||||||
|
|
||||||
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
|
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
|
||||||
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
|
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
|
||||||
|
|
||||||
value := c.values[i]
|
value := c.values[i]
|
||||||
var vSrcView, vDstView ml.Tensor
|
var vSrcView, vDstView ml.Tensor
|
||||||
@@ -284,14 +336,14 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
|||||||
vHeadDim := value.Dim(1)
|
vHeadDim := value.Dim(1)
|
||||||
elemSize := value.Stride(0)
|
elemSize := value.Stride(0)
|
||||||
|
|
||||||
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||||
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
||||||
} else {
|
} else {
|
||||||
vHeadDim := value.Dim(0)
|
vHeadDim := value.Dim(0)
|
||||||
rowSize := value.Stride(2)
|
rowSize := value.Stride(2)
|
||||||
|
|
||||||
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
|
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
|
||||||
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
|
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.Forward(
|
ctx.Forward(
|
||||||
@@ -321,7 +373,8 @@ func (c *Causal) defrag() {
|
|||||||
ctx := c.backend.NewContext()
|
ctx := c.backend.NewContext()
|
||||||
|
|
||||||
// For every move, 6 tensors are required per layer (2 views and a
|
// For every move, 6 tensors are required per layer (2 views and a
|
||||||
// copy for each of k and v).
|
// copy for each of k and v). We also need to refer to the original
|
||||||
|
// k and v cache tensors - once per layer, not per move.
|
||||||
layers := 0
|
layers := 0
|
||||||
for _, key := range c.keys {
|
for _, key := range c.keys {
|
||||||
if key == nil {
|
if key == nil {
|
||||||
@@ -330,7 +383,7 @@ func (c *Causal) defrag() {
|
|||||||
layers++
|
layers++
|
||||||
}
|
}
|
||||||
|
|
||||||
maxMoves := ctx.MaxGraphNodes() / (6 * layers)
|
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
|
||||||
moves := 0
|
moves := 0
|
||||||
|
|
||||||
var pendingSrc, pendingDst, pendingLen int
|
var pendingSrc, pendingDst, pendingLen int
|
||||||
@@ -479,14 +532,14 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := c.keys[c.curLayer]; !ok {
|
if _, ok := c.keys[c.curLayer]; !ok {
|
||||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := c.values[c.curLayer]; !ok {
|
if _, ok := c.values[c.curLayer]; !ok {
|
||||||
if c.config.PermutedV {
|
if c.config.PermutedV {
|
||||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
|
||||||
} else {
|
} else {
|
||||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -497,7 +550,7 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
|||||||
elemSize := c.values[c.curLayer].Stride(0)
|
elemSize := c.values[c.curLayer].Stride(0)
|
||||||
|
|
||||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
|
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
|
||||||
} else {
|
} else {
|
||||||
rowSize := c.values[c.curLayer].Stride(2)
|
rowSize := c.values[c.curLayer].Stride(2)
|
||||||
|
|
||||||
@@ -528,6 +581,35 @@ func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
|||||||
c.cellRanges[dstSeq] = seqRange
|
c.cellRanges[dstSeq] = seqRange
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||||
|
if c.windowSize == math.MaxInt32 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
seqRange, ok := c.cellRanges[seq]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// for sliding window, check that the window of the new sequence is contained in
|
||||||
|
// the window of what we are storing
|
||||||
|
var last int32 = -1
|
||||||
|
for i := seqRange.min; i <= seqRange.max; i++ {
|
||||||
|
if slices.Contains(c.cells[i].sequences, seq) {
|
||||||
|
last = max(last, c.cells[i].pos)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if last == -1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
lastWindowStart := max(0, last-c.windowSize)
|
||||||
|
posWindowStart := max(0, pos-c.windowSize)
|
||||||
|
|
||||||
|
return posWindowStart >= lastWindowStart
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||||
if c.shiftFn == nil {
|
if c.shiftFn == nil {
|
||||||
return ErrNotSupported
|
return ErrNotSupported
|
||||||
@@ -582,6 +664,12 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
||||||
|
// TODO(jessegross): We should check to see if removing the middle of the sequence will
|
||||||
|
// cause the sliding window to encompass tokens that we no longer have. If so, then we
|
||||||
|
// should return an error, which will trigger the runner to evaluate the full history and
|
||||||
|
// rebuild the window. However, if we have multimodal inputs in our history, this reuse
|
||||||
|
// results in use after free, so we don't do it for now.
|
||||||
|
|
||||||
var offset int32
|
var offset int32
|
||||||
if endIndex != math.MaxInt32 {
|
if endIndex != math.MaxInt32 {
|
||||||
offset = beginIndex - endIndex
|
offset = beginIndex - endIndex
|
||||||
@@ -596,8 +684,7 @@ func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
|
|||||||
} else {
|
} else {
|
||||||
if c.cells[i].pos >= endIndex {
|
if c.cells[i].pos >= endIndex {
|
||||||
if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
|
if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
|
||||||
// TODO(jessegross): Need to be careful about data shared between sequences
|
return errors.New("shifting cells shared by multiple sequences not supported")
|
||||||
return errors.New("shifting on cells shared by multiple sequences not yet implemented")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.cells[i].pos += offset
|
c.cells[i].pos += offset
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func TestStore(t *testing.T) {
|
|||||||
cache := NewCausalCache(nil)
|
cache := NewCausalCache(nil)
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -58,11 +58,11 @@ func TestSWA(t *testing.T) {
|
|||||||
cache := NewSWACache(1, nil)
|
cache := NewSWACache(1, nil)
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF32, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
name: "SlidingWindow",
|
name: "FirstBatch",
|
||||||
in: []float32{1, 2, 3, 4},
|
in: []float32{1, 2, 3, 4},
|
||||||
inShape: []int{1, 1, 4},
|
inShape: []int{1, 1, 4},
|
||||||
seqs: []int{0, 0, 0, 0},
|
seqs: []int{0, 0, 0, 0},
|
||||||
@@ -71,6 +71,16 @@ func TestSWA(t *testing.T) {
|
|||||||
expectedShape: []int{1, 1, 4},
|
expectedShape: []int{1, 1, 4},
|
||||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "SecondBatch",
|
||||||
|
in: []float32{5, 6},
|
||||||
|
inShape: []int{1, 1, 2},
|
||||||
|
seqs: []int{0, 0},
|
||||||
|
pos: []int32{4, 5},
|
||||||
|
expected: []float32{5, 6, 3, 4},
|
||||||
|
expectedShape: []int{1, 1, 4},
|
||||||
|
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
testCache(t, backend, cache, tests)
|
testCache(t, backend, cache, tests)
|
||||||
@@ -81,7 +91,7 @@ func TestSequences(t *testing.T) {
|
|||||||
cache := NewCausalCache(nil)
|
cache := NewCausalCache(nil)
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -116,7 +126,7 @@ func TestRemove(t *testing.T) {
|
|||||||
})
|
})
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -181,7 +191,7 @@ func TestDefrag(t *testing.T) {
|
|||||||
})
|
})
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -229,7 +239,7 @@ func TestCopy(t *testing.T) {
|
|||||||
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
cache.Init(backend, ml.DTypeF16, 16)
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
@@ -270,7 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
|||||||
context := backend.NewContext()
|
context := backend.NewContext()
|
||||||
defer context.Close()
|
defer context.Close()
|
||||||
|
|
||||||
err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs})
|
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -290,6 +300,77 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCanResume(t *testing.T) {
|
||||||
|
backend := &testBackend{}
|
||||||
|
windowSize := int32(4)
|
||||||
|
cache := NewSWACache(windowSize, nil)
|
||||||
|
defer cache.Close()
|
||||||
|
|
||||||
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
|
|
||||||
|
context := backend.NewContext()
|
||||||
|
defer context.Close()
|
||||||
|
|
||||||
|
err := cache.StartForward(context, input.Batch{
|
||||||
|
Positions: []int32{0, 1, 2, 3},
|
||||||
|
Sequences: []int{0, 0, 0, 0},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StartForward failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cache.SetLayer(0)
|
||||||
|
tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
|
||||||
|
cache.Put(context, tensor, tensor)
|
||||||
|
|
||||||
|
// with window size 4, nothing has slid out of the window yet
|
||||||
|
if !cache.CanResume(0, 0) {
|
||||||
|
t.Errorf("CanResume(0, 0) = false, want true (within window)")
|
||||||
|
}
|
||||||
|
if !cache.CanResume(0, 1) {
|
||||||
|
t.Errorf("CanResume(0, 1) = false, want true (within window)")
|
||||||
|
}
|
||||||
|
if !cache.CanResume(0, 2) {
|
||||||
|
t.Errorf("CanResume(0, 2) = false, want true (within window)")
|
||||||
|
}
|
||||||
|
if !cache.CanResume(0, 3) {
|
||||||
|
t.Errorf("CanResume(0, 3) = false, want true (latest position)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// shift window by adding position 4
|
||||||
|
err = cache.StartForward(context, input.Batch{
|
||||||
|
Positions: []int32{4, 5},
|
||||||
|
Sequences: []int{0, 0},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StartForward failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cache.SetLayer(0)
|
||||||
|
tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
|
||||||
|
cache.Put(context, tensor, tensor)
|
||||||
|
|
||||||
|
// only the latest position has overlapping windows
|
||||||
|
if cache.CanResume(0, 0) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
||||||
|
}
|
||||||
|
if cache.CanResume(0, 1) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
||||||
|
}
|
||||||
|
if cache.CanResume(0, 2) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
||||||
|
}
|
||||||
|
if cache.CanResume(0, 3) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
||||||
|
}
|
||||||
|
if cache.CanResume(0, 4) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
||||||
|
}
|
||||||
|
if !cache.CanResume(0, 5) {
|
||||||
|
t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type testBackend struct{}
|
type testBackend struct{}
|
||||||
|
|
||||||
func (b *testBackend) Config() ml.Config {
|
func (b *testBackend) Config() ml.Config {
|
||||||
@@ -352,7 +433,6 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *testContext) Input() ml.Context { return c }
|
func (c *testContext) Input() ml.Context { return c }
|
||||||
func (c *testContext) Output() ml.Context { return c }
|
|
||||||
func (c *testContext) Layer(int) ml.Context { return c }
|
func (c *testContext) Layer(int) ml.Context { return c }
|
||||||
|
|
||||||
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ func NewEncoderCache() *EncoderCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
if c.config == nil {
|
if c.config == nil {
|
||||||
var config ml.CacheConfig
|
var config ml.CacheConfig
|
||||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||||
@@ -58,6 +58,10 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
|||||||
c.config = &config
|
c.config = &config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if maxSequences > 1 {
|
||||||
|
panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
|
||||||
|
}
|
||||||
|
|
||||||
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
||||||
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||||
}
|
}
|
||||||
@@ -79,10 +83,10 @@ func (c *EncoderCache) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error {
|
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||||
// We work with the most recent image
|
// We work with the most recent image
|
||||||
if len(opts.Multimodal) > 0 {
|
if len(batch.Multimodal) > 0 {
|
||||||
c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index]
|
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -130,6 +134,10 @@ func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
|||||||
panic("encoder cache does not support multiple sequences")
|
panic("encoder cache does not support multiple sequences")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *EncoderCache) CanResume(seq int, pos int32) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
|
func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||||
if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
|
if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
|
||||||
c.encoderCached = false
|
c.encoderCached = false
|
||||||
|
|||||||
@@ -23,9 +23,9 @@ func NewWrapperCache(caches ...Cache) *WrapperCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||||
for _, cache := range c.caches {
|
for _, cache := range c.caches {
|
||||||
cache.Init(backend, dtype, capacity)
|
cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,14 +41,14 @@ func (c *WrapperCache) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
|
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||||
for i, cache := range c.caches {
|
for i, cache := range c.caches {
|
||||||
err := cache.StartForward(ctx, opts)
|
err := cache.StartForward(ctx, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
||||||
for j := i - 1; j >= 0; j-- {
|
for j := i - 1; j >= 0; j-- {
|
||||||
for k := range opts.Positions {
|
for k := range batch.Positions {
|
||||||
_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
|
_ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@@ -87,6 +87,16 @@ func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *WrapperCache) CanResume(seq int, pos int32) bool {
|
||||||
|
for _, cache := range c.caches {
|
||||||
|
if !cache.CanResume(seq, pos) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
|
func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||||
// If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
|
// If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
|
||||||
for _, cache := range c.caches {
|
for _, cache := range c.caches {
|
||||||
|
|||||||
@@ -166,6 +166,10 @@ func (c *Context) KvCacheDefrag() {
|
|||||||
C.llama_kv_cache_defrag(c.c)
|
C.llama_kv_cache_defrag(c.c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Context) KvCacheCanShift() bool {
|
||||||
|
return bool(C.llama_kv_cache_can_shift(c.c))
|
||||||
|
}
|
||||||
|
|
||||||
// Get the embeddings for a sequence id
|
// Get the embeddings for a sequence id
|
||||||
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
|
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
|
||||||
e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
|
e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
|
||||||
|
|||||||
103
llama/patches/0022-add-rdna4-support.patch
Normal file
103
llama/patches/0022-add-rdna4-support.patch
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Saman <saman.khatir@amd.com>
|
||||||
|
Date: Wed, 19 Mar 2025 14:02:26 -0700
|
||||||
|
Subject: [PATCH] add rdna4 support
|
||||||
|
|
||||||
|
---
|
||||||
|
ggml/src/ggml-cuda/common.cuh | 6 ++++--
|
||||||
|
ggml/src/ggml-cuda/mmq.cu | 2 +-
|
||||||
|
ggml/src/ggml-cuda/mmq.cuh | 4 ++--
|
||||||
|
ggml/src/ggml-cuda/mmvq.cu | 4 ++--
|
||||||
|
ggml/src/ggml-cuda/vendors/hip.h | 4 ++++
|
||||||
|
5 files changed, 13 insertions(+), 7 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
|
||||||
|
index adf0d3ec..b24593fc 100644
|
||||||
|
--- a/ggml/src/ggml-cuda/common.cuh
|
||||||
|
+++ b/ggml/src/ggml-cuda/common.cuh
|
||||||
|
@@ -61,11 +61,13 @@
|
||||||
|
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
|
||||||
|
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
||||||
|
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
||||||
|
+#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
|
||||||
|
|
||||||
|
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
||||||
|
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
||||||
|
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
||||||
|
-#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
|
||||||
|
+#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
|
||||||
|
+#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
|
||||||
|
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
|
||||||
|
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
||||||
|
|
||||||
|
@@ -386,7 +388,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
||||||
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
|
||||||
|
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
||||||
|
-#elif defined(RDNA3)
|
||||||
|
+#elif defined(RDNA3) || defined(RDNA4)
|
||||||
|
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
||||||
|
#elif defined(__gfx1010__) || defined(__gfx900__)
|
||||||
|
int tmp1;
|
||||||
|
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
|
||||||
|
index 10f2ebb1..933d945c 100644
|
||||||
|
--- a/ggml/src/ggml-cuda/mmq.cu
|
||||||
|
+++ b/ggml/src/ggml-cuda/mmq.cu
|
||||||
|
@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||||
|
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
|
- return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||||
|
+ return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||||
|
}
|
||||||
|
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
|
||||||
|
index 0451c65f..66ce2bc9 100644
|
||||||
|
--- a/ggml/src/ggml-cuda/mmq.cuh
|
||||||
|
+++ b/ggml/src/ggml-cuda/mmq.cuh
|
||||||
|
@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
|
||||||
|
|
||||||
|
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
||||||
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
-#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||||
|
+#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||||
|
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
||||||
|
-#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||||
|
+#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||||
|
#else
|
||||||
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||||
|
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
||||||
|
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
|
||||||
|
index 4fb466ca..23ae7abc 100644
|
||||||
|
--- a/ggml/src/ggml-cuda/mmvq.cu
|
||||||
|
+++ b/ggml/src/ggml-cuda/mmvq.cu
|
||||||
|
@@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q(
|
||||||
|
|
||||||
|
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
||||||
|
|
||||||
|
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
||||||
|
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
|
||||||
|
constexpr int nwarps = 1;
|
||||||
|
constexpr int rows_per_cuda_block = 1;
|
||||||
|
#else
|
||||||
|
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
||||||
|
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
||||||
|
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
||||||
|
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
|
||||||
|
|
||||||
|
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||||
|
const int row0 = rows_per_cuda_block*blockIdx.x;
|
||||||
|
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
|
||||||
|
index 81964611..a62544b5 100644
|
||||||
|
--- a/ggml/src/ggml-cuda/vendors/hip.h
|
||||||
|
+++ b/ggml/src/ggml-cuda/vendors/hip.h
|
||||||
|
@@ -150,6 +150,10 @@
|
||||||
|
#define CDNA
|
||||||
|
#endif
|
||||||
|
|
||||||
|
+#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||||
|
+#define RDNA4
|
||||||
|
+#endif
|
||||||
|
+
|
||||||
|
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
||||||
|
defined(__gfx1150__) || defined(__gfx1151__)
|
||||||
|
#define RDNA3
|
||||||
@@ -15,12 +15,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// This algorithm looks for a complete fit to determine if we need to unload other models
|
// This algorithm looks for a complete fit to determine if we need to unload other models
|
||||||
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
|
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
|
||||||
// Split up the GPUs by type and try them
|
// Split up the GPUs by type and try them
|
||||||
var estimatedVRAM uint64
|
var estimatedVRAM uint64
|
||||||
for _, gpus := range allGpus.ByLibrary() {
|
for _, gpus := range allGpus.ByLibrary() {
|
||||||
var layerCount int
|
var layerCount int
|
||||||
estimate := EstimateGPULayers(gpus, f, projectors, opts)
|
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
|
||||||
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
|
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
|
||||||
if opts.NumGPU < 0 {
|
if opts.NumGPU < 0 {
|
||||||
if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) {
|
if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) {
|
||||||
@@ -71,7 +71,7 @@ type MemoryEstimate struct {
|
|||||||
|
|
||||||
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
|
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
|
||||||
// The GPUs provided must all be the same Library
|
// The GPUs provided must all be the same Library
|
||||||
func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options) MemoryEstimate {
|
func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate {
|
||||||
// Graph size for a partial offload, applies to all GPUs
|
// Graph size for a partial offload, applies to all GPUs
|
||||||
var graphPartialOffload uint64
|
var graphPartialOffload uint64
|
||||||
|
|
||||||
@@ -137,13 +137,19 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), kvct)
|
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct)
|
||||||
|
|
||||||
// KV is proportional to the number of layers
|
if len(kv) > 0 {
|
||||||
layerSize += kv / f.KV().BlockCount()
|
layerSize += kv[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
var kvTotal uint64
|
||||||
|
for _, kvLayer := range kv {
|
||||||
|
kvTotal += kvLayer
|
||||||
|
}
|
||||||
|
|
||||||
if graphPartialOffload == 0 {
|
if graphPartialOffload == 0 {
|
||||||
graphPartialOffload = f.KV().GQA() * kv / 6
|
graphPartialOffload = f.KV().GQA() * kvTotal / 6
|
||||||
}
|
}
|
||||||
if graphFullOffload == 0 {
|
if graphFullOffload == 0 {
|
||||||
graphFullOffload = graphPartialOffload
|
graphFullOffload = graphPartialOffload
|
||||||
@@ -217,7 +223,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||||||
// Some models have inconsistent layer sizes
|
// Some models have inconsistent layer sizes
|
||||||
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
||||||
layerSize = blk.Size()
|
layerSize = blk.Size()
|
||||||
layerSize += kv / f.KV().BlockCount()
|
layerSize += kv[i]
|
||||||
memoryWeights += blk.Size()
|
memoryWeights += blk.Size()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -315,7 +321,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||||||
layersRequested: opts.NumGPU,
|
layersRequested: opts.NumGPU,
|
||||||
layersModel: int(f.KV().BlockCount()) + 1,
|
layersModel: int(f.KV().BlockCount()) + 1,
|
||||||
availableList: availableList,
|
availableList: availableList,
|
||||||
kv: kv,
|
kv: kvTotal,
|
||||||
allocationsList: allocationsList,
|
allocationsList: allocationsList,
|
||||||
memoryWeights: memoryWeights,
|
memoryWeights: memoryWeights,
|
||||||
memoryLayerOutput: memoryLayerOutput,
|
memoryLayerOutput: memoryLayerOutput,
|
||||||
@@ -374,7 +380,7 @@ func (m MemoryEstimate) LogValue() slog.Value {
|
|||||||
slog.Group(
|
slog.Group(
|
||||||
"weights",
|
"weights",
|
||||||
// memory of the weights
|
// memory of the weights
|
||||||
"total", format.HumanBytes2(m.memoryWeights),
|
"total", format.HumanBytes2(m.memoryWeights+m.memoryLayerOutput),
|
||||||
// memory of repeating layers
|
// memory of repeating layers
|
||||||
"repeating", format.HumanBytes2(m.memoryWeights),
|
"repeating", format.HumanBytes2(m.memoryWeights),
|
||||||
// memory of non-repeating layers
|
// memory of non-repeating layers
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ func TestEstimateGPULayers(t *testing.T) {
|
|||||||
projectors := []string{}
|
projectors := []string{}
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
t.Run("cpu", func(t *testing.T) {
|
t.Run("cpu", func(t *testing.T) {
|
||||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
|
estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
|
||||||
assert.Equal(t, 0, estimate.Layers)
|
assert.Equal(t, 0, estimate.Layers)
|
||||||
assert.Equal(t, uint64(0), estimate.Graph)
|
assert.Equal(t, uint64(0), estimate.Graph)
|
||||||
})
|
})
|
||||||
@@ -112,7 +112,7 @@ func TestEstimateGPULayers(t *testing.T) {
|
|||||||
gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
|
gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
|
||||||
gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
||||||
gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
||||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
|
estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
|
||||||
assert.Equal(t, int(s.expect0+s.expect1), estimate.Layers, "scenario %d: %v", i, s)
|
assert.Equal(t, int(s.expect0+s.expect1), estimate.Layers, "scenario %d: %v", i, s)
|
||||||
assert.Equal(t, fmt.Sprintf("%d,%d", s.expect0, s.expect1), estimate.TensorSplit, "scenario %d: %v", i, s)
|
assert.Equal(t, fmt.Sprintf("%d,%d", s.expect0, s.expect1), estimate.TensorSplit, "scenario %d: %v", i, s)
|
||||||
var layerSums uint64
|
var layerSums uint64
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
gpus = discover.GetCPUInfo()
|
gpus = discover.GetCPUInfo()
|
||||||
}
|
}
|
||||||
|
|
||||||
estimate := EstimateGPULayers(gpus, f, projectors, opts)
|
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
|
||||||
if len(gpus) > 1 || gpus[0].Library != "cpu" {
|
if len(gpus) > 1 || gpus[0].Library != "cpu" {
|
||||||
switch {
|
switch {
|
||||||
case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
|
case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package ml
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
@@ -60,6 +61,10 @@ type CacheConfig struct {
|
|||||||
|
|
||||||
// BackendParams controls how the backend loads and executes models
|
// BackendParams controls how the backend loads and executes models
|
||||||
type BackendParams struct {
|
type BackendParams struct {
|
||||||
|
// Progress is a callback function that allows reporting percentage completion
|
||||||
|
// of model loading
|
||||||
|
Progress func(float32)
|
||||||
|
|
||||||
// NumThreads sets the number of threads to use if running on the CPU
|
// NumThreads sets the number of threads to use if running on the CPU
|
||||||
NumThreads int
|
NumThreads int
|
||||||
|
|
||||||
@@ -76,9 +81,9 @@ type BackendParams struct {
|
|||||||
FlashAttention bool
|
FlashAttention bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
|
var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error))
|
||||||
|
|
||||||
func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
|
func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) {
|
||||||
if _, ok := backends[name]; ok {
|
if _, ok := backends[name]; ok {
|
||||||
panic("backend: backend already registered")
|
panic("backend: backend already registered")
|
||||||
}
|
}
|
||||||
@@ -86,9 +91,9 @@ func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, erro
|
|||||||
backends[name] = f
|
backends[name] = f
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBackend(f *os.File, params BackendParams) (Backend, error) {
|
func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) {
|
||||||
if backend, ok := backends["ggml"]; ok {
|
if backend, ok := backends["ggml"]; ok {
|
||||||
return backend(f, params)
|
return backend(ctx, f, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("unsupported backend")
|
return nil, fmt.Errorf("unsupported backend")
|
||||||
@@ -105,12 +110,10 @@ type Context interface {
|
|||||||
MaxGraphNodes() int
|
MaxGraphNodes() int
|
||||||
Close()
|
Close()
|
||||||
|
|
||||||
// Input returns a context appropriate for creating input tensors
|
// Input returns a context appropriate for creating tensors that are
|
||||||
|
// inputs to the model (which includes things like output locations)
|
||||||
Input() Context
|
Input() Context
|
||||||
|
|
||||||
// Output returns a context appropriate for creating output tensors
|
|
||||||
Output() Context
|
|
||||||
|
|
||||||
// Layer returns a context appropriate for creating intermediate tensors
|
// Layer returns a context appropriate for creating intermediate tensors
|
||||||
Layer(int) Context
|
Layer(int) Context
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,15 +9,17 @@ package ggml
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"maps"
|
"maps"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"unicode"
|
"unicode"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
@@ -46,9 +48,6 @@ type Backend struct {
|
|||||||
// input is the backend used for inputs
|
// input is the backend used for inputs
|
||||||
input *C.struct_ggml_backend_buffer_type
|
input *C.struct_ggml_backend_buffer_type
|
||||||
|
|
||||||
// output is the backend used for outputs
|
|
||||||
output *C.struct_ggml_backend_buffer_type
|
|
||||||
|
|
||||||
// layers is the backend used for repeating layers
|
// layers is the backend used for repeating layers
|
||||||
layers map[int]*C.struct_ggml_backend_buffer_type
|
layers map[int]*C.struct_ggml_backend_buffer_type
|
||||||
|
|
||||||
@@ -58,7 +57,7 @@ type Backend struct {
|
|||||||
maxGraphNodes int
|
maxGraphNodes int
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||||
meta, n, err := fs.Decode(r, -1)
|
meta, n, err := fs.Decode(r, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -297,12 +296,16 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// concurrently read in tensor data. uses a section reader which is safe for concurrent reads
|
var doneBytes atomic.Uint64
|
||||||
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
|
totalBytes := uint64(n) - meta.Tensors().Offset
|
||||||
var g errgroup.Group
|
|
||||||
|
g, ctx := errgroup.WithContext(ctx)
|
||||||
|
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||||
for _, t := range meta.Tensors().Items() {
|
for _, t := range meta.Tensors().Items() {
|
||||||
for _, target := range targets[t.Name] {
|
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
|
tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
|
||||||
|
for i := range tts {
|
||||||
|
target := targets[t.Name][i]
|
||||||
if target == "" {
|
if target == "" {
|
||||||
target = t.Name
|
target = t.Name
|
||||||
}
|
}
|
||||||
@@ -312,25 +315,44 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
return fmt.Errorf("unassigned tensor: %s", t.Name)
|
return fmt.Errorf("unassigned tensor: %s", t.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
bts := C.malloc(C.size_t(t.Size()))
|
tts[i] = tt
|
||||||
if bts == nil {
|
}
|
||||||
return errors.New("failed to allocate tensor buffer")
|
|
||||||
}
|
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
|
||||||
defer C.free(bts)
|
bts := make([]byte, 128*format.KibiByte)
|
||||||
|
|
||||||
buf := unsafe.Slice((*byte)(bts), t.Size())
|
var s uint64
|
||||||
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf)
|
for s < t.Size() {
|
||||||
if err != nil || n != len(buf) {
|
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
|
||||||
return errors.New("read failed")
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tts {
|
||||||
|
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
s += uint64(n)
|
||||||
|
|
||||||
|
if params.Progress != nil {
|
||||||
|
done := doneBytes.Add(uint64(n))
|
||||||
|
params.Progress(float32(done) / float32(totalBytes))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size()))
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if g.Wait() != nil {
|
// start a goroutine to cancel the errgroup if the parent context is done
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
g.Go(func() error {
|
||||||
|
return ctx.Err()
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := g.Wait(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -376,7 +398,6 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
|
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
|
||||||
),
|
),
|
||||||
input: deviceBufferTypes[input.d],
|
input: deviceBufferTypes[input.d],
|
||||||
output: deviceBufferTypes[output.d],
|
|
||||||
layers: func() map[int]*C.struct_ggml_backend_buffer_type {
|
layers: func() map[int]*C.struct_ggml_backend_buffer_type {
|
||||||
m := make(map[int]*C.struct_ggml_backend_buffer_type)
|
m := make(map[int]*C.struct_ggml_backend_buffer_type)
|
||||||
for i, layer := range layers {
|
for i, layer := range layers {
|
||||||
@@ -457,19 +478,6 @@ func (c Context) Input() ml.Context {
|
|||||||
return &c
|
return &c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Context) Output() ml.Context {
|
|
||||||
if c.b.output != nil {
|
|
||||||
return &Context{
|
|
||||||
b: c.b,
|
|
||||||
ctx: c.ctx,
|
|
||||||
buft: c.b.output,
|
|
||||||
maxGraphNodes: c.maxGraphNodes,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &c
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c Context) Layer(i int) ml.Context {
|
func (c Context) Layer(i int) ml.Context {
|
||||||
if buft, ok := c.b.layers[i]; ok {
|
if buft, ok := c.b.layers[i]; ok {
|
||||||
return &Context{
|
return &Context{
|
||||||
|
|||||||
@@ -61,11 +61,13 @@
|
|||||||
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
|
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
|
||||||
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
||||||
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
||||||
|
#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
|
||||||
|
|
||||||
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
|
||||||
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
|
||||||
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
|
||||||
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
|
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
|
||||||
|
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
|
||||||
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
|
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
|
||||||
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
|
||||||
|
|
||||||
@@ -386,7 +388,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
|||||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||||
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
|
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
|
||||||
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
||||||
#elif defined(RDNA3)
|
#elif defined(RDNA3) || defined(RDNA4)
|
||||||
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
|
||||||
#elif defined(__gfx1010__) || defined(__gfx900__)
|
#elif defined(__gfx1010__) || defined(__gfx900__)
|
||||||
int tmp1;
|
int tmp1;
|
||||||
|
|||||||
2
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
vendored
2
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
vendored
@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
|||||||
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||||
}
|
}
|
||||||
|
|||||||
4
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
vendored
4
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
vendored
@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
|
|||||||
|
|
||||||
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
||||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||||
#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||||
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
||||||
#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
||||||
#else
|
#else
|
||||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||||
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
||||||
|
|||||||
4
ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
vendored
4
ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
vendored
@@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q(
|
|||||||
|
|
||||||
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
||||||
|
|
||||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
|
||||||
constexpr int nwarps = 1;
|
constexpr int nwarps = 1;
|
||||||
constexpr int rows_per_cuda_block = 1;
|
constexpr int rows_per_cuda_block = 1;
|
||||||
#else
|
#else
|
||||||
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
||||||
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
||||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
|
||||||
|
|
||||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||||
const int row0 = rows_per_cuda_block*blockIdx.x;
|
const int row0 = rows_per_cuda_block*blockIdx.x;
|
||||||
|
|||||||
@@ -150,6 +150,10 @@
|
|||||||
#define CDNA
|
#define CDNA
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||||
|
#define RDNA4
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
|
||||||
defined(__gfx1150__) || defined(__gfx1151__)
|
defined(__gfx1150__) || defined(__gfx1151__)
|
||||||
#define RDNA3
|
#define RDNA3
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package input
|
package input
|
||||||
|
|
||||||
|
import "github.com/ollama/ollama/ml"
|
||||||
|
|
||||||
// Input represents one token in the input stream
|
// Input represents one token in the input stream
|
||||||
type Input struct {
|
type Input struct {
|
||||||
// Token is a single element of text.
|
// Token is a single element of text.
|
||||||
@@ -33,11 +35,24 @@ type MultimodalIndex struct {
|
|||||||
Multimodal any
|
Multimodal any
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options contains the inputs for a model forward pass
|
// Batch contains the inputs for a model forward pass
|
||||||
type Options struct {
|
type Batch struct {
|
||||||
Inputs []int32
|
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
||||||
|
Inputs ml.Tensor
|
||||||
|
|
||||||
|
// Multimodal is a set of multimodal embeddings previously created by
|
||||||
|
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||||
|
// models or for batches without multimodal elements.
|
||||||
Multimodal []MultimodalIndex
|
Multimodal []MultimodalIndex
|
||||||
|
|
||||||
|
// Positions is the position for each Input, relative to its sequence. Equal
|
||||||
|
// in length to Inputs.
|
||||||
Positions []int32
|
Positions []int32
|
||||||
|
|
||||||
|
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
||||||
Sequences []int
|
Sequences []int
|
||||||
|
|
||||||
|
// Outputs are the set of indicies into Inputs for which output data should
|
||||||
|
// be returned.
|
||||||
Outputs []int32
|
Outputs []int32
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
@@ -26,7 +27,7 @@ var ErrNoVisionModel = errors.New("this model is missing data required for image
|
|||||||
|
|
||||||
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
||||||
type Model interface {
|
type Model interface {
|
||||||
Forward(ml.Context, input.Options) (ml.Tensor, error)
|
Forward(ml.Context, input.Batch) (ml.Tensor, error)
|
||||||
|
|
||||||
Backend() ml.Backend
|
Backend() ml.Backend
|
||||||
Config() config
|
Config() config
|
||||||
@@ -94,14 +95,14 @@ func Register(name string, f func(ml.Config) (Model, error)) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// New initializes a new model instance with the provided configuration based on the metadata in the model file
|
// New initializes a new model instance with the provided configuration based on the metadata in the model file
|
||||||
func New(modelPath string, params ml.BackendParams) (Model, error) {
|
func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) {
|
||||||
r, err := os.Open(modelPath)
|
r, err := os.Open(modelPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
|
|
||||||
b, err := ml.NewBackend(r, params)
|
b, err := ml.NewBackend(ctx, r, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -280,24 +281,30 @@ func canNil(t reflect.Type) bool {
|
|||||||
t.Kind() == reflect.Slice
|
t.Kind() == reflect.Slice
|
||||||
}
|
}
|
||||||
|
|
||||||
func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
|
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
|
||||||
if len(opts.Positions) != len(opts.Sequences) {
|
if len(batch.Positions) != len(batch.Sequences) {
|
||||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
|
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(opts.Positions) < 1 {
|
if len(batch.Positions) < 1 {
|
||||||
return nil, errors.New("batch size cannot be less than 1")
|
return nil, errors.New("batch size cannot be less than 1")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
cache := m.Config().Cache
|
cache := m.Config().Cache
|
||||||
if cache != nil {
|
if cache != nil {
|
||||||
err := cache.StartForward(ctx, opts)
|
err := cache.StartForward(ctx, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t, err := m.Forward(ctx, opts)
|
t, err := m.Forward(ctx, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -163,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) {
|
|||||||
|
|
||||||
type notTextProcessorModel struct{}
|
type notTextProcessorModel struct{}
|
||||||
|
|
||||||
func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) {
|
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
|
||||||
panic("unimplemented")
|
panic("unimplemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ const (
|
|||||||
func New(c ml.Config) (model.Model, error) {
|
func New(c ml.Config) (model.Model, error) {
|
||||||
m := Model{
|
m := Model{
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
SentencePieceModel: model.NewSentencePieceModel(
|
||||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
@@ -168,23 +167,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
|||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
||||||
|
|
||||||
if len(m.Layers) == gemma27BLayerCount {
|
if len(m.Layers) == gemma27BLayerCount {
|
||||||
@@ -211,8 +205,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
|||||||
// final logit softcap
|
// final logit softcap
|
||||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
|
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
|
||||||
hiddenState = hiddenState.Tanh(ctx)
|
hiddenState = hiddenState.Tanh(ctx)
|
||||||
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
|
return hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)), nil
|
||||||
return hiddenState.Rows(ctx, outputs), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -55,7 +55,6 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
|||||||
func New(c ml.Config) (model.Model, error) {
|
func New(c ml.Config) (model.Model, error) {
|
||||||
m := Model{
|
m := Model{
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
SentencePieceModel: model.NewSentencePieceModel(
|
||||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
@@ -139,23 +138,18 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -45,7 +45,6 @@ func newTextModel(c ml.Config) *TextModel {
|
|||||||
|
|
||||||
m := TextModel{
|
m := TextModel{
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
SentencePieceModel: model.NewSentencePieceModel(
|
||||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
@@ -171,13 +170,13 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
|||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
|
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
||||||
|
|
||||||
// set image embeddings
|
// set image embeddings
|
||||||
var except []int
|
var except []int
|
||||||
for _, image := range opts.Multimodal {
|
for _, image := range batch.Multimodal {
|
||||||
visionOutputs := image.Multimodal.(ml.Tensor)
|
visionOutputs := image.Multimodal.(ml.Tensor)
|
||||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||||
|
|
||||||
|
|||||||
@@ -139,23 +139,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
|||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
m.Cache.SetLayer(i)
|
m.Cache.SetLayer(i)
|
||||||
|
|||||||
@@ -135,32 +135,27 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
return inputs, nil
|
return inputs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
var crossAttentionStates ml.Tensor
|
var crossAttentionStates ml.Tensor
|
||||||
if len(opts.Multimodal) > 0 {
|
if len(batch.Multimodal) > 0 {
|
||||||
images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor)
|
images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor)
|
||||||
if len(images) > 0 {
|
if len(images) > 0 {
|
||||||
crossAttentionStates = images[len(images)-1]
|
crossAttentionStates = images[len(images)-1]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: attention mask, cross attention mask
|
// TODO: attention mask, cross attention mask
|
||||||
return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -1,29 +1,23 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"iter"
|
"container/heap"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/dlclark/regexp2"
|
|
||||||
queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const spmWhitespaceSep = "▁"
|
const spmWhitespaceSep = "▁"
|
||||||
|
|
||||||
func replaceWhitespaceBySeperator(s string) string {
|
|
||||||
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
|
|
||||||
}
|
|
||||||
|
|
||||||
type SentencePieceModel struct {
|
type SentencePieceModel struct {
|
||||||
maxTokenLen int
|
maxTokenLen int
|
||||||
pre *regexp2.Regexp
|
|
||||||
vocab *Vocabulary
|
vocab *Vocabulary
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ TextProcessor = (*SentencePieceModel)(nil)
|
var _ TextProcessor = (*SentencePieceModel)(nil)
|
||||||
|
|
||||||
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
||||||
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||||
|
|
||||||
counter := map[int]int{}
|
counter := map[int]int{}
|
||||||
@@ -44,7 +38,6 @@ func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
|||||||
|
|
||||||
return SentencePieceModel{
|
return SentencePieceModel{
|
||||||
maxTokenLen: maxTokenLen,
|
maxTokenLen: maxTokenLen,
|
||||||
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
|
|
||||||
vocab: vocab,
|
vocab: vocab,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -53,20 +46,9 @@ func (spm SentencePieceModel) Is(id int32, special Special) bool {
|
|||||||
return spm.vocab.Is(id, special)
|
return spm.vocab.Is(id, special)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
|
|
||||||
return func(yield func(string) bool) {
|
|
||||||
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
|
|
||||||
if !yield(m.String()) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
|
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||||
fragments := []fragment{{value: s}}
|
fragments := []fragment{{value: s}}
|
||||||
for _, special := range spm.vocab.SpecialVocabulary() {
|
for _, special := range spm.vocab.SpecialVocabulary() {
|
||||||
// TODO: process special tokens concurrently
|
|
||||||
id := spm.vocab.Encode(special)
|
id := spm.vocab.Encode(special)
|
||||||
for i := 0; i < len(fragments); i++ {
|
for i := 0; i < len(fragments); i++ {
|
||||||
frag := fragments[i]
|
frag := fragments[i]
|
||||||
@@ -91,7 +73,6 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
slog.Debug("fragments", "frags", fragments)
|
|
||||||
|
|
||||||
var ids []int32
|
var ids []int32
|
||||||
for _, frag := range fragments {
|
for _, frag := range fragments {
|
||||||
@@ -100,26 +81,17 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for split := range spm.split(frag.value) {
|
text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep)
|
||||||
split = replaceWhitespaceBySeperator(split)
|
|
||||||
|
|
||||||
var sb strings.Builder
|
if id := spm.vocab.Encode(text); id >= 0 {
|
||||||
sb.Write([]byte(split))
|
|
||||||
if id := spm.vocab.Encode(sb.String()); id >= 0 {
|
|
||||||
ids = append(ids, id)
|
ids = append(ids, id)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
runes := []rune(sb.String())
|
q := &queue{}
|
||||||
pq := queue.NewWith(func(a, b any) int {
|
heap.Init(q)
|
||||||
priA := a.(*candidate)
|
|
||||||
priB := b.(*candidate)
|
|
||||||
if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) {
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
return 1
|
|
||||||
})
|
|
||||||
|
|
||||||
|
runes := []rune(text)
|
||||||
merges := make([]merge, len(runes))
|
merges := make([]merge, len(runes))
|
||||||
for r := range runes {
|
for r := range runes {
|
||||||
merges[r] = merge{
|
merges[r] = merge{
|
||||||
@@ -129,8 +101,6 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("tokenizer", "merges", merges)
|
|
||||||
|
|
||||||
pairwise := func(a, b int) *candidate {
|
pairwise := func(a, b int) *candidate {
|
||||||
if a < 0 || b >= len(runes) {
|
if a < 0 || b >= len(runes) {
|
||||||
return nil
|
return nil
|
||||||
@@ -142,34 +112,24 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
a: a,
|
a: a,
|
||||||
b: b,
|
b: b,
|
||||||
score: spm.vocab.Scores[id],
|
score: spm.vocab.Scores[id],
|
||||||
|
size: len(left) + len(right),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range len(runes) - 1 {
|
for i := range len(runes) - 1 {
|
||||||
if pair := pairwise(i, i+1); pair != nil {
|
if pair := pairwise(i, i+1); pair != nil {
|
||||||
pq.Enqueue(pair)
|
heap.Push(q, pair)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pqv := pq.Values()
|
for q.Len() > 0 {
|
||||||
for _, v := range pqv {
|
pair := heap.Pop(q).(*candidate)
|
||||||
e := v.(*candidate)
|
|
||||||
slog.Debug("candidate", "candidate", e)
|
|
||||||
}
|
|
||||||
|
|
||||||
for !pq.Empty() {
|
|
||||||
v, _ := pq.Dequeue()
|
|
||||||
pair := v.(*candidate)
|
|
||||||
left, right := merges[pair.a], merges[pair.b]
|
left, right := merges[pair.a], merges[pair.b]
|
||||||
|
|
||||||
slog.Debug("pair", "left", left, "right", right)
|
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
|
||||||
if len(left.runes) == 0 || len(right.runes) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if id := spm.vocab.Encode(string(left.runes) + string(right.runes)); id < 0 {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,24 +141,36 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||||
pq.Enqueue(pair)
|
heap.Push(q, pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||||
pq.Enqueue(pair)
|
heap.Push(q, pair)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("merges", "merges", merges)
|
|
||||||
|
|
||||||
for _, merge := range merges {
|
for _, merge := range merges {
|
||||||
if len(merge.runes) > 0 {
|
if token := string(merge.runes); token != "" {
|
||||||
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
|
id := spm.vocab.Encode(token)
|
||||||
|
|
||||||
|
if id >= 0 {
|
||||||
ids = append(ids, id)
|
ids = append(ids, id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to byte tokenization
|
||||||
|
var result []int32
|
||||||
|
for _, b := range []byte(token) {
|
||||||
|
byteToken := fmt.Sprintf("<0x%02X>", b)
|
||||||
|
unknownID := spm.vocab.Encode(byteToken)
|
||||||
|
if unknownID >= 0 {
|
||||||
|
result = append(result, unknownID)
|
||||||
} else {
|
} else {
|
||||||
slog.Debug("missing token", "token", string(merge.runes))
|
slog.Debug("unknown byte token", "byte", b, "token", byteToken)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ids = append(ids, result...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -229,6 +201,30 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
type candidate struct {
|
type candidate struct {
|
||||||
a, b int
|
a, b int
|
||||||
score float32
|
score float32
|
||||||
|
size int
|
||||||
|
}
|
||||||
|
|
||||||
|
type queue []*candidate
|
||||||
|
|
||||||
|
func (q queue) Len() int { return len(q) }
|
||||||
|
|
||||||
|
func (q queue) Less(i, j int) bool {
|
||||||
|
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
|
||||||
|
|
||||||
|
func (q *queue) Push(x interface{}) {
|
||||||
|
item := x.(*candidate)
|
||||||
|
*q = append(*q, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *queue) Pop() interface{} {
|
||||||
|
old := *q
|
||||||
|
n := len(old)
|
||||||
|
item := old[n-1]
|
||||||
|
*q = old[0 : n-1]
|
||||||
|
return item
|
||||||
}
|
}
|
||||||
|
|
||||||
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
||||||
@@ -236,11 +232,26 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
|||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
data := spm.vocab.Decode(id)
|
data := spm.vocab.Decode(id)
|
||||||
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
|
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
|
||||||
|
|
||||||
|
// For tokenizers that use byte tokens like "<0xEA>"
|
||||||
|
// convert them to the partial unicode character
|
||||||
|
// so they are buffered correctly by the runner instead
|
||||||
|
// of being sent back to the api as "<0xEA>"
|
||||||
|
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
|
||||||
|
byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to parse hex byte: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sb.WriteByte(byte(byteVal)); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if _, err := sb.WriteString(data); err != nil {
|
if _, err := sb.WriteString(data); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
slog.Debug("decoded", "ids", ids, "text", sb.String())
|
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,8 +25,6 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
preTokenizer := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
|
|
||||||
|
|
||||||
var v Vocabulary
|
var v Vocabulary
|
||||||
|
|
||||||
for _, piece := range spm.GetPieces() {
|
for _, piece := range spm.GetPieces() {
|
||||||
@@ -47,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return NewSentencePieceModel(preTokenizer, &v)
|
return NewSentencePieceModel(&v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSentencePieceEncode(t *testing.T) {
|
func TestSentencePieceEncode(t *testing.T) {
|
||||||
@@ -116,3 +114,59 @@ func TestSentencePieceEncode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
|
||||||
|
vocab := &Vocabulary{
|
||||||
|
Values: []string{
|
||||||
|
"normal",
|
||||||
|
"<0xEA>",
|
||||||
|
"<0x41>",
|
||||||
|
"<0xC3>",
|
||||||
|
"<0xA3>",
|
||||||
|
},
|
||||||
|
Types: []uint32{
|
||||||
|
TOKEN_TYPE_NORMAL,
|
||||||
|
TOKEN_TYPE_BYTE,
|
||||||
|
TOKEN_TYPE_BYTE,
|
||||||
|
TOKEN_TYPE_BYTE,
|
||||||
|
TOKEN_TYPE_BYTE,
|
||||||
|
},
|
||||||
|
Scores: []float32{0, 0, 0, 0, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
spm := NewSentencePieceModel(vocab)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ids []int32
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single byte token",
|
||||||
|
ids: []int32{1},
|
||||||
|
expected: "\xea",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ASCII byte token",
|
||||||
|
ids: []int32{2},
|
||||||
|
expected: "A",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple byte tokens forming UTF-8 character",
|
||||||
|
ids: []int32{3, 4},
|
||||||
|
expected: "ã",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := spm.Decode(tt.ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to decode token IDs %v: %v", tt.ids, err)
|
||||||
|
}
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("got %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ var finishReasonToolCalls = "tool_calls"
|
|||||||
type Error struct {
|
type Error struct {
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Param interface{} `json:"param"`
|
Param any `json:"param"`
|
||||||
Code *string `json:"code"`
|
Code *string `json:"code"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -465,7 +465,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
options := make(map[string]interface{})
|
options := make(map[string]any)
|
||||||
|
|
||||||
switch stop := r.Stop.(type) {
|
switch stop := r.Stop.(type) {
|
||||||
case string:
|
case string:
|
||||||
|
|||||||
@@ -219,7 +219,7 @@ func TestChatMiddleware(t *testing.T) {
|
|||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_current_weather",
|
Name: "get_current_weather",
|
||||||
Arguments: map[string]interface{}{
|
Arguments: map[string]any{
|
||||||
"location": "Paris, France",
|
"location": "Paris, France",
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -213,8 +213,16 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
|
|||||||
return discard
|
return discard
|
||||||
}
|
}
|
||||||
|
|
||||||
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
type ErrReprocessInputs struct {
|
||||||
// the newest half into that space (saving numKeep inputs at the beginning).
|
Inputs []input
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ErrReprocessInputs) Error() string {
|
||||||
|
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShiftCacheSlot frees up space in the KV cache by deleting the oldest half of history
|
||||||
|
// and shifting the newest half into that space (saving numKeep inputs at the beginning).
|
||||||
//
|
//
|
||||||
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
|
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
|
||||||
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
||||||
@@ -222,7 +230,8 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
|||||||
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
|
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
|
inputLen := len(slot.Inputs)
|
||||||
|
discard := c.ShiftDiscard(inputLen, numKeep)
|
||||||
|
|
||||||
if discard <= 0 {
|
if discard <= 0 {
|
||||||
return nil
|
return nil
|
||||||
@@ -231,16 +240,42 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
|||||||
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
||||||
"keep", numKeep, "discard", discard)
|
"keep", numKeep, "discard", discard)
|
||||||
|
|
||||||
// TODO (jessegross): KV cache removal can fail for certain types of models
|
var shiftFailed bool
|
||||||
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
|
|
||||||
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
|
|
||||||
}
|
|
||||||
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
|
|
||||||
|
|
||||||
for i := numKeep + discard; i < len(slot.Inputs); i++ {
|
if c.lc.KvCacheCanShift() {
|
||||||
|
// For models that support shifting, attempt to shift the KV cache
|
||||||
|
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
|
||||||
|
shiftFailed = true
|
||||||
|
slog.Debug("kv cache removal not supported, clearing cache and returning inputs for reprocessing", "id", slot.Id)
|
||||||
|
} else {
|
||||||
|
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, inputLen, -discard)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// For models that don't support shifting
|
||||||
|
shiftFailed = true
|
||||||
|
slog.Debug("kv cache cannot shift, clearing cache and returning inputs for reprocessing", "id", slot.Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
if shiftFailed {
|
||||||
|
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
||||||
|
newInputs := make([]input, numKeep+inputLen-(numKeep+discard))
|
||||||
|
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
||||||
|
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
||||||
|
|
||||||
|
// Clear the entire KV cache
|
||||||
|
_ = c.lc.KvCacheSeqRm(slot.Id, 0, -1)
|
||||||
|
// Reset the slot inputs since we've cleared the cache
|
||||||
|
slot.Inputs = []input{}
|
||||||
|
|
||||||
|
// Return error with inputs that need to be reprocessed
|
||||||
|
return &ErrReprocessInputs{Inputs: newInputs}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Standard shift succeeded - update input array
|
||||||
|
for i := numKeep + discard; i < inputLen; i++ {
|
||||||
slot.Inputs[i-discard] = slot.Inputs[i]
|
slot.Inputs[i-discard] = slot.Inputs[i]
|
||||||
}
|
}
|
||||||
slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
|
slot.Inputs = slot.Inputs[:inputLen-discard]
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -389,8 +389,16 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
if len(seq.pendingInputs) == 0 {
|
if len(seq.pendingInputs) == 0 {
|
||||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
var reprocess *ErrReprocessInputs
|
||||||
|
if errors.As(err, &reprocess) {
|
||||||
|
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||||
|
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||||
|
// Continue processing as normal
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -599,7 +607,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
slog.Info("aborting completion request due to client closing the connection")
|
slog.Info("aborting completion request due to client closing the connection")
|
||||||
} else {
|
} else {
|
||||||
slog.Error("Failed to acquire semaphore", "error", err)
|
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -611,6 +619,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
s.seqsSem.Release(1)
|
||||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -626,6 +635,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
|
s.seqsSem.Release(1)
|
||||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -691,7 +701,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
slog.Info("aborting embeddings request due to client closing the connection")
|
slog.Info("aborting embeddings request due to client closing the connection")
|
||||||
} else {
|
} else {
|
||||||
slog.Error("Failed to acquire semaphore", "error", err)
|
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -703,6 +713,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
s.seqsSem.Release(1)
|
||||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -715,6 +726,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
|
s.seqsSem.Release(1)
|
||||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,8 +31,10 @@ type InputCache struct {
|
|||||||
cache kvcache.Cache
|
cache kvcache.Cache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) {
|
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
|
||||||
if kvSize/int32(numSlots) < 1 {
|
numCtx := kvSize / int32(numSlots)
|
||||||
|
|
||||||
|
if numCtx < 1 {
|
||||||
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
|
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,11 +46,11 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
|
|||||||
|
|
||||||
cache := model.Config().Cache
|
cache := model.Config().Cache
|
||||||
if cache != nil {
|
if cache != nil {
|
||||||
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize)
|
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), numSlots, int(numCtx), batchSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &InputCache{
|
return &InputCache{
|
||||||
numCtx: kvSize / int32(numSlots),
|
numCtx: numCtx,
|
||||||
enabled: cache != nil,
|
enabled: cache != nil,
|
||||||
slots: slots,
|
slots: slots,
|
||||||
multiUserCache: multiUserCache,
|
multiUserCache: multiUserCache,
|
||||||
@@ -116,6 +118,10 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
|
|||||||
}
|
}
|
||||||
|
|
||||||
if c.cache != nil {
|
if c.cache != nil {
|
||||||
|
if numPast > 0 && !c.cache.CanResume(slot.Id, numPast) {
|
||||||
|
numPast = 0
|
||||||
|
}
|
||||||
|
|
||||||
err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)
|
err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Some models don't support partial erasure
|
// Some models don't support partial erasure
|
||||||
@@ -223,6 +229,8 @@ func countCommonPrefix(a []input.Input, b []input.Input) int32 {
|
|||||||
return count
|
return count
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(jessegross): If we need to reprocess the inputs we should ensure that
|
||||||
|
// we don't split up a SameBatch
|
||||||
func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
||||||
targetFree := (c.numCtx - numKeep) / 2
|
targetFree := (c.numCtx - numKeep) / 2
|
||||||
targetFree = max(targetFree, 1)
|
targetFree = max(targetFree, 1)
|
||||||
@@ -237,6 +245,14 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
|||||||
return discard
|
return discard
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ErrReprocessInputs struct {
|
||||||
|
Inputs []input.Input
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ErrReprocessInputs) Error() string {
|
||||||
|
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
|
||||||
|
}
|
||||||
|
|
||||||
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
||||||
// the newest half into that space (saving numKeep inputs at the beginning).
|
// the newest half into that space (saving numKeep inputs at the beginning).
|
||||||
//
|
//
|
||||||
@@ -256,11 +272,23 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
|
|||||||
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
||||||
"keep", numKeep, "discard", discard)
|
"keep", numKeep, "discard", discard)
|
||||||
|
|
||||||
// TODO (jessegross): KV cache removal can fail for certain types of models
|
|
||||||
if c.cache != nil {
|
if c.cache != nil {
|
||||||
err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
|
err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err)
|
slog.Debug("kv cache removal unsupported, clearing cache and returning inputs for reprocessing",
|
||||||
|
"id", slot.Id, "error", err)
|
||||||
|
|
||||||
|
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
||||||
|
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
|
||||||
|
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
||||||
|
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
||||||
|
|
||||||
|
// Reset the cache
|
||||||
|
_ = c.cache.Remove(slot.Id, 0, -1)
|
||||||
|
slot.Inputs = []input.Input{}
|
||||||
|
|
||||||
|
// Return error with inputs that need to be reprocessed
|
||||||
|
return &ErrReprocessInputs{Inputs: newInputs}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
package ollamarunner
|
package ollamarunner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"image"
|
"image"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -425,3 +428,92 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mock implementation of the Cache interface
|
||||||
|
type mockCache struct {
|
||||||
|
shouldFail bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement only the methods needed for the test
|
||||||
|
func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||||
|
if m.shouldFail {
|
||||||
|
return fmt.Errorf("mock cache removal error")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stub implementations for other interface methods
|
||||||
|
func (m *mockCache) SetLayer(layer int) {}
|
||||||
|
func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { return nil, nil, nil }
|
||||||
|
func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {}
|
||||||
|
func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {}
|
||||||
|
func (m *mockCache) Close() {}
|
||||||
|
func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch) error { return nil }
|
||||||
|
func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
|
||||||
|
func (m *mockCache) SetConfig(ml.CacheConfig) {}
|
||||||
|
func (m *mockCache) CanResume(seq int, pos int32) bool { return true }
|
||||||
|
|
||||||
|
func TestShiftCacheSlot(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
numCtx int32
|
||||||
|
inputs []input.Input
|
||||||
|
numKeep int32
|
||||||
|
cacheErr bool
|
||||||
|
wantErr any
|
||||||
|
wantInputsLen int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Normal shift",
|
||||||
|
numCtx: 10,
|
||||||
|
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||||
|
numKeep: 2,
|
||||||
|
cacheErr: false, // No error
|
||||||
|
wantErr: nil,
|
||||||
|
wantInputsLen: 6, // After discarding 4 tokens
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Cache removal fails",
|
||||||
|
numCtx: 10,
|
||||||
|
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||||
|
numKeep: 2,
|
||||||
|
cacheErr: true,
|
||||||
|
wantErr: &ErrReprocessInputs{},
|
||||||
|
wantInputsLen: 0, // Original inputs should be cleared
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mock := &mockCache{shouldFail: tt.cacheErr}
|
||||||
|
c := InputCache{
|
||||||
|
numCtx: tt.numCtx,
|
||||||
|
cache: mock,
|
||||||
|
}
|
||||||
|
slot := &InputCacheSlot{
|
||||||
|
Id: 123,
|
||||||
|
Inputs: make([]input.Input, len(tt.inputs)),
|
||||||
|
}
|
||||||
|
copy(slot.Inputs, tt.inputs)
|
||||||
|
|
||||||
|
err := c.ShiftCacheSlot(slot, tt.numKeep)
|
||||||
|
|
||||||
|
if tt.wantErr != nil {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error but got nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !errors.As(err, &tt.wantErr) {
|
||||||
|
t.Errorf("Expected error of type %T but got %T: %v", tt.wantErr, err, err)
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(slot.Inputs) != tt.wantInputsLen {
|
||||||
|
t.Errorf("Slot inputs length after operation: got %v, want %v", len(slot.Inputs), tt.wantInputsLen)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -115,16 +115,41 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
params.numKeep = int32(len(inputs))
|
params.numKeep = int32(len(inputs))
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(jessegross): We should ensure that we always leave minBatch of context space to shift,
|
|
||||||
// otherwise we might truncate or split the batch against the model's wishes
|
|
||||||
|
|
||||||
// Ensure that at least 1 input can be discarded during shift
|
// Ensure that at least 1 input can be discarded during shift
|
||||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||||
|
|
||||||
if int32(len(inputs)) > s.cache.numCtx {
|
if int32(len(inputs)) > s.cache.numCtx {
|
||||||
discard := int32(len(inputs)) - s.cache.numCtx
|
discard := int32(len(inputs)) - s.cache.numCtx
|
||||||
|
promptStart := params.numKeep + discard
|
||||||
|
|
||||||
|
// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
|
||||||
|
sameBatch := 0
|
||||||
|
for i, inp := range inputs {
|
||||||
|
if sameBatch > 0 {
|
||||||
|
sameBatch--
|
||||||
|
|
||||||
|
if promptStart == int32(i) {
|
||||||
|
promptStart++
|
||||||
|
}
|
||||||
|
} else if promptStart == int32(i) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if inp.SameBatch != 0 {
|
||||||
|
if int32(i) < params.numKeep {
|
||||||
|
return nil, fmt.Errorf("SameBatch may not be specified within numKeep (index: %v numKeep: %v SameBatch: %v)", i, params.numKeep, inp.SameBatch)
|
||||||
|
}
|
||||||
|
|
||||||
|
sameBatch = inp.SameBatch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if promptStart >= int32(len(inputs)) {
|
||||||
|
return nil, errors.New("entire prompt removed by truncation")
|
||||||
|
}
|
||||||
|
|
||||||
newInputs := inputs[:params.numKeep]
|
newInputs := inputs[:params.numKeep]
|
||||||
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
|
newInputs = append(newInputs, inputs[promptStart:]...)
|
||||||
|
|
||||||
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
|
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
|
||||||
inputs = newInputs
|
inputs = newInputs
|
||||||
@@ -267,6 +292,9 @@ type Server struct {
|
|||||||
// KV cache
|
// KV cache
|
||||||
cache *InputCache
|
cache *InputCache
|
||||||
|
|
||||||
|
// next sequence for prompt processing to avoid starvation
|
||||||
|
nextSeq int
|
||||||
|
|
||||||
// multimodalHash generates hashes for comparing equality
|
// multimodalHash generates hashes for comparing equality
|
||||||
// of non-text data
|
// of non-text data
|
||||||
multimodalHash maphash.Hash
|
multimodalHash maphash.Hash
|
||||||
@@ -348,16 +376,22 @@ func (s *Server) processBatch() error {
|
|||||||
}
|
}
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
var options input.Options
|
var batchInputs []int32
|
||||||
|
var batch input.Batch
|
||||||
|
|
||||||
|
resumeSeq := -1
|
||||||
|
seqIdx := s.nextSeq - 1
|
||||||
|
for range s.seqs {
|
||||||
|
seqIdx = (seqIdx + 1) % len(s.seqs)
|
||||||
|
seq := s.seqs[seqIdx]
|
||||||
|
|
||||||
for i, seq := range s.seqs {
|
|
||||||
if seq == nil {
|
if seq == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// if past the num predict limit
|
// if past the num predict limit
|
||||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||||
s.removeSequence(i, "limit")
|
s.removeSequence(seqIdx, "limit")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -368,16 +402,23 @@ func (s *Server) processBatch() error {
|
|||||||
|
|
||||||
batchSize := s.batchSize
|
batchSize := s.batchSize
|
||||||
|
|
||||||
for j, inp := range seq.inputs {
|
for i, inp := range seq.inputs {
|
||||||
// If we are required to put following inputs into a single batch then extend the
|
// If we are required to put following inputs into a single batch then extend the
|
||||||
// batch size. Since we are only extending the size the minimum amount possible, this
|
// batch size. Since we are only extending the size the minimum amount possible, this
|
||||||
// will cause a break if we have pending inputs.
|
// will cause a break if we have existing inputs.
|
||||||
minBatch := 1 + inp.SameBatch
|
minBatch := 1 + inp.SameBatch
|
||||||
if minBatch > batchSize {
|
if minBatch > batchSize {
|
||||||
batchSize = minBatch
|
batchSize = minBatch
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(seq.pendingInputs)+minBatch > batchSize {
|
// Stop if the required batch would put us over the total batch size (including tokens
|
||||||
|
// added by other sequences). If we haven't been able to add anything yet then pick up
|
||||||
|
// here again for the next batch to avoid starvation, though we can opportunistically
|
||||||
|
// check if other sequences can still squeeze something in.
|
||||||
|
if len(batchInputs)+minBatch > batchSize {
|
||||||
|
if len(seq.pendingInputs) == 0 && resumeSeq == -1 {
|
||||||
|
resumeSeq = seqIdx
|
||||||
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -391,21 +432,29 @@ func (s *Server) processBatch() error {
|
|||||||
|
|
||||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
var reprocess *ErrReprocessInputs
|
||||||
|
if errors.As(err, &reprocess) {
|
||||||
|
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||||
|
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||||
|
// Skip this sequence but continue processing the rest
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
options.Inputs = append(options.Inputs, inp.Token)
|
|
||||||
if inp.Multimodal != nil {
|
|
||||||
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
batchInputs = append(batchInputs, inp.Token)
|
||||||
options.Sequences = append(options.Sequences, seq.cache.Id)
|
if inp.Multimodal != nil {
|
||||||
|
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal})
|
||||||
|
}
|
||||||
|
|
||||||
seq.iBatch = len(options.Outputs)
|
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||||
if j+1 == len(seq.inputs) {
|
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||||
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
|
|
||||||
|
seq.iBatch = len(batch.Outputs)
|
||||||
|
if i+1 == len(seq.inputs) {
|
||||||
|
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||||
}
|
}
|
||||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||||
}
|
}
|
||||||
@@ -413,14 +462,20 @@ func (s *Server) processBatch() error {
|
|||||||
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(options.Inputs) == 0 {
|
if resumeSeq != -1 {
|
||||||
|
s.nextSeq = resumeSeq
|
||||||
|
} else {
|
||||||
|
s.nextSeq = seqIdx + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(batchInputs) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := s.model.Backend().NewContext()
|
ctx := s.model.Backend().NewContext()
|
||||||
defer ctx.Close()
|
defer ctx.Close()
|
||||||
|
|
||||||
modelOutput, err := model.Forward(ctx, s.model, options)
|
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to decode batch: %w", err)
|
return fmt.Errorf("failed to decode batch: %w", err)
|
||||||
}
|
}
|
||||||
@@ -460,7 +515,7 @@ func (s *Server) processBatch() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sample a token
|
// sample a token
|
||||||
vocabSize := len(logits) / len(options.Outputs)
|
vocabSize := len(logits) / len(batch.Outputs)
|
||||||
|
|
||||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -587,7 +642,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
slog.Info("aborting completion request due to client closing the connection")
|
slog.Info("aborting completion request due to client closing the connection")
|
||||||
} else {
|
} else {
|
||||||
slog.Error("Failed to acquire semaphore", "error", err)
|
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -599,6 +654,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
s.seqsSem.Release(1)
|
||||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -612,6 +668,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
|
s.seqsSem.Release(1)
|
||||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -677,6 +734,7 @@ func (m *multiLPath) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) loadModel(
|
func (s *Server) loadModel(
|
||||||
|
ctx context.Context,
|
||||||
mpath string,
|
mpath string,
|
||||||
params ml.BackendParams,
|
params ml.BackendParams,
|
||||||
lpath multiLPath,
|
lpath multiLPath,
|
||||||
@@ -686,7 +744,7 @@ func (s *Server) loadModel(
|
|||||||
multiUserCache bool,
|
multiUserCache bool,
|
||||||
) {
|
) {
|
||||||
var err error
|
var err error
|
||||||
s.model, err = model.New(mpath, params)
|
s.model, err = model.New(ctx, mpath, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -698,7 +756,7 @@ func (s *Server) loadModel(
|
|||||||
panic("loras are not yet implemented")
|
panic("loras are not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, multiUserCache)
|
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -782,6 +840,9 @@ func Execute(args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
params := ml.BackendParams{
|
params := ml.BackendParams{
|
||||||
|
Progress: func(progress float32) {
|
||||||
|
server.progress = progress
|
||||||
|
},
|
||||||
NumThreads: *threads,
|
NumThreads: *threads,
|
||||||
NumGPULayers: *numGPULayers,
|
NumGPULayers: *numGPULayers,
|
||||||
MainGPU: *mainGPU,
|
MainGPU: *mainGPU,
|
||||||
@@ -790,13 +851,13 @@ func Execute(args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
server.ready.Add(1)
|
server.ready.Add(1)
|
||||||
go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
|
||||||
|
|
||||||
server.cond = sync.NewCond(&server.mu)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
||||||
|
|
||||||
|
server.cond = sync.NewCond(&server.mu)
|
||||||
|
|
||||||
go server.run(ctx)
|
go server.run(ctx)
|
||||||
|
|
||||||
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
||||||
|
|||||||
@@ -26,6 +26,10 @@ type Sampler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||||
|
if len(logits) == 0 {
|
||||||
|
return -1, errors.New("sample: no logits provided to sample")
|
||||||
|
}
|
||||||
|
|
||||||
tokens := make([]token, len(logits))
|
tokens := make([]token, len(logits))
|
||||||
for i := range logits {
|
for i := range logits {
|
||||||
tokens[i].id = int32(i)
|
tokens[i].id = int32(i)
|
||||||
@@ -94,13 +98,6 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
|||||||
tokens = topP(tokens, s.topP)
|
tokens = topP(tokens, s.topP)
|
||||||
tokens = minP(tokens, s.minP)
|
tokens = minP(tokens, s.minP)
|
||||||
|
|
||||||
// TODO: this should fall back to greedy sampling
|
|
||||||
// or topP, topK values etc should be such that
|
|
||||||
// there are always tokens to sample from
|
|
||||||
if len(tokens) == 0 {
|
|
||||||
return token{}, errors.New("no tokens to sample from")
|
|
||||||
}
|
|
||||||
|
|
||||||
var r float32
|
var r float32
|
||||||
if s.rng != nil {
|
if s.rng != nil {
|
||||||
r = s.rng.Float32()
|
r = s.rng.Float32()
|
||||||
@@ -123,6 +120,9 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
|||||||
return 1
|
return 1
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if math.IsNaN(float64(sum)) {
|
||||||
|
return token{}, errors.New("sample: logits sum to NaN, check model output")
|
||||||
|
}
|
||||||
return tokens[idx], nil
|
return tokens[idx], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package sample
|
package sample
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"math"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@@ -29,6 +30,29 @@ func TestWeighted(t *testing.T) {
|
|||||||
if want != got {
|
if want != got {
|
||||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test very high p
|
||||||
|
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
|
||||||
|
// Use extremely small topP to filter out all tokens
|
||||||
|
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
|
||||||
|
got, err = sampler.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Should get the token with the highest logit
|
||||||
|
want = int32(0)
|
||||||
|
if want != got {
|
||||||
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
|
||||||
|
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
|
||||||
|
got, err = sampler.Sample(logits)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error, got %d", got)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkSample(b *testing.B) {
|
func BenchmarkSample(b *testing.B) {
|
||||||
|
|||||||
@@ -168,27 +168,53 @@ func TestTopP(t *testing.T) {
|
|||||||
softmax(tokens)
|
softmax(tokens)
|
||||||
tokens = topK(tokens, 20)
|
tokens = topK(tokens, 20)
|
||||||
|
|
||||||
// Then apply topP
|
// Test with very high p value
|
||||||
tokens = topP(tokens, 0.95)
|
got := topP(tokens, 1.0)
|
||||||
|
|
||||||
// Should keep tokens until cumsum > 0.95
|
// Should keep all tokens since p is 1
|
||||||
if len(tokens) > 3 {
|
if len(got) != len(input) {
|
||||||
|
t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with normal p value
|
||||||
|
got = topP(tokens, 0.95)
|
||||||
|
|
||||||
|
if len(got) > 3 {
|
||||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
|
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
|
||||||
t.Logf("got: %v", tokens)
|
t.Logf("got: %v", got)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test edge case - ensure at least one token remains
|
// Test edge case - ensure at least one token remains
|
||||||
input = []float32{-1e6, -1e6, -1e6} // One dominant token
|
input = []float32{-1e6, -1e6, -1e7}
|
||||||
tokens = toTokens(input)
|
tokens = toTokens(input)
|
||||||
|
tokens = topK(tokens, 20)
|
||||||
softmax(tokens)
|
softmax(tokens)
|
||||||
tokens = topP(tokens, 0.0) // Very small p
|
got = topP(tokens, 0.0)
|
||||||
if len(tokens) < 1 {
|
if len(got) < 1 {
|
||||||
t.Error("topP should keep at least one token")
|
t.Error("topP should keep at least one token")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test with zero p value
|
||||||
|
got = topP(tokens, 0.0)
|
||||||
|
|
||||||
|
// Should keep only the highest probability token
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Errorf("topP(0.0): should keep only one token, got %d", len(got))
|
||||||
|
t.Logf("got: %v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens = toTokens(input)
|
||||||
|
tokens = topK(tokens, 20)
|
||||||
|
softmax(tokens)
|
||||||
|
got = topP(tokens, 1e-10)
|
||||||
|
if len(got) == 0 {
|
||||||
|
t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got))
|
||||||
|
t.Logf("got: %v", got)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMinP(t *testing.T) {
|
func TestMinP(t *testing.T) {
|
||||||
input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
|
input := []float32{-2, 0, -1, -3, 2, 1, 4, 3}
|
||||||
tokens := toTokens(input)
|
tokens := toTokens(input)
|
||||||
|
|
||||||
// First apply temperature and softmax
|
// First apply temperature and softmax
|
||||||
@@ -225,30 +251,48 @@ func TestMinP(t *testing.T) {
|
|||||||
t.Logf("got: %v", tokens)
|
t.Logf("got: %v", tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test with single token
|
||||||
|
tokens = toTokens(input[:1])
|
||||||
|
tokens = topK(tokens, 20)
|
||||||
|
softmax(tokens)
|
||||||
|
tokens = minP(tokens, 0.1)
|
||||||
|
|
||||||
|
// Should keep only the highest probability token
|
||||||
|
if len(tokens) != 1 {
|
||||||
|
t.Errorf("minP(0.1): should return single token, got %d", len(tokens))
|
||||||
|
t.Logf("got: %v", tokens)
|
||||||
|
}
|
||||||
|
|
||||||
input = []float32{1e-10, 1e-10, 1e-10}
|
input = []float32{1e-10, 1e-10, 1e-10}
|
||||||
tokens = toTokens(input)
|
tokens = toTokens(input)
|
||||||
softmax(tokens)
|
softmax(tokens)
|
||||||
tokens = minP(tokens, 1.0)
|
tokens = minP(tokens, 1.0)
|
||||||
if len(tokens) < 1 {
|
if len(tokens) < 1 {
|
||||||
t.Error("minP should keep at least one token even with extreme probabilities")
|
t.Error("minP should keep at least one token even with extreme probabilities")
|
||||||
}
|
got := minP(tokens, 1.0)
|
||||||
|
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSortLogits(t *testing.T) {
|
// Test with normal p value
|
||||||
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
got = minP(tokens, 0.2)
|
||||||
tokens := toTokens(input)
|
|
||||||
|
|
||||||
tokens = topK(tokens, 20)
|
// Should keep tokens with prob >= 0.2 * max_prob
|
||||||
|
if len(got) > 3 {
|
||||||
for i := 1; i < len(tokens); i++ {
|
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
|
||||||
if tokens[i].value > tokens[i-1].value {
|
t.Logf("got: %v", got)
|
||||||
t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
|
|
||||||
i, tokens[i].value, tokens[i-1].value)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
// Test with zero p value
|
||||||
compareLogits(t, "sortLogits", want, tokens)
|
got = minP(tokens, 0.0)
|
||||||
|
|
||||||
|
// Should keep only the highest probability token
|
||||||
|
if len(got) != len(tokens) {
|
||||||
|
t.Errorf("minP(0.0): should keep only one token, got %d", len(got))
|
||||||
|
t.Logf("got: %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkTransforms(b *testing.B) {
|
func BenchmarkTransforms(b *testing.B) {
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ const maxRetries = 6
|
|||||||
var (
|
var (
|
||||||
errMaxRetriesExceeded = errors.New("max retries exceeded")
|
errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||||
errPartStalled = errors.New("part stalled")
|
errPartStalled = errors.New("part stalled")
|
||||||
|
errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL")
|
||||||
)
|
)
|
||||||
|
|
||||||
var blobDownloadManager sync.Map
|
var blobDownloadManager sync.Map
|
||||||
@@ -236,7 +237,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||||||
|
|
||||||
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||||
if len(via) > 10 {
|
if len(via) > 10 {
|
||||||
return errors.New("maximum redirects exceeded (10) for directURL")
|
return errMaxRedirectsExceeded
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the hostname is the same, allow the redirect
|
// if the hostname is the same, allow the redirect
|
||||||
|
|||||||
@@ -35,14 +35,9 @@ var (
|
|||||||
errCapabilityCompletion = errors.New("completion")
|
errCapabilityCompletion = errors.New("completion")
|
||||||
errCapabilityTools = errors.New("tools")
|
errCapabilityTools = errors.New("tools")
|
||||||
errCapabilityInsert = errors.New("insert")
|
errCapabilityInsert = errors.New("insert")
|
||||||
)
|
errCapabilityVision = errors.New("vision")
|
||||||
|
errCapabilityEmbedding = errors.New("embedding")
|
||||||
type Capability string
|
errInsecureProtocol = errors.New("insecure protocol http")
|
||||||
|
|
||||||
const (
|
|
||||||
CapabilityCompletion = Capability("completion")
|
|
||||||
CapabilityTools = Capability("tools")
|
|
||||||
CapabilityInsert = Capability("insert")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type registryOptions struct {
|
type registryOptions struct {
|
||||||
@@ -65,52 +60,83 @@ type Model struct {
|
|||||||
System string
|
System string
|
||||||
License []string
|
License []string
|
||||||
Digest string
|
Digest string
|
||||||
Options map[string]interface{}
|
Options map[string]any
|
||||||
Messages []api.Message
|
Messages []api.Message
|
||||||
|
|
||||||
Template *template.Template
|
Template *template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
|
// Capabilities returns the capabilities that the model supports
|
||||||
// any missing or unknown capabilities
|
func (m *Model) Capabilities() []model.Capability {
|
||||||
func (m *Model) CheckCapabilities(caps ...Capability) error {
|
capabilities := []model.Capability{}
|
||||||
var errs []error
|
|
||||||
for _, cap := range caps {
|
// Check for completion capability
|
||||||
switch cap {
|
|
||||||
case CapabilityCompletion:
|
|
||||||
r, err := os.Open(m.ModelPath)
|
r, err := os.Open(m.ModelPath)
|
||||||
if err != nil {
|
if err == nil {
|
||||||
slog.Error("couldn't open model file", "error", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
|
|
||||||
// TODO(mxyng): decode the GGML into model to avoid doing this multiple times
|
|
||||||
f, _, err := ggml.Decode(r, 0)
|
f, _, err := ggml.Decode(r, 0)
|
||||||
if err != nil {
|
if err == nil {
|
||||||
|
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
|
||||||
|
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||||
|
} else {
|
||||||
|
capabilities = append(capabilities, model.CapabilityCompletion)
|
||||||
|
}
|
||||||
|
if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok {
|
||||||
|
capabilities = append(capabilities, model.CapabilityVision)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
slog.Error("couldn't decode ggml", "error", err)
|
slog.Error("couldn't decode ggml", "error", err)
|
||||||
continue
|
}
|
||||||
|
} else {
|
||||||
|
slog.Error("couldn't open model file", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
|
if m.Template == nil {
|
||||||
errs = append(errs, errCapabilityCompletion)
|
return capabilities
|
||||||
}
|
}
|
||||||
case CapabilityTools:
|
|
||||||
if !slices.Contains(m.Template.Vars(), "tools") {
|
// Check for tools capability
|
||||||
errs = append(errs, errCapabilityTools)
|
if slices.Contains(m.Template.Vars(), "tools") {
|
||||||
|
capabilities = append(capabilities, model.CapabilityTools)
|
||||||
}
|
}
|
||||||
case CapabilityInsert:
|
|
||||||
vars := m.Template.Vars()
|
// Check for insert capability
|
||||||
if !slices.Contains(vars, "suffix") {
|
if slices.Contains(m.Template.Vars(), "suffix") {
|
||||||
errs = append(errs, errCapabilityInsert)
|
capabilities = append(capabilities, model.CapabilityInsert)
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
|
return capabilities
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
|
||||||
|
// any missing or unknown capabilities
|
||||||
|
func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
||||||
|
available := m.Capabilities()
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
// Map capabilities to their corresponding error
|
||||||
|
capToErr := map[model.Capability]error{
|
||||||
|
model.CapabilityCompletion: errCapabilityCompletion,
|
||||||
|
model.CapabilityTools: errCapabilityTools,
|
||||||
|
model.CapabilityInsert: errCapabilityInsert,
|
||||||
|
model.CapabilityVision: errCapabilityVision,
|
||||||
|
model.CapabilityEmbedding: errCapabilityEmbedding,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cap := range want {
|
||||||
|
err, ok := capToErr[cap]
|
||||||
|
if !ok {
|
||||||
slog.Error("unknown capability", "capability", cap)
|
slog.Error("unknown capability", "capability", cap)
|
||||||
return fmt.Errorf("unknown capability: %s", cap)
|
return fmt.Errorf("unknown capability: %s", cap)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !slices.Contains(available, cap) {
|
||||||
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := errors.Join(errs...); err != nil {
|
if len(errs) > 0 {
|
||||||
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
|
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -479,7 +505,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||||
|
|
||||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||||
return errors.New("insecure protocol http")
|
return errInsecureProtocol
|
||||||
}
|
}
|
||||||
|
|
||||||
manifest, _, err := GetManifest(mp)
|
manifest, _, err := GetManifest(mp)
|
||||||
@@ -543,7 +569,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
}
|
}
|
||||||
|
|
||||||
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||||
return errors.New("insecure protocol http")
|
return errInsecureProtocol
|
||||||
}
|
}
|
||||||
|
|
||||||
fn(api.ProgressResponse{Status: "pulling manifest"})
|
fn(api.ProgressResponse{Status: "pulling manifest"})
|
||||||
|
|||||||
360
server/images_test.go
Normal file
360
server/images_test.go
Normal file
@@ -0,0 +1,360 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Constants for GGUF magic bytes and version
|
||||||
|
var (
|
||||||
|
ggufMagic = []byte{0x47, 0x47, 0x55, 0x46} // "GGUF"
|
||||||
|
ggufVer = uint32(3) // Version 3
|
||||||
|
)
|
||||||
|
|
||||||
|
// Helper function to create mock GGUF data
|
||||||
|
func createMockGGUFData(architecture string, vision bool) []byte {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
// Write GGUF header
|
||||||
|
buf.Write(ggufMagic)
|
||||||
|
binary.Write(&buf, binary.LittleEndian, ggufVer)
|
||||||
|
|
||||||
|
// Write tensor count (0 for our test)
|
||||||
|
var numTensors uint64 = 0
|
||||||
|
binary.Write(&buf, binary.LittleEndian, numTensors)
|
||||||
|
|
||||||
|
// Calculate number of metadata entries
|
||||||
|
numMetaEntries := uint64(1) // architecture entry
|
||||||
|
if vision {
|
||||||
|
numMetaEntries++
|
||||||
|
}
|
||||||
|
// Add embedding entry if architecture is "bert"
|
||||||
|
if architecture == "bert" {
|
||||||
|
numMetaEntries++
|
||||||
|
}
|
||||||
|
binary.Write(&buf, binary.LittleEndian, numMetaEntries)
|
||||||
|
|
||||||
|
// Write architecture metadata
|
||||||
|
archKey := "general.architecture"
|
||||||
|
keyLen := uint64(len(archKey))
|
||||||
|
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||||
|
buf.WriteString(archKey)
|
||||||
|
|
||||||
|
// String type (8)
|
||||||
|
var strType uint32 = 8
|
||||||
|
binary.Write(&buf, binary.LittleEndian, strType)
|
||||||
|
|
||||||
|
// String length
|
||||||
|
strLen := uint64(len(architecture))
|
||||||
|
binary.Write(&buf, binary.LittleEndian, strLen)
|
||||||
|
buf.WriteString(architecture)
|
||||||
|
|
||||||
|
if vision {
|
||||||
|
visionKey := architecture + ".vision.block_count"
|
||||||
|
keyLen = uint64(len(visionKey))
|
||||||
|
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||||
|
buf.WriteString(visionKey)
|
||||||
|
|
||||||
|
// uint32 type (4)
|
||||||
|
var uint32Type uint32 = 4
|
||||||
|
binary.Write(&buf, binary.LittleEndian, uint32Type)
|
||||||
|
|
||||||
|
// uint32 value (1)
|
||||||
|
var countVal uint32 = 1
|
||||||
|
binary.Write(&buf, binary.LittleEndian, countVal)
|
||||||
|
}
|
||||||
|
// Write embedding metadata if architecture is "bert"
|
||||||
|
if architecture == "bert" {
|
||||||
|
poolKey := architecture + ".pooling_type"
|
||||||
|
keyLen = uint64(len(poolKey))
|
||||||
|
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||||
|
buf.WriteString(poolKey)
|
||||||
|
|
||||||
|
// uint32 type (4)
|
||||||
|
var uint32Type uint32 = 4
|
||||||
|
binary.Write(&buf, binary.LittleEndian, uint32Type)
|
||||||
|
|
||||||
|
// uint32 value (1)
|
||||||
|
var poolingVal uint32 = 1
|
||||||
|
binary.Write(&buf, binary.LittleEndian, poolingVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelCapabilities(t *testing.T) {
|
||||||
|
// Create a temporary directory for test files
|
||||||
|
tempDir, err := os.MkdirTemp("", "model_capabilities_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
// Create different types of mock model files
|
||||||
|
completionModelPath := filepath.Join(tempDir, "model.bin")
|
||||||
|
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
|
||||||
|
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
|
||||||
|
// Create a simple model file for tests that don't depend on GGUF content
|
||||||
|
simpleModelPath := filepath.Join(tempDir, "simple_model.bin")
|
||||||
|
|
||||||
|
err = os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create completion model file: %v", err)
|
||||||
|
}
|
||||||
|
err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create completion model file: %v", err)
|
||||||
|
}
|
||||||
|
err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create embedding model file: %v", err)
|
||||||
|
}
|
||||||
|
err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create simple model file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
chatTemplate, err := template.Parse("{{ .prompt }}")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testModels := []struct {
|
||||||
|
name string
|
||||||
|
model Model
|
||||||
|
expectedCaps []model.Capability
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "model with completion capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: completionModelPath,
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityCompletion},
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "model with completion, tools, and insert capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: completionModelPath,
|
||||||
|
Template: toolsInsertTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with tools and insert capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: simpleModelPath,
|
||||||
|
Template: toolsInsertTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with tools capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: simpleModelPath,
|
||||||
|
Template: toolsTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityTools},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with vision capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: visionModelPath,
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with vision, tools, and insert capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: visionModelPath,
|
||||||
|
Template: toolsInsertTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with embedding capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: embeddingModelPath,
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityEmbedding},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// compare two slices of model.Capability regardless of order
|
||||||
|
compareCapabilities := func(a, b []model.Capability) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
aCount := make(map[model.Capability]int)
|
||||||
|
for _, cap := range a {
|
||||||
|
aCount[cap]++
|
||||||
|
}
|
||||||
|
|
||||||
|
bCount := make(map[model.Capability]int)
|
||||||
|
for _, cap := range b {
|
||||||
|
bCount[cap]++
|
||||||
|
}
|
||||||
|
|
||||||
|
for cap, count := range aCount {
|
||||||
|
if bCount[cap] != count {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range testModels {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Test Capabilities method
|
||||||
|
caps := tt.model.Capabilities()
|
||||||
|
if !compareCapabilities(caps, tt.expectedCaps) {
|
||||||
|
t.Errorf("Expected capabilities %v, got %v", tt.expectedCaps, caps)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelCheckCapabilities(t *testing.T) {
|
||||||
|
// Create a temporary directory for test files
|
||||||
|
tempDir, err := os.MkdirTemp("", "model_check_capabilities_test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
|
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
|
||||||
|
simpleModelPath := filepath.Join(tempDir, "model.bin")
|
||||||
|
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
|
||||||
|
|
||||||
|
err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create simple model file: %v", err)
|
||||||
|
}
|
||||||
|
err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create vision model file: %v", err)
|
||||||
|
}
|
||||||
|
err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create embedding model file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
chatTemplate, err := template.Parse("{{ .prompt }}")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model Model
|
||||||
|
checkCaps []model.Capability
|
||||||
|
expectedErrMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "completion model without tools capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: simpleModelPath,
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
checkCaps: []model.Capability{model.CapabilityTools},
|
||||||
|
expectedErrMsg: "does not support tools",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with all needed capabilities",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: simpleModelPath,
|
||||||
|
Template: toolsInsertTemplate,
|
||||||
|
},
|
||||||
|
checkCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model missing insert capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: simpleModelPath,
|
||||||
|
Template: toolsTemplate,
|
||||||
|
},
|
||||||
|
checkCaps: []model.Capability{model.CapabilityInsert},
|
||||||
|
expectedErrMsg: "does not support insert",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model missing vision capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: simpleModelPath,
|
||||||
|
Template: toolsTemplate,
|
||||||
|
},
|
||||||
|
checkCaps: []model.Capability{model.CapabilityVision},
|
||||||
|
expectedErrMsg: "does not support vision",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with vision capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: visionModelPath,
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
checkCaps: []model.Capability{model.CapabilityVision},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with embedding capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: embeddingModelPath,
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
checkCaps: []model.Capability{model.CapabilityEmbedding},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown capability",
|
||||||
|
model: Model{
|
||||||
|
ModelPath: simpleModelPath,
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
checkCaps: []model.Capability{"unknown"},
|
||||||
|
expectedErrMsg: "unknown capability",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Test CheckCapabilities method
|
||||||
|
err := tt.model.CheckCapabilities(tt.checkCaps...)
|
||||||
|
if tt.expectedErrMsg == "" {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error containing %q, got nil", tt.expectedErrMsg)
|
||||||
|
} else if !strings.Contains(err.Error(), tt.expectedErrMsg) {
|
||||||
|
t.Errorf("Expected error containing %q, got: %v", tt.expectedErrMsg, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -37,7 +37,6 @@ import (
|
|||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||||
"github.com/ollama/ollama/server/internal/internal/backoff"
|
|
||||||
"github.com/ollama/ollama/server/internal/internal/names"
|
"github.com/ollama/ollama/server/internal/internal/names"
|
||||||
|
|
||||||
_ "embed"
|
_ "embed"
|
||||||
@@ -60,6 +59,11 @@ var (
|
|||||||
// ErrCached is passed to [Trace.PushUpdate] when a layer already
|
// ErrCached is passed to [Trace.PushUpdate] when a layer already
|
||||||
// exists. It is a non-fatal error and is never returned by [Registry.Push].
|
// exists. It is a non-fatal error and is never returned by [Registry.Push].
|
||||||
ErrCached = errors.New("cached")
|
ErrCached = errors.New("cached")
|
||||||
|
|
||||||
|
// ErrIncomplete is returned by [Registry.Pull] when a model pull was
|
||||||
|
// incomplete due to one or more layer download failures. Users that
|
||||||
|
// want specific errors should use [WithTrace].
|
||||||
|
ErrIncomplete = errors.New("incomplete")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defaults
|
// Defaults
|
||||||
@@ -213,12 +217,6 @@ type Registry struct {
|
|||||||
// request. If zero, [DefaultChunkingThreshold] is used.
|
// request. If zero, [DefaultChunkingThreshold] is used.
|
||||||
ChunkingThreshold int64
|
ChunkingThreshold int64
|
||||||
|
|
||||||
// MaxChunkSize is the maximum size of a chunk to download. If zero,
|
|
||||||
// the default is [DefaultMaxChunkSize].
|
|
||||||
//
|
|
||||||
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
|
||||||
MaxChunkSize int64
|
|
||||||
|
|
||||||
// Mask, if set, is the name used to convert non-fully qualified names
|
// Mask, if set, is the name used to convert non-fully qualified names
|
||||||
// to fully qualified names. If empty, [DefaultMask] is used.
|
// to fully qualified names. If empty, [DefaultMask] is used.
|
||||||
Mask string
|
Mask string
|
||||||
@@ -278,8 +276,19 @@ func DefaultRegistry() (*Registry, error) {
|
|||||||
|
|
||||||
func UserAgent() string {
|
func UserAgent() string {
|
||||||
buildinfo, _ := debug.ReadBuildInfo()
|
buildinfo, _ := debug.ReadBuildInfo()
|
||||||
|
|
||||||
|
version := buildinfo.Main.Version
|
||||||
|
if version == "(devel)" {
|
||||||
|
// When using `go run .` the version is "(devel)". This is seen
|
||||||
|
// as an invalid version by ollama.com and so it defaults to
|
||||||
|
// "needs upgrade" for some requests, such as pulls. These
|
||||||
|
// checks can be skipped by using the special version "v0.0.0",
|
||||||
|
// so we set it to that here.
|
||||||
|
version = "v0.0.0"
|
||||||
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
|
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
|
||||||
buildinfo.Main.Version,
|
version,
|
||||||
runtime.GOARCH,
|
runtime.GOARCH,
|
||||||
runtime.GOOS,
|
runtime.GOOS,
|
||||||
runtime.Version(),
|
runtime.Version(),
|
||||||
@@ -412,26 +421,19 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func canRetry(err error) bool {
|
|
||||||
var re *Error
|
|
||||||
if !errors.As(err, &re) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return re.Status >= 500
|
|
||||||
}
|
|
||||||
|
|
||||||
// trackingReader is an io.Reader that tracks the number of bytes read and
|
// trackingReader is an io.Reader that tracks the number of bytes read and
|
||||||
// calls the update function with the layer, the number of bytes read.
|
// calls the update function with the layer, the number of bytes read.
|
||||||
//
|
//
|
||||||
// It always calls update with a nil error.
|
// It always calls update with a nil error.
|
||||||
type trackingReader struct {
|
type trackingReader struct {
|
||||||
|
l *Layer
|
||||||
r io.Reader
|
r io.Reader
|
||||||
n *atomic.Int64
|
update func(l *Layer, n int64, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
||||||
n, err = r.r.Read(p)
|
n, err = r.r.Read(p)
|
||||||
r.n.Add(int64(n))
|
r.update(r.l, int64(n), nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -447,6 +449,11 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(bmizerany): decide if this should be considered valid. Maybe
|
||||||
|
// server-side we special case '{}' to have some special meaning? Maybe
|
||||||
|
// "archiving" a tag (which is how we reason about it in the registry
|
||||||
|
// already, just with a different twist).
|
||||||
if len(m.Layers) == 0 {
|
if len(m.Layers) == 0 {
|
||||||
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
||||||
}
|
}
|
||||||
@@ -456,11 +463,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
exists := func(l *Layer) bool {
|
// TODO(bmizerany): work to remove the need to do this
|
||||||
info, err := c.Get(l.Digest)
|
|
||||||
return err == nil && info.Size == l.Size
|
|
||||||
}
|
|
||||||
|
|
||||||
layers := m.Layers
|
layers := m.Layers
|
||||||
if m.Config != nil && m.Config.Digest.IsValid() {
|
if m.Config != nil && m.Config.Digest.IsValid() {
|
||||||
layers = append(layers, m.Config)
|
layers = append(layers, m.Config)
|
||||||
@@ -468,45 +471,79 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
|
|
||||||
// Send initial layer trace events to allow clients to have an
|
// Send initial layer trace events to allow clients to have an
|
||||||
// understanding of work to be done before work starts.
|
// understanding of work to be done before work starts.
|
||||||
|
var expected int64
|
||||||
t := traceFromContext(ctx)
|
t := traceFromContext(ctx)
|
||||||
skip := make([]bool, len(layers))
|
for _, l := range layers {
|
||||||
for i, l := range layers {
|
|
||||||
t.update(l, 0, nil)
|
t.update(l, 0, nil)
|
||||||
if exists(l) {
|
expected += l.Size
|
||||||
skip[i] = true
|
|
||||||
t.update(l, l.Size, ErrCached)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
g, ctx := errgroup.WithContext(ctx)
|
var received atomic.Int64
|
||||||
|
var g errgroup.Group
|
||||||
g.SetLimit(r.maxStreams())
|
g.SetLimit(r.maxStreams())
|
||||||
for i, l := range layers {
|
for _, l := range layers {
|
||||||
if skip[i] {
|
info, err := c.Get(l.Digest)
|
||||||
|
if err == nil && info.Size == l.Size {
|
||||||
|
received.Add(l.Size)
|
||||||
|
t.update(l, l.Size, ErrCached)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
chunked, err := c.Chunked(l.Digest, l.Size)
|
chunked, err := c.Chunked(l.Digest, l.Size)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.update(l, 0, err)
|
t.update(l, 0, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
defer chunked.Close()
|
|
||||||
|
|
||||||
var progress atomic.Int64
|
|
||||||
for cs, err := range r.chunksums(ctx, name, l) {
|
for cs, err := range r.chunksums(ctx, name, l) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.update(l, progress.Load(), err)
|
// Chunksum stream interrupted. Note in trace
|
||||||
|
// log and let in-flight downloads complete.
|
||||||
|
// This will naturally trigger ErrIncomplete
|
||||||
|
// since received < expected bytes.
|
||||||
|
t.update(l, 0, err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
g.Go(func() (err error) {
|
cacheKey := fmt.Sprintf(
|
||||||
defer func() { t.update(l, progress.Load(), err) }()
|
"v1 pull chunksum %s %s %d-%d",
|
||||||
|
l.Digest,
|
||||||
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
cs.Digest,
|
||||||
if err != nil {
|
cs.Chunk.Start,
|
||||||
return err
|
cs.Chunk.End,
|
||||||
|
)
|
||||||
|
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
|
||||||
|
_, err := c.Get(cacheKeyDigest)
|
||||||
|
if err == nil {
|
||||||
|
received.Add(cs.Chunk.Size())
|
||||||
|
t.update(l, cs.Chunk.Size(), ErrCached)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
err := func() error {
|
|
||||||
|
wg.Add(1)
|
||||||
|
g.Go(func() (err error) {
|
||||||
|
defer func() {
|
||||||
|
if err == nil {
|
||||||
|
// Ignore cache key write errors for now. We've already
|
||||||
|
// reported to trace that the chunk is complete.
|
||||||
|
//
|
||||||
|
// Ideally, we should only report completion to trace
|
||||||
|
// after successful cache commit. This current approach
|
||||||
|
// works but could trigger unnecessary redownloads if
|
||||||
|
// the checkpoint key is missing on next pull.
|
||||||
|
//
|
||||||
|
// Not incorrect, just suboptimal - fix this in a
|
||||||
|
// future update.
|
||||||
|
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
|
||||||
|
|
||||||
|
received.Add(cs.Chunk.Size())
|
||||||
|
} else {
|
||||||
|
t.update(l, 0, err)
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -518,49 +555,40 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
// Count bytes towards
|
body := &trackingReader{l: l, r: res.Body, update: t.update}
|
||||||
// progress, as they arrive, so
|
return chunked.Put(cs.Chunk, cs.Digest, body)
|
||||||
// that our bytes piggyback
|
|
||||||
// other chunk updates on
|
|
||||||
// completion.
|
|
||||||
//
|
|
||||||
// This tactic is enough to
|
|
||||||
// show "smooth" progress given
|
|
||||||
// the current CLI client. In
|
|
||||||
// the near future, the server
|
|
||||||
// should report download rate
|
|
||||||
// since it knows better than
|
|
||||||
// a client that is measuring
|
|
||||||
// rate based on wall-clock
|
|
||||||
// time-since-last-update.
|
|
||||||
body := &trackingReader{r: res.Body, n: &progress}
|
|
||||||
|
|
||||||
err = chunked.Put(cs.Chunk, cs.Digest, body)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}()
|
|
||||||
if !canRetry(err) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close writer immediately after downloads finish, not at Pull
|
||||||
|
// exit. Using defer would keep file descriptors open until all
|
||||||
|
// layers complete, potentially exhausting system limits with
|
||||||
|
// many layers.
|
||||||
|
//
|
||||||
|
// The WaitGroup tracks when all chunks finish downloading,
|
||||||
|
// allowing precise writer closure in a background goroutine.
|
||||||
|
// Each layer briefly uses one extra goroutine while at most
|
||||||
|
// maxStreams()-1 chunks download in parallel.
|
||||||
|
//
|
||||||
|
// This caps file descriptors at maxStreams() instead of
|
||||||
|
// growing with layer count.
|
||||||
|
g.Go(func() error {
|
||||||
|
wg.Wait()
|
||||||
|
chunked.Close()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
if err := g.Wait(); err != nil {
|
if err := g.Wait(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if received.Load() != expected {
|
||||||
|
return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, received.Load(), expected)
|
||||||
|
}
|
||||||
|
|
||||||
// store the manifest blob
|
|
||||||
md := blob.DigestFromBytes(m.Data)
|
md := blob.DigestFromBytes(m.Data)
|
||||||
if err := blob.PutBytes(c, md, m.Data); err != nil {
|
if err := blob.PutBytes(c, md, m.Data); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// commit the manifest with a link
|
|
||||||
return c.Link(m.Name, md)
|
return c.Link(m.Name, md)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -599,6 +627,30 @@ func (m *Manifest) Layer(d blob.Digest) *Layer {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manifest) All() iter.Seq[*Layer] {
|
||||||
|
return func(yield func(*Layer) bool) {
|
||||||
|
if !yield(m.Config) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, l := range m.Layers {
|
||||||
|
if !yield(l) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manifest) Size() int64 {
|
||||||
|
var size int64
|
||||||
|
if m.Config != nil {
|
||||||
|
size += m.Config.Size
|
||||||
|
}
|
||||||
|
for _, l := range m.Layers {
|
||||||
|
size += l.Size
|
||||||
|
}
|
||||||
|
return size
|
||||||
|
}
|
||||||
|
|
||||||
// MarshalJSON implements json.Marshaler.
|
// MarshalJSON implements json.Marshaler.
|
||||||
//
|
//
|
||||||
// NOTE: It adds an empty config object to the manifest, which is required by
|
// NOTE: It adds an empty config object to the manifest, which is required by
|
||||||
@@ -741,20 +793,32 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// A chunksums response is a sequence of chunksums in a
|
// The response is a sequence of chunksums.
|
||||||
// simple, easy to parse line-oriented format.
|
|
||||||
//
|
//
|
||||||
// Example:
|
// Chunksums are chunks of a larger blob that can be
|
||||||
|
// downloaded and verified independently.
|
||||||
//
|
//
|
||||||
// >> GET /v2/<namespace>/<model>/chunksums/<digest>
|
// The chunksums endpoint is a GET request that returns a
|
||||||
|
// sequence of chunksums in the following format:
|
||||||
//
|
//
|
||||||
// << HTTP/1.1 200 OK
|
// > GET /v2/<namespace>/<model>/chunksums/<digest>
|
||||||
// << Content-Location: <blobURL>
|
|
||||||
// <<
|
|
||||||
// << <digest> <start>-<end>
|
|
||||||
// << ...
|
|
||||||
//
|
//
|
||||||
// The blobURL is the URL to download the chunks from.
|
// < HTTP/1.1 200 OK
|
||||||
|
// < Content-Location: <blobURL>
|
||||||
|
// <
|
||||||
|
// < <digest> <start>-<end>
|
||||||
|
// < ...
|
||||||
|
//
|
||||||
|
// The <blobURL> is the URL to download the chunks from and
|
||||||
|
// each <digest> is the digest of the chunk, and <start>-<end>
|
||||||
|
// is the range the chunk in the blob.
|
||||||
|
//
|
||||||
|
// Ranges may be used directly in Range headers like
|
||||||
|
// "bytes=<start>-<end>".
|
||||||
|
//
|
||||||
|
// The chunksums returned are guaranteed to be contiguous and
|
||||||
|
// include all bytes of the layer. If the stream is cut short,
|
||||||
|
// clients should retry.
|
||||||
|
|
||||||
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
|
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
|
||||||
scheme,
|
scheme,
|
||||||
|
|||||||
@@ -9,21 +9,41 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"math/rand/v2"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||||
"github.com/ollama/ollama/server/internal/testutil"
|
"github.com/ollama/ollama/server/internal/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func ExampleRegistry_cancelOnFirstError() {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
ctx = WithTrace(ctx, &Trace{
|
||||||
|
Update: func(l *Layer, n int64, err error) {
|
||||||
|
if err != nil {
|
||||||
|
// Discontinue pulling layers if there is an
|
||||||
|
// error instead of continuing to pull more
|
||||||
|
// data.
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
var r Registry
|
||||||
|
if err := r.Pull(ctx, "model"); err != nil {
|
||||||
|
// panic for demo purposes
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestManifestMarshalJSON(t *testing.T) {
|
func TestManifestMarshalJSON(t *testing.T) {
|
||||||
// All manifests should contain an "empty" config object.
|
// All manifests should contain an "empty" config object.
|
||||||
var m Manifest
|
var m Manifest
|
||||||
@@ -70,7 +90,7 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
|
|||||||
// communication is attempted.
|
// communication is attempted.
|
||||||
//
|
//
|
||||||
// To simulate a network error, pass a handler that returns a 499 status code.
|
// To simulate a network error, pass a handler that returns a 499 status code.
|
||||||
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
c, err := blob.Open(t.TempDir())
|
c, err := blob.Open(t.TempDir())
|
||||||
@@ -88,7 +108,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
|||||||
r := &Registry{
|
r := &Registry{
|
||||||
Cache: c,
|
Cache: c,
|
||||||
HTTPClient: &http.Client{
|
HTTPClient: &http.Client{
|
||||||
Transport: recordRoundTripper(h),
|
Transport: recordRoundTripper(upstreamRegistry),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -315,15 +335,8 @@ func TestPushCommitRoundtripError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkNotExist(t *testing.T, err error) {
|
|
||||||
t.Helper()
|
|
||||||
if !errors.Is(err, fs.ErrNotExist) {
|
|
||||||
t.Fatalf("err = %v; want fs.ErrNotExist", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistryPullInvalidName(t *testing.T) {
|
func TestRegistryPullInvalidName(t *testing.T) {
|
||||||
rc, _ := newClient(t, nil)
|
rc, _ := newRegistryClient(t, nil)
|
||||||
err := rc.Pull(t.Context(), "://")
|
err := rc.Pull(t.Context(), "://")
|
||||||
if !errors.Is(err, ErrNameInvalid) {
|
if !errors.Is(err, ErrNameInvalid) {
|
||||||
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
||||||
@@ -339,197 +352,16 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, resp := range cases {
|
for _, resp := range cases {
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
io.WriteString(w, resp)
|
io.WriteString(w, resp)
|
||||||
})
|
})
|
||||||
err := rc.Pull(t.Context(), "x")
|
err := rc.Pull(t.Context(), "http://example.com/a/b")
|
||||||
if !errors.Is(err, ErrManifestInvalid) {
|
if !errors.Is(err, ErrManifestInvalid) {
|
||||||
t.Errorf("err = %v; want invalid manifest", err)
|
t.Errorf("err = %v; want invalid manifest", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistryPullNotCached(t *testing.T) {
|
|
||||||
check := testutil.Checker(t)
|
|
||||||
|
|
||||||
var c *blob.DiskCache
|
|
||||||
var rc *Registry
|
|
||||||
|
|
||||||
d := blob.DigestFromBytes("some data")
|
|
||||||
rc, c = newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
|
||||||
io.WriteString(w, "some data")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":9}]}`, d)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Confirm that the layer does not exist locally
|
|
||||||
_, err := rc.ResolveLocal("model")
|
|
||||||
checkNotExist(t, err)
|
|
||||||
|
|
||||||
_, err = c.Get(d)
|
|
||||||
checkNotExist(t, err)
|
|
||||||
|
|
||||||
err = rc.Pull(t.Context(), "model")
|
|
||||||
check(err)
|
|
||||||
|
|
||||||
mw, err := rc.Resolve(t.Context(), "model")
|
|
||||||
check(err)
|
|
||||||
mg, err := rc.ResolveLocal("model")
|
|
||||||
check(err)
|
|
||||||
if !reflect.DeepEqual(mw, mg) {
|
|
||||||
t.Errorf("mw = %v; mg = %v", mw, mg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Confirm successful download
|
|
||||||
info, err := c.Get(d)
|
|
||||||
check(err)
|
|
||||||
if info.Digest != d {
|
|
||||||
t.Errorf("info.Digest = %v; want %v", info.Digest, d)
|
|
||||||
}
|
|
||||||
if info.Size != 9 {
|
|
||||||
t.Errorf("info.Size = %v; want %v", info.Size, 9)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := os.ReadFile(c.GetFile(d))
|
|
||||||
check(err)
|
|
||||||
if string(data) != "some data" {
|
|
||||||
t.Errorf("data = %q; want %q", data, "exists")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistryPullCached(t *testing.T) {
|
|
||||||
cached := blob.DigestFromBytes("exists")
|
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
|
||||||
w.WriteHeader(499) // should not be called
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if strings.Contains(r.URL.Path, "/manifests/") {
|
|
||||||
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, cached)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
var errs []error
|
|
||||||
var reads []int64
|
|
||||||
ctx := WithTrace(t.Context(), &Trace{
|
|
||||||
Update: func(d *Layer, n int64, err error) {
|
|
||||||
t.Logf("update %v %d %v", d, n, err)
|
|
||||||
reads = append(reads, n)
|
|
||||||
errs = append(errs, err)
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
err := rc.Pull(ctx, "single")
|
|
||||||
testutil.Check(t, err)
|
|
||||||
|
|
||||||
want := []int64{0, 6}
|
|
||||||
if !errors.Is(errors.Join(errs...), ErrCached) {
|
|
||||||
t.Errorf("errs = %v; want %v", errs, ErrCached)
|
|
||||||
}
|
|
||||||
if !slices.Equal(reads, want) {
|
|
||||||
t.Errorf("pairs = %v; want %v", reads, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistryPullManifestNotFound(t *testing.T) {
|
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusNotFound)
|
|
||||||
})
|
|
||||||
err := rc.Pull(t.Context(), "notfound")
|
|
||||||
checkErrCode(t, err, 404, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistryPullResolveRemoteError(t *testing.T) {
|
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
|
|
||||||
})
|
|
||||||
err := rc.Pull(t.Context(), "single")
|
|
||||||
checkErrCode(t, err, 500, "an_error")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistryPullResolveRoundtripError(t *testing.T) {
|
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if strings.Contains(r.URL.Path, "/manifests/") {
|
|
||||||
w.WriteHeader(499) // force RoundTrip error
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
|
||||||
err := rc.Pull(t.Context(), "single")
|
|
||||||
if !errors.Is(err, errRoundTrip) {
|
|
||||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRegistryPullMixedCachedNotCached tests that cached layers do not
|
|
||||||
// interfere with pulling layers that are not cached
|
|
||||||
func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
|
||||||
x := blob.DigestFromBytes("xxxxxx")
|
|
||||||
e := blob.DigestFromBytes("exists")
|
|
||||||
y := blob.DigestFromBytes("yyyyyy")
|
|
||||||
|
|
||||||
for i := range 10 {
|
|
||||||
t.Logf("iteration %d", i)
|
|
||||||
|
|
||||||
digests := []blob.Digest{x, e, y}
|
|
||||||
|
|
||||||
rand.Shuffle(len(digests), func(i, j int) {
|
|
||||||
digests[i], digests[j] = digests[j], digests[i]
|
|
||||||
})
|
|
||||||
|
|
||||||
manifest := fmt.Sprintf(`{
|
|
||||||
"layers": [
|
|
||||||
{"digest":"%s","size":6},
|
|
||||||
{"digest":"%s","size":6},
|
|
||||||
{"digest":"%s","size":6}
|
|
||||||
]
|
|
||||||
}`, digests[0], digests[1], digests[2])
|
|
||||||
|
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
switch path.Base(r.URL.Path) {
|
|
||||||
case "latest":
|
|
||||||
io.WriteString(w, manifest)
|
|
||||||
case x.String():
|
|
||||||
io.WriteString(w, "xxxxxx")
|
|
||||||
case e.String():
|
|
||||||
io.WriteString(w, "exists")
|
|
||||||
case y.String():
|
|
||||||
io.WriteString(w, "yyyyyy")
|
|
||||||
default:
|
|
||||||
panic(fmt.Sprintf("unexpected request: %v", r))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx := WithTrace(t.Context(), &Trace{
|
|
||||||
Update: func(l *Layer, n int64, err error) {
|
|
||||||
t.Logf("update %v %d %v", l, n, err)
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Check that we pull all layers that we can.
|
|
||||||
|
|
||||||
err := rc.Pull(ctx, "mixed")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, d := range digests {
|
|
||||||
info, err := c.Get(d)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Get(%v): %v", d, err)
|
|
||||||
}
|
|
||||||
if info.Size != 6 {
|
|
||||||
t.Errorf("info.Size = %v; want %v", info.Size, 6)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistryResolveByDigest(t *testing.T) {
|
func TestRegistryResolveByDigest(t *testing.T) {
|
||||||
check := testutil.Checker(t)
|
check := testutil.Checker(t)
|
||||||
|
|
||||||
@@ -567,26 +399,6 @@ func TestInsecureSkipVerify(t *testing.T) {
|
|||||||
testutil.Check(t, err)
|
testutil.Check(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCanRetry(t *testing.T) {
|
|
||||||
cases := []struct {
|
|
||||||
err error
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{nil, false},
|
|
||||||
{errors.New("x"), false},
|
|
||||||
{ErrCached, false},
|
|
||||||
{ErrManifestInvalid, false},
|
|
||||||
{ErrNameInvalid, false},
|
|
||||||
{&Error{Status: 100}, false},
|
|
||||||
{&Error{Status: 500}, true},
|
|
||||||
}
|
|
||||||
for _, tt := range cases {
|
|
||||||
if got := canRetry(tt.err); got != tt.want {
|
|
||||||
t.Errorf("CanRetry(%v) = %v; want %v", tt.err, got, tt.want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestErrorUnmarshal(t *testing.T) {
|
func TestErrorUnmarshal(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -738,17 +550,23 @@ func TestParseNameExtended(t *testing.T) {
|
|||||||
|
|
||||||
func TestUnlink(t *testing.T) {
|
func TestUnlink(t *testing.T) {
|
||||||
t.Run("found by name", func(t *testing.T) {
|
t.Run("found by name", func(t *testing.T) {
|
||||||
rc, _ := newClient(t, nil)
|
check := testutil.Checker(t)
|
||||||
|
|
||||||
|
rc, _ := newRegistryClient(t, nil)
|
||||||
|
// make a blob and link it
|
||||||
|
d := blob.DigestFromBytes("{}")
|
||||||
|
err := blob.PutBytes(rc.Cache, d, "{}")
|
||||||
|
check(err)
|
||||||
|
err = rc.Cache.Link("registry.ollama.ai/library/single:latest", d)
|
||||||
|
check(err)
|
||||||
|
|
||||||
// confirm linked
|
// confirm linked
|
||||||
_, err := rc.ResolveLocal("single")
|
_, err = rc.ResolveLocal("single")
|
||||||
if err != nil {
|
check(err)
|
||||||
t.Errorf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// unlink
|
// unlink
|
||||||
_, err = rc.Unlink("single")
|
_, err = rc.Unlink("single")
|
||||||
testutil.Check(t, err)
|
check(err)
|
||||||
|
|
||||||
// confirm unlinked
|
// confirm unlinked
|
||||||
_, err = rc.ResolveLocal("single")
|
_, err = rc.ResolveLocal("single")
|
||||||
@@ -757,7 +575,7 @@ func TestUnlink(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("not found by name", func(t *testing.T) {
|
t.Run("not found by name", func(t *testing.T) {
|
||||||
rc, _ := newClient(t, nil)
|
rc, _ := newRegistryClient(t, nil)
|
||||||
ok, err := rc.Unlink("manifestNotFound")
|
ok, err := rc.Unlink("manifestNotFound")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -767,3 +585,369 @@ func TestUnlink(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Many tests from here out, in this file are based on a single blob, "abc",
|
||||||
|
// with the checksum of its sha256 hash. The checksum is:
|
||||||
|
//
|
||||||
|
// "abc" -> sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad
|
||||||
|
//
|
||||||
|
// Using the literal value instead of a constant with fmt.Xprintf calls proved
|
||||||
|
// to be the most readable and maintainable approach. The sum is consistently
|
||||||
|
// used in the tests and unique so searches do not yield false positives.
|
||||||
|
|
||||||
|
func checkRequest(t *testing.T, req *http.Request, method, path string) {
|
||||||
|
t.Helper()
|
||||||
|
if got := req.URL.Path; got != path {
|
||||||
|
t.Errorf("URL = %q, want %q", got, path)
|
||||||
|
}
|
||||||
|
if req.Method != method {
|
||||||
|
t.Errorf("Method = %q, want %q", req.Method, method)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRegistryClient(t *testing.T, h http.HandlerFunc) (*Registry, context.Context) {
|
||||||
|
s := httptest.NewServer(h)
|
||||||
|
t.Cleanup(s.Close)
|
||||||
|
cache, err := blob.Open(t.TempDir())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := WithTrace(t.Context(), &Trace{
|
||||||
|
Update: func(l *Layer, n int64, err error) {
|
||||||
|
t.Log("trace:", l.Digest.Short(), n, err)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
rc := &Registry{
|
||||||
|
Cache: cache,
|
||||||
|
HTTPClient: &http.Client{Transport: &http.Transport{
|
||||||
|
Dial: func(network, addr string) (net.Conn, error) {
|
||||||
|
return net.Dial(network, s.Listener.Addr().String())
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
return rc, ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPullChunked(t *testing.T) {
|
||||||
|
var steps atomic.Int64
|
||||||
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch steps.Add(1) {
|
||||||
|
case 1:
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||||
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||||
|
case 2:
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||||
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||||
|
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||||
|
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||||
|
case 3, 4:
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||||
|
switch rng := r.Header.Get("Range"); rng {
|
||||||
|
case "bytes=0-1":
|
||||||
|
io.WriteString(w, "ab")
|
||||||
|
case "bytes=2-2":
|
||||||
|
t.Logf("writing c")
|
||||||
|
io.WriteString(w, "c")
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected range %q", rng)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected steps %d: %v", steps.Load(), r)
|
||||||
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
c.ChunkingThreshold = 1 // force chunking
|
||||||
|
|
||||||
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||||
|
testutil.Check(t, err)
|
||||||
|
|
||||||
|
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||||
|
testutil.Check(t, err)
|
||||||
|
|
||||||
|
if g := steps.Load(); g != 4 {
|
||||||
|
t.Fatalf("got %d steps, want 4", g)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPullCached(t *testing.T) {
|
||||||
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||||
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||||
|
})
|
||||||
|
|
||||||
|
check := testutil.Checker(t)
|
||||||
|
|
||||||
|
// Premeptively cache the blob
|
||||||
|
d, err := blob.ParseDigest("sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||||
|
check(err)
|
||||||
|
err = blob.PutBytes(c.Cache, d, []byte("abc"))
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
// Pull only the manifest, which should be enough to resolve the cached blob
|
||||||
|
err = c.Pull(ctx, "http://o.com/library/abc")
|
||||||
|
check(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPullManifestError(t *testing.T) {
|
||||||
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
io.WriteString(w, `{"errors":[{"code":"MANIFEST_UNKNOWN"}]}`)
|
||||||
|
})
|
||||||
|
|
||||||
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error")
|
||||||
|
}
|
||||||
|
var got *Error
|
||||||
|
if !errors.Is(err, ErrModelNotFound) {
|
||||||
|
t.Fatalf("err = %v, want %v", got, ErrModelNotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPullLayerError(t *testing.T) {
|
||||||
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||||
|
io.WriteString(w, `!`)
|
||||||
|
})
|
||||||
|
|
||||||
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error")
|
||||||
|
}
|
||||||
|
var want *json.SyntaxError
|
||||||
|
if !errors.As(err, &want) {
|
||||||
|
t.Fatalf("err = %T, want %T", err, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPullLayerChecksumError(t *testing.T) {
|
||||||
|
var step atomic.Int64
|
||||||
|
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch step.Add(1) {
|
||||||
|
case 1:
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||||
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||||
|
case 2:
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||||
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||||
|
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||||
|
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||||
|
case 3:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
io.WriteString(w, `{"errors":[{"code":"BLOB_UNKNOWN"}]}`)
|
||||||
|
case 4:
|
||||||
|
io.WriteString(w, "c")
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||||
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
c.MaxStreams = 1
|
||||||
|
c.ChunkingThreshold = 1 // force chunking
|
||||||
|
|
||||||
|
var written atomic.Int64
|
||||||
|
ctx := WithTrace(t.Context(), &Trace{
|
||||||
|
Update: func(l *Layer, n int64, err error) {
|
||||||
|
t.Log("trace:", l.Digest.Short(), n, err)
|
||||||
|
written.Add(n)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||||
|
var got *Error
|
||||||
|
if !errors.As(err, &got) || got.Code != "BLOB_UNKNOWN" {
|
||||||
|
t.Fatalf("err = %v, want %v", err, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if g := written.Load(); g != 1 {
|
||||||
|
t.Fatalf("wrote %d bytes, want 1", g)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPullChunksumStreamError(t *testing.T) {
|
||||||
|
var step atomic.Int64
|
||||||
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch step.Add(1) {
|
||||||
|
case 1:
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||||
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||||
|
case 2:
|
||||||
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||||
|
|
||||||
|
// Write one valid chunksum and one invalid chunksum
|
||||||
|
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) // valid
|
||||||
|
fmt.Fprint(w, "sha256:!") // invalid
|
||||||
|
case 3:
|
||||||
|
io.WriteString(w, "ab")
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||||
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
c.ChunkingThreshold = 1 // force chunking
|
||||||
|
|
||||||
|
got := c.Pull(ctx, "http://o.com/library/abc")
|
||||||
|
if !errors.Is(got, ErrIncomplete) {
|
||||||
|
t.Fatalf("err = %v, want %v", got, ErrIncomplete)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type flushAfterWriter struct {
|
||||||
|
w io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *flushAfterWriter) Write(p []byte) (n int, err error) {
|
||||||
|
n, err = f.w.Write(p)
|
||||||
|
f.w.(http.Flusher).Flush() // panic if not a flusher
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPullChunksumStreaming(t *testing.T) {
|
||||||
|
csr, csw := io.Pipe()
|
||||||
|
defer csw.Close()
|
||||||
|
|
||||||
|
var step atomic.Int64
|
||||||
|
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch step.Add(1) {
|
||||||
|
case 1:
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||||
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||||
|
case 2:
|
||||||
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||||
|
fw := &flushAfterWriter{w} // ensure client gets data as it arrives by aggressively flushing
|
||||||
|
_, err := io.Copy(fw, csr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("copy: %v", err)
|
||||||
|
}
|
||||||
|
case 3:
|
||||||
|
io.WriteString(w, "ab")
|
||||||
|
case 4:
|
||||||
|
io.WriteString(w, "c")
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||||
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
c.ChunkingThreshold = 1 // force chunking
|
||||||
|
|
||||||
|
update := make(chan int64, 1)
|
||||||
|
ctx := WithTrace(t.Context(), &Trace{
|
||||||
|
Update: func(l *Layer, n int64, err error) {
|
||||||
|
t.Log("trace:", l.Digest.Short(), n, err)
|
||||||
|
if n > 0 {
|
||||||
|
update <- n
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
errc := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
errc <- c.Pull(ctx, "http://o.com/library/abc")
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Send first chunksum and ensure it kicks off work immediately
|
||||||
|
fmt.Fprintf(csw, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||||
|
if g := <-update; g != 2 {
|
||||||
|
t.Fatalf("got %d, want 2", g)
|
||||||
|
}
|
||||||
|
|
||||||
|
// now send the second chunksum and ensure it kicks off work immediately
|
||||||
|
fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||||
|
if g := <-update; g != 1 {
|
||||||
|
t.Fatalf("got %d, want 1", g)
|
||||||
|
}
|
||||||
|
csw.Close()
|
||||||
|
testutil.Check(t, <-errc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPullChunksumsCached(t *testing.T) {
|
||||||
|
var step atomic.Int64
|
||||||
|
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch step.Add(1) {
|
||||||
|
case 1:
|
||||||
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
||||||
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
||||||
|
case 2:
|
||||||
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||||
|
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
||||||
|
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
||||||
|
case 3, 4:
|
||||||
|
switch rng := r.Header.Get("Range"); rng {
|
||||||
|
case "bytes=0-1":
|
||||||
|
io.WriteString(w, "ab")
|
||||||
|
case "bytes=2-2":
|
||||||
|
io.WriteString(w, "c")
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected range %q", rng)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
||||||
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
c.MaxStreams = 1 // force serial processing of chunksums
|
||||||
|
c.ChunkingThreshold = 1 // force chunking
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Cancel the pull after the first chunksum is processed, but before
|
||||||
|
// the second chunksum is processed (which is waiting because
|
||||||
|
// MaxStreams=1). This should cause the second chunksum to error out
|
||||||
|
// leaving the blob incomplete.
|
||||||
|
ctx = WithTrace(ctx, &Trace{
|
||||||
|
Update: func(l *Layer, n int64, err error) {
|
||||||
|
if n > 0 {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("err = %v, want %v", err, context.Canceled)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||||
|
if !errors.Is(err, fs.ErrNotExist) {
|
||||||
|
t.Fatalf("err = %v, want nil", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset state and pull again to ensure the blob chunks that should
|
||||||
|
// have been cached are, and the remaining chunk was downloaded, making
|
||||||
|
// the blob complete.
|
||||||
|
step.Store(0)
|
||||||
|
var written atomic.Int64
|
||||||
|
var cached atomic.Int64
|
||||||
|
ctx = WithTrace(t.Context(), &Trace{
|
||||||
|
Update: func(l *Layer, n int64, err error) {
|
||||||
|
t.Log("trace:", l.Digest.Short(), n, err)
|
||||||
|
if errors.Is(err, ErrCached) {
|
||||||
|
cached.Add(n)
|
||||||
|
}
|
||||||
|
written.Add(n)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
check := testutil.Checker(t)
|
||||||
|
|
||||||
|
err = c.Pull(ctx, "http://o.com/library/abc")
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
||||||
|
check(err)
|
||||||
|
|
||||||
|
if g := written.Load(); g != 3 {
|
||||||
|
t.Fatalf("wrote %d bytes, want 3", g)
|
||||||
|
}
|
||||||
|
if g := cached.Load(); g != 2 { // "ab" should have been cached
|
||||||
|
t.Fatalf("cached %d bytes, want 3", g)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -200,7 +200,7 @@ type params struct {
|
|||||||
//
|
//
|
||||||
// Unfortunately, this API was designed to be a bit awkward. Stream is
|
// Unfortunately, this API was designed to be a bit awkward. Stream is
|
||||||
// defined to default to true if not present, so we need a way to check
|
// defined to default to true if not present, so we need a way to check
|
||||||
// if the client decisively it to false. So, we use a pointer to a
|
// if the client decisively set it to false. So, we use a pointer to a
|
||||||
// bool. Gross.
|
// bool. Gross.
|
||||||
//
|
//
|
||||||
// Use [stream()] to get the correct value for this field.
|
// Use [stream()] to get the correct value for this field.
|
||||||
@@ -280,17 +280,17 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
progress := make(map[*ollama.Layer]int64)
|
progress := make(map[*ollama.Layer]int64)
|
||||||
|
|
||||||
progressCopy := make(map[*ollama.Layer]int64, len(progress))
|
progressCopy := make(map[*ollama.Layer]int64, len(progress))
|
||||||
pushUpdate := func() {
|
flushProgress := func() {
|
||||||
defer maybeFlush()
|
defer maybeFlush()
|
||||||
|
|
||||||
// TODO(bmizerany): This scales poorly with more layers due to
|
// TODO(bmizerany): Flushing every layer in one update doesn't
|
||||||
// needing to flush out them all in one big update. We _could_
|
// scale well. We could flush only the modified layers or track
|
||||||
// just flush on the changed ones, or just track the whole
|
// the full download. Needs further consideration, though it's
|
||||||
// download. Needs more thought. This is fine for now.
|
// fine for now.
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
maps.Copy(progressCopy, progress)
|
maps.Copy(progressCopy, progress)
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
for l, n := range progress {
|
for l, n := range progressCopy {
|
||||||
enc.Encode(progressUpdateJSON{
|
enc.Encode(progressUpdateJSON{
|
||||||
Digest: l.Digest,
|
Digest: l.Digest,
|
||||||
Total: l.Size,
|
Total: l.Size,
|
||||||
@@ -298,19 +298,26 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
defer flushProgress()
|
||||||
|
|
||||||
t := time.NewTicker(time.Hour) // "unstarted" timer
|
t := time.NewTicker(1000 * time.Hour) // "unstarted" timer
|
||||||
start := sync.OnceFunc(func() {
|
start := sync.OnceFunc(func() {
|
||||||
pushUpdate()
|
flushProgress() // flush initial state
|
||||||
t.Reset(100 * time.Millisecond)
|
t.Reset(100 * time.Millisecond)
|
||||||
})
|
})
|
||||||
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
||||||
Update: func(l *ollama.Layer, n int64, err error) {
|
Update: func(l *ollama.Layer, n int64, err error) {
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
start() // flush initial state
|
// Block flushing progress updates until every
|
||||||
|
// layer is accounted for. Clients depend on a
|
||||||
|
// complete model size to calculate progress
|
||||||
|
// correctly; if they use an incomplete total,
|
||||||
|
// progress indicators would erratically jump
|
||||||
|
// as new layers are registered.
|
||||||
|
start()
|
||||||
}
|
}
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
progress[l] = n
|
progress[l] += n
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -323,9 +330,9 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-t.C:
|
case <-t.C:
|
||||||
pushUpdate()
|
flushProgress()
|
||||||
case err := <-done:
|
case err := <-done:
|
||||||
pushUpdate()
|
flushProgress()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var status string
|
var status string
|
||||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
if errors.Is(err, ollama.ErrModelNotFound) {
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
|||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
||||||
if t, err := template.Named(s); err != nil {
|
if t, err := template.Named(s); err != nil {
|
||||||
slog.Debug("template detection", "error", err)
|
slog.Debug("template detection", "error", err, "template", s)
|
||||||
} else {
|
} else {
|
||||||
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -31,9 +31,10 @@ const (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
ErrInvalidImageFormat = errors.New("invalid image format")
|
ErrInvalidImageFormat = errors.New("invalid image format")
|
||||||
|
ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||||
ErrInvalidProtocol = errors.New("invalid protocol scheme")
|
ErrInvalidProtocol = errors.New("invalid protocol scheme")
|
||||||
ErrInsecureProtocol = errors.New("insecure protocol http")
|
ErrInsecureProtocol = errors.New("insecure protocol http")
|
||||||
ErrInvalidDigestFormat = errors.New("invalid digest format")
|
ErrModelPathInvalid = errors.New("invalid model path")
|
||||||
)
|
)
|
||||||
|
|
||||||
func ParseModelPath(name string) ModelPath {
|
func ParseModelPath(name string) ModelPath {
|
||||||
@@ -73,8 +74,6 @@ func ParseModelPath(name string) ModelPath {
|
|||||||
return mp
|
return mp
|
||||||
}
|
}
|
||||||
|
|
||||||
var errModelPathInvalid = errors.New("invalid model path")
|
|
||||||
|
|
||||||
func (mp ModelPath) GetNamespaceRepository() string {
|
func (mp ModelPath) GetNamespaceRepository() string {
|
||||||
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
|
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ var (
|
|||||||
errBadTemplate = errors.New("template error")
|
errBadTemplate = errors.New("template error")
|
||||||
)
|
)
|
||||||
|
|
||||||
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
func modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) {
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
if err := opts.FromMap(model.Options); err != nil {
|
if err := opts.FromMap(model.Options); err != nil {
|
||||||
return api.Options{}, err
|
return api.Options{}, err
|
||||||
@@ -87,7 +87,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
|
|||||||
|
|
||||||
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
|
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
|
||||||
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
|
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
|
||||||
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
|
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
|
||||||
}
|
}
|
||||||
@@ -144,7 +144,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
model, err := GetModel(name.String())
|
m, err := GetModel(name.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, fs.ErrNotExist):
|
case errors.Is(err, fs.ErrNotExist):
|
||||||
@@ -159,7 +159,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
|
|
||||||
// expire the runner
|
// expire the runner
|
||||||
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
|
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
|
||||||
s.sched.expireRunner(model)
|
s.sched.expireRunner(m)
|
||||||
|
|
||||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
@@ -176,9 +176,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
caps := []Capability{CapabilityCompletion}
|
caps := []model.Capability{model.CapabilityCompletion}
|
||||||
if req.Suffix != "" {
|
if req.Suffix != "" {
|
||||||
caps = append(caps, CapabilityInsert)
|
caps = append(caps, model.CapabilityInsert)
|
||||||
}
|
}
|
||||||
|
|
||||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||||
@@ -203,7 +203,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
isMllama := checkMllamaModelFamily(model)
|
isMllama := checkMllamaModelFamily(m)
|
||||||
if isMllama && len(req.Images) > 1 {
|
if isMllama && len(req.Images) > 1 {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"})
|
||||||
return
|
return
|
||||||
@@ -211,7 +211,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
|
|
||||||
images := make([]llm.ImageData, len(req.Images))
|
images := make([]llm.ImageData, len(req.Images))
|
||||||
for i := range req.Images {
|
for i := range req.Images {
|
||||||
if isMllama && len(model.ProjectorPaths) > 0 {
|
if isMllama && len(m.ProjectorPaths) > 0 {
|
||||||
data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
|
data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
|
||||||
@@ -422,7 +422,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleScheduleError(c, req.Model, err)
|
handleScheduleError(c, req.Model, err)
|
||||||
return
|
return
|
||||||
@@ -530,7 +530,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
|
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleScheduleError(c, req.Model, err)
|
handleScheduleError(c, req.Model, err)
|
||||||
return
|
return
|
||||||
@@ -777,7 +777,7 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
|||||||
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||||
name := model.ParseName(req.Model)
|
name := model.ParseName(req.Model)
|
||||||
if !name.IsValid() {
|
if !name.IsValid() {
|
||||||
return nil, errModelPathInvalid
|
return nil, ErrModelPathInvalid
|
||||||
}
|
}
|
||||||
name, err := getExistingName(name)
|
name, err := getExistingName(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -818,6 +818,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
Template: m.Template.String(),
|
Template: m.Template.String(),
|
||||||
Details: modelDetails,
|
Details: modelDetails,
|
||||||
Messages: msgs,
|
Messages: msgs,
|
||||||
|
Capabilities: m.Capabilities(),
|
||||||
ModifiedAt: manifest.fi.ModTime(),
|
ModifiedAt: manifest.fi.ModTime(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -825,7 +826,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
cs := 30
|
cs := 30
|
||||||
for k, v := range m.Options {
|
for k, v := range m.Options {
|
||||||
switch val := v.(type) {
|
switch val := v.(type) {
|
||||||
case []interface{}:
|
case []any:
|
||||||
for _, nv := range val {
|
for _, nv := range val {
|
||||||
params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv))
|
params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv))
|
||||||
}
|
}
|
||||||
@@ -1335,7 +1336,7 @@ func Serve(ln net.Listener) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func waitForStream(c *gin.Context, ch chan interface{}) {
|
func waitForStream(c *gin.Context, ch chan any) {
|
||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
for resp := range ch {
|
for resp := range ch {
|
||||||
switch r := resp.(type) {
|
switch r := resp.(type) {
|
||||||
@@ -1468,9 +1469,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
caps := []Capability{CapabilityCompletion}
|
caps := []model.Capability{model.CapabilityCompletion}
|
||||||
if len(req.Tools) > 0 {
|
if len(req.Tools) > 0 {
|
||||||
caps = append(caps, CapabilityTools)
|
caps = append(caps, model.CapabilityTools)
|
||||||
}
|
}
|
||||||
|
|
||||||
name := model.ParseName(req.Model)
|
name := model.ParseName(req.Model)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LlmRequest struct {
|
type LlmRequest struct {
|
||||||
@@ -37,7 +38,7 @@ type Scheduler struct {
|
|||||||
pendingReqCh chan *LlmRequest
|
pendingReqCh chan *LlmRequest
|
||||||
finishedReqCh chan *LlmRequest
|
finishedReqCh chan *LlmRequest
|
||||||
expiredCh chan *runnerRef
|
expiredCh chan *runnerRef
|
||||||
unloadedCh chan interface{}
|
unloadedCh chan any
|
||||||
|
|
||||||
loaded map[string]*runnerRef
|
loaded map[string]*runnerRef
|
||||||
loadedMu sync.Mutex
|
loadedMu sync.Mutex
|
||||||
@@ -67,7 +68,7 @@ func InitScheduler(ctx context.Context) *Scheduler {
|
|||||||
pendingReqCh: make(chan *LlmRequest, maxQueue),
|
pendingReqCh: make(chan *LlmRequest, maxQueue),
|
||||||
finishedReqCh: make(chan *LlmRequest, maxQueue),
|
finishedReqCh: make(chan *LlmRequest, maxQueue),
|
||||||
expiredCh: make(chan *runnerRef, maxQueue),
|
expiredCh: make(chan *runnerRef, maxQueue),
|
||||||
unloadedCh: make(chan interface{}, maxQueue),
|
unloadedCh: make(chan any, maxQueue),
|
||||||
loaded: make(map[string]*runnerRef),
|
loaded: make(map[string]*runnerRef),
|
||||||
newServerFn: llm.NewLlamaServer,
|
newServerFn: llm.NewLlamaServer,
|
||||||
getGpuFn: discover.GetGPUInfo,
|
getGpuFn: discover.GetGPUInfo,
|
||||||
@@ -195,7 +196,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Embedding models should always be loaded with parallel=1
|
// Embedding models should always be loaded with parallel=1
|
||||||
if pending.model.CheckCapabilities(CapabilityCompletion) != nil {
|
if pending.model.CheckCapabilities(model.CapabilityCompletion) != nil {
|
||||||
numParallel = 1
|
numParallel = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -617,8 +618,8 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
|||||||
// a before and after GPU memory allocation. The returned channel
|
// a before and after GPU memory allocation. The returned channel
|
||||||
// will be notified when we're done waiting, or have timed out and should
|
// will be notified when we're done waiting, or have timed out and should
|
||||||
// proceed anyway
|
// proceed anyway
|
||||||
func (runner *runnerRef) waitForVRAMRecovery() chan interface{} {
|
func (runner *runnerRef) waitForVRAMRecovery() chan any {
|
||||||
finished := make(chan interface{}, 1)
|
finished := make(chan any, 1)
|
||||||
|
|
||||||
// CPU or Metal don't need checking, so no waiting required
|
// CPU or Metal don't need checking, so no waiting required
|
||||||
// windows can page VRAM, only cuda currently can report accurate used vram usage
|
// windows can page VRAM, only cuda currently can report accurate used vram usage
|
||||||
@@ -711,7 +712,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuIn
|
|||||||
req.opts.NumCtx = req.origNumCtx * p
|
req.opts.NumCtx = req.origNumCtx * p
|
||||||
if !envconfig.SchedSpread() {
|
if !envconfig.SchedSpread() {
|
||||||
for _, g := range sgl {
|
for _, g := range sgl {
|
||||||
if ok, estimatedVRAM = llm.PredictServerFit([]discover.GpuInfo{g}, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
if ok, estimatedVRAM = llm.PredictServerFit([]discover.GpuInfo{g}, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p); ok {
|
||||||
slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
|
slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
|
||||||
*numParallel = p
|
*numParallel = p
|
||||||
return []discover.GpuInfo{g}
|
return []discover.GpuInfo{g}
|
||||||
@@ -727,7 +728,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuIn
|
|||||||
// Now try all the GPUs
|
// Now try all the GPUs
|
||||||
for _, p := range numParallelToTry {
|
for _, p := range numParallelToTry {
|
||||||
req.opts.NumCtx = req.origNumCtx * p
|
req.opts.NumCtx = req.origNumCtx * p
|
||||||
if ok, estimatedVRAM = llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
if ok, estimatedVRAM = llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p); ok {
|
||||||
slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM))
|
slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM))
|
||||||
*numParallel = p
|
*numParallel = p
|
||||||
return sgl
|
return sgl
|
||||||
@@ -750,7 +751,7 @@ func pickBestPartialFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.Gp
|
|||||||
var bestEstimate uint64
|
var bestEstimate uint64
|
||||||
var bestFit int
|
var bestFit int
|
||||||
for i, gl := range byLibrary {
|
for i, gl := range byLibrary {
|
||||||
_, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
|
_, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, *numParallel)
|
||||||
if estimatedVRAM > bestEstimate {
|
if estimatedVRAM > bestEstimate {
|
||||||
bestEstimate = estimatedVRAM
|
bestEstimate = estimatedVRAM
|
||||||
bestFit = i
|
bestFit = i
|
||||||
@@ -825,7 +826,7 @@ func (s *Scheduler) expireRunner(model *Model) {
|
|||||||
// If not, pick a runner to unload, else return nil and the request can be loaded
|
// If not, pick a runner to unload, else return nil and the request can be loaded
|
||||||
func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList) *runnerRef {
|
func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList) *runnerRef {
|
||||||
slog.Debug("evaluating if CPU model load will fit in available system memory")
|
slog.Debug("evaluating if CPU model load will fit in available system memory")
|
||||||
estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts)
|
estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts, req.opts.NumCtx/req.origNumCtx)
|
||||||
if estimate.TotalSize <= gpus[0].FreeMemory {
|
if estimate.TotalSize <= gpus[0].FreeMemory {
|
||||||
slog.Debug("cpu inference mode, model fits in available system memory", "model", format.HumanBytes2(estimate.TotalSize), "available", format.HumanBytes2(gpus[0].FreeMemory))
|
slog.Debug("cpu inference mode, model fits in available system memory", "model", format.HumanBytes2(estimate.TotalSize), "available", format.HumanBytes2(gpus[0].FreeMemory))
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
13
template/gemma3-instruct.gotmpl
Normal file
13
template/gemma3-instruct.gotmpl
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{{- range $i, $_ := .Messages }}
|
||||||
|
{{- $last := eq (len (slice $.Messages $i)) 1 }}
|
||||||
|
{{- if eq .Role "user" }}<start_of_turn>user
|
||||||
|
{{- if and (eq $i 1) $.System }}
|
||||||
|
{{ $.System }}
|
||||||
|
{{ end }}
|
||||||
|
{{ .Content }}<end_of_turn>
|
||||||
|
{{ else if eq .Role "assistant" }}<start_of_turn>model
|
||||||
|
{{ .Content }}<end_of_turn>
|
||||||
|
{{ end }}
|
||||||
|
{{- if $last }}<start_of_turn>model
|
||||||
|
{{ end }}
|
||||||
|
{{- end }}
|
||||||
6
template/gemma3-instruct.json
Normal file
6
template/gemma3-instruct.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"<end_of_turn>"
|
||||||
|
],
|
||||||
|
"temperature": 0.1
|
||||||
|
}
|
||||||
@@ -87,6 +87,10 @@
|
|||||||
"template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
"template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||||
"name": "gemma-instruct"
|
"name": "gemma-instruct"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n",
|
||||||
|
"name": "gemma3-instruct"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
||||||
"name": "llama3-instruct"
|
"name": "llama3-instruct"
|
||||||
|
|||||||
10
template/testdata/gemma3-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
10
template/testdata/gemma3-instruct.gotmpl/system-user-assistant-user
vendored
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
<start_of_turn>user
|
||||||
|
You are a helpful assistant.
|
||||||
|
|
||||||
|
Hello, how are you?<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
I'm doing great. How can I help you today?<end_of_turn>
|
||||||
|
<start_of_turn>user
|
||||||
|
I'd like to show off how chat templating works!<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
|
||||||
4
template/testdata/gemma3-instruct.gotmpl/user
vendored
Normal file
4
template/testdata/gemma3-instruct.gotmpl/user
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
<start_of_turn>user
|
||||||
|
Hello, how are you?<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
|
||||||
8
template/testdata/gemma3-instruct.gotmpl/user-assistant-user
vendored
Normal file
8
template/testdata/gemma3-instruct.gotmpl/user-assistant-user
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<start_of_turn>user
|
||||||
|
Hello, how are you?<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
I'm doing great. How can I help you today?<end_of_turn>
|
||||||
|
<start_of_turn>user
|
||||||
|
I'd like to show off how chat templating works!<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
|
||||||
15
types/model/capability.go
Normal file
15
types/model/capability.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
type Capability string
|
||||||
|
|
||||||
|
const (
|
||||||
|
CapabilityCompletion = Capability("completion")
|
||||||
|
CapabilityTools = Capability("tools")
|
||||||
|
CapabilityInsert = Capability("insert")
|
||||||
|
CapabilityVision = Capability("vision")
|
||||||
|
CapabilityEmbedding = Capability("embedding")
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c Capability) String() string {
|
||||||
|
return string(c)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user