mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
launch: warn when server context length is below 64k for local models (#15044)
A stop-gap for now to guide users better. We'll add more in-depth recommendations per integration as well. --------- Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
This commit is contained in:
@@ -841,7 +841,8 @@ type CloudStatus struct {
|
||||
|
||||
// StatusResponse is the response from [Client.CloudStatusExperimental].
|
||||
type StatusResponse struct {
|
||||
Cloud CloudStatus `json:"cloud"`
|
||||
Cloud CloudStatus `json:"cloud"`
|
||||
ContextLength int `json:"context_length,omitempty"`
|
||||
}
|
||||
|
||||
// GenerateResponse is the response passed into [GenerateResponseFunc].
|
||||
|
||||
@@ -472,6 +472,10 @@ func (c *launcherClient) launchSingleIntegration(ctx context.Context, name strin
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := lowContextLength(ctx, c.apiClient, []string{target}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if target != current {
|
||||
if err := config.SaveIntegration(name, []string{target}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
@@ -498,6 +502,10 @@ func (c *launcherClient) launchEditorIntegration(ctx context.Context, name strin
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := lowContextLength(ctx, c.apiClient, models); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if needsConfigure || req.ModelOverride != "" {
|
||||
if err := prepareEditorIntegration(name, runner, editor, models); err != nil {
|
||||
return err
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
)
|
||||
|
||||
@@ -1496,3 +1498,232 @@ func compareStringSlices(got, want [][]string) string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TestConfirmLowContextLength(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
models []string
|
||||
statusBody string
|
||||
statusCode int
|
||||
showParams string // Parameters field returned by /api/show
|
||||
wantWarning bool
|
||||
wantModelfile bool // true if warning should mention Modelfile
|
||||
}{
|
||||
{
|
||||
name: "no warning when server context meets recommended",
|
||||
models: []string{"llama3.2"},
|
||||
statusBody: `{"cloud":{},"context_length":65536}`,
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "no warning when server context exceeds recommended",
|
||||
models: []string{"llama3.2"},
|
||||
statusBody: `{"cloud":{},"context_length":131072}`,
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "warns when server context is below recommended",
|
||||
models: []string{"llama3.2"},
|
||||
statusBody: `{"cloud":{},"context_length":4096}`,
|
||||
statusCode: http.StatusOK,
|
||||
wantWarning: true,
|
||||
},
|
||||
{
|
||||
name: "no warning when status endpoint fails",
|
||||
models: []string{"llama3.2"},
|
||||
statusCode: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "no warning for cloud-only models even with low context",
|
||||
models: []string{"gpt-4o:cloud"},
|
||||
statusBody: `{"cloud":{},"context_length":4096}`,
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "no warning when models list is empty",
|
||||
models: []string{},
|
||||
statusBody: `{"cloud":{},"context_length":4096}`,
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "no warning when modelfile num_ctx meets recommended",
|
||||
models: []string{"llama3.2"},
|
||||
statusBody: `{"cloud":{},"context_length":4096}`,
|
||||
statusCode: http.StatusOK,
|
||||
showParams: "num_ctx 65536",
|
||||
},
|
||||
{
|
||||
name: "no warning when modelfile num_ctx exceeds recommended",
|
||||
models: []string{"llama3.2"},
|
||||
statusBody: `{"cloud":{},"context_length":4096}`,
|
||||
statusCode: http.StatusOK,
|
||||
showParams: "num_ctx 131072",
|
||||
},
|
||||
{
|
||||
name: "warns with modelfile hint when modelfile num_ctx is below recommended",
|
||||
models: []string{"llama3.2"},
|
||||
statusBody: `{"cloud":{},"context_length":131072}`,
|
||||
statusCode: http.StatusOK,
|
||||
showParams: "num_ctx 4096",
|
||||
wantWarning: true,
|
||||
wantModelfile: true,
|
||||
},
|
||||
{
|
||||
name: "warns with modelfile hint when both server and modelfile are low",
|
||||
models: []string{"llama3.2"},
|
||||
statusBody: `{"cloud":{},"context_length":2048}`,
|
||||
statusCode: http.StatusOK,
|
||||
showParams: "num_ctx 4096",
|
||||
wantWarning: true,
|
||||
wantModelfile: true,
|
||||
},
|
||||
{
|
||||
name: "no warning when status returns malformed JSON",
|
||||
models: []string{"llama3.2"},
|
||||
statusBody: `{invalid json`,
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "no warning when status returns empty body with 200",
|
||||
models: []string{"llama3.2"},
|
||||
statusBody: "",
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "no warning when show endpoint fails",
|
||||
models: []string{"llama3.2"},
|
||||
statusBody: `{"cloud":{},"context_length":65536}`,
|
||||
statusCode: http.StatusOK,
|
||||
showParams: "SHOW_ERROR", // sentinel to make show return 500
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/status" {
|
||||
w.WriteHeader(tt.statusCode)
|
||||
if tt.statusBody != "" {
|
||||
fmt.Fprint(w, tt.statusBody)
|
||||
}
|
||||
return
|
||||
}
|
||||
if r.URL.Path == "/api/show" {
|
||||
if tt.showParams == "SHOW_ERROR" {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"parameters":%q}`, tt.showParams)
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
client, err := newTestClient(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create client: %v", err)
|
||||
}
|
||||
|
||||
// capture stderr
|
||||
oldStderr := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
err = lowContextLength(context.Background(), client, tt.models)
|
||||
|
||||
w.Close()
|
||||
var buf bytes.Buffer
|
||||
buf.ReadFrom(r)
|
||||
os.Stderr = oldStderr
|
||||
output := buf.String()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
hasWarning := strings.Contains(output, "Warning:")
|
||||
if hasWarning != tt.wantWarning {
|
||||
t.Fatalf("expected warning=%v, got output: %q", tt.wantWarning, output)
|
||||
}
|
||||
if tt.wantWarning && tt.wantModelfile {
|
||||
if !strings.Contains(output, "Use the model:") {
|
||||
t.Fatalf("expected parent model hint in output: %q", output)
|
||||
}
|
||||
}
|
||||
if tt.wantWarning && !tt.wantModelfile {
|
||||
if strings.Contains(output, "Use the model:") {
|
||||
t.Fatalf("expected server hint, not parent model hint in output: %q", output)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNumCtxFromParameters(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
parameters string
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "extracts num_ctx",
|
||||
parameters: "num_ctx 65536",
|
||||
want: 65536,
|
||||
},
|
||||
{
|
||||
name: "extracts num_ctx among other parameters",
|
||||
parameters: "temperature 0.7\nnum_ctx 131072\nstop \"<|end|>\"",
|
||||
want: 131072,
|
||||
},
|
||||
{
|
||||
name: "returns zero when no num_ctx",
|
||||
parameters: "temperature 0.7\nstop \"<|end|>\"",
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "returns zero for empty string",
|
||||
parameters: "",
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "handles float representation",
|
||||
parameters: "num_ctx 65536.0",
|
||||
want: 65536,
|
||||
},
|
||||
{
|
||||
name: "returns zero when num_ctx value is not a number",
|
||||
parameters: "num_ctx abc",
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "returns zero for completely garbled input",
|
||||
parameters: "!@#$%^&*()_+{}|:<>?",
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "returns zero when num_ctx has no value",
|
||||
parameters: "num_ctx",
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "returns zero when num_ctx has extra fields",
|
||||
parameters: "num_ctx 65536 extra_stuff",
|
||||
want: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := parseNumCtx(tt.parameters)
|
||||
if got != tt.want {
|
||||
t.Fatalf("parseNumCtx(%q) = %d, want %d", tt.parameters, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTestClient(url string) (*api.Client, error) {
|
||||
return api.ClientFromEnvironment()
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -443,6 +444,77 @@ func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool
|
||||
return status.Cloud.Disabled, true
|
||||
}
|
||||
|
||||
// TODO(ParthSareen): make this controllable on an integration level as well
|
||||
const recommendedContextLength = 64000
|
||||
|
||||
func hasLocalModel(models []string) bool {
|
||||
for _, m := range models {
|
||||
if !isCloudModelName(m) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func lowContextLength(ctx context.Context, client *api.Client, models []string) error {
|
||||
if !hasLocalModel(models) {
|
||||
return nil
|
||||
}
|
||||
|
||||
status, err := client.CloudStatusExperimental(ctx)
|
||||
if err != nil {
|
||||
return nil //nolint:nilerr // best-effort check; ignore if status endpoint is unavailable
|
||||
}
|
||||
serverCtx := status.ContextLength
|
||||
if serverCtx == 0 {
|
||||
return nil // couldn't determine context length, skip check
|
||||
}
|
||||
|
||||
for _, m := range models {
|
||||
if isCloudModelName(m) {
|
||||
continue
|
||||
}
|
||||
// A Modelfile can override num_ctx, which takes precedence over the server default.
|
||||
effectiveCtx := serverCtx
|
||||
modelfileOverride := false
|
||||
var info *api.ShowResponse
|
||||
if info, err = client.Show(ctx, &api.ShowRequest{Model: m}); err == nil {
|
||||
if numCtx := parseNumCtx(info.Parameters); numCtx > 0 {
|
||||
effectiveCtx = numCtx
|
||||
modelfileOverride = true
|
||||
}
|
||||
}
|
||||
if effectiveCtx < recommendedContextLength {
|
||||
fmt.Fprintf(os.Stderr, "\n%sWarning: context window is %d tokens (recommended: %d+)%s\n", ansiYellow, effectiveCtx, recommendedContextLength, ansiReset)
|
||||
if modelfileOverride {
|
||||
parentModel := info.Details.ParentModel
|
||||
fmt.Fprintf(os.Stderr, "%sUse the model: %s and increase the context length to at least %d in Ollama App Settings.%s\n\n", ansiYellow, parentModel, recommendedContextLength, ansiReset)
|
||||
} else {
|
||||
if runtime.GOOS == "windows" {
|
||||
fmt.Fprintf(os.Stderr, "%sIncrease it in Ollama App Settings or with $env:OLLAMA_CONTEXT_LENGTH=%d; ollama serve%s\n\n", ansiYellow, recommendedContextLength, ansiReset)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%sIncrease it in Ollama App Settings or with OLLAMA_CONTEXT_LENGTH=%d ollama serve%s\n\n", ansiYellow, recommendedContextLength, ansiReset)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseNumCtx extracts num_ctx from the Show response Parameters string.
|
||||
func parseNumCtx(parameters string) int {
|
||||
for _, line := range strings.Split(parameters, "\n") {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) == 2 && fields[0] == "num_ctx" {
|
||||
if v, err := strconv.ParseFloat(fields[1], 64); err == nil {
|
||||
return int(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// TODO(parthsareen): this duplicates the pull progress UI in cmd.PullHandler.
|
||||
// Move the shared pull rendering to a small utility once the package boundary settles.
|
||||
func pullModel(ctx context.Context, client *api.Client, model string, insecure bool) error {
|
||||
|
||||
@@ -1935,11 +1935,19 @@ func streamResponse(c *gin.Context, ch chan any) {
|
||||
|
||||
func (s *Server) StatusHandler(c *gin.Context) {
|
||||
disabled, source := internalcloud.Status()
|
||||
|
||||
contextLength := int(envconfig.ContextLength())
|
||||
if contextLength == 0 {
|
||||
slog.Warn("OLLAMA_CONTEXT_LENGTH is not set, using default", "default", s.defaultNumCtx)
|
||||
contextLength = s.defaultNumCtx
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, api.StatusResponse{
|
||||
Cloud: api.CloudStatus{
|
||||
Disabled: disabled,
|
||||
Source: source,
|
||||
},
|
||||
ContextLength: contextLength,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -44,6 +44,62 @@ func TestStatusHandler(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusHandlerContextLength(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
envContextLen string
|
||||
defaultNumCtx int
|
||||
wantContextLen int
|
||||
}{
|
||||
{
|
||||
name: "env var takes precedence over VRAM default",
|
||||
envContextLen: "8192",
|
||||
defaultNumCtx: 32768,
|
||||
wantContextLen: 8192,
|
||||
},
|
||||
{
|
||||
name: "falls back to VRAM default when env not set",
|
||||
envContextLen: "",
|
||||
defaultNumCtx: 32768,
|
||||
wantContextLen: 32768,
|
||||
},
|
||||
{
|
||||
name: "zero when neither is set",
|
||||
envContextLen: "",
|
||||
defaultNumCtx: 0,
|
||||
wantContextLen: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.envContextLen != "" {
|
||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", tt.envContextLen)
|
||||
} else {
|
||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", "")
|
||||
}
|
||||
|
||||
s := Server{defaultNumCtx: tt.defaultNumCtx}
|
||||
w := createRequest(t, s.StatusHandler, nil)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.StatusResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.ContextLength != tt.wantContextLen {
|
||||
t.Fatalf("expected context_length %d, got %d", tt.wantContextLen, resp.ContextLength)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudDisabledBlocksRemoteOperations(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
Reference in New Issue
Block a user