mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
launch: warn when server context length is below 64k for local models (#15044)
A stop-gap for now to guide users better. We'll add more in-depth recommendations per integration as well. --------- Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
This commit is contained in:
@@ -841,7 +841,8 @@ type CloudStatus struct {
|
|||||||
|
|
||||||
// StatusResponse is the response from [Client.CloudStatusExperimental].
|
// StatusResponse is the response from [Client.CloudStatusExperimental].
|
||||||
type StatusResponse struct {
|
type StatusResponse struct {
|
||||||
Cloud CloudStatus `json:"cloud"`
|
Cloud CloudStatus `json:"cloud"`
|
||||||
|
ContextLength int `json:"context_length,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateResponse is the response passed into [GenerateResponseFunc].
|
// GenerateResponse is the response passed into [GenerateResponseFunc].
|
||||||
|
|||||||
@@ -472,6 +472,10 @@ func (c *launcherClient) launchSingleIntegration(ctx context.Context, name strin
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := lowContextLength(ctx, c.apiClient, []string{target}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if target != current {
|
if target != current {
|
||||||
if err := config.SaveIntegration(name, []string{target}); err != nil {
|
if err := config.SaveIntegration(name, []string{target}); err != nil {
|
||||||
return fmt.Errorf("failed to save: %w", err)
|
return fmt.Errorf("failed to save: %w", err)
|
||||||
@@ -498,6 +502,10 @@ func (c *launcherClient) launchEditorIntegration(ctx context.Context, name strin
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := lowContextLength(ctx, c.apiClient, models); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if needsConfigure || req.ModelOverride != "" {
|
if needsConfigure || req.ModelOverride != "" {
|
||||||
if err := prepareEditorIntegration(name, runner, editor, models); err != nil {
|
if err := prepareEditorIntegration(name, runner, editor, models); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package launch
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -13,6 +14,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/cmd/config"
|
"github.com/ollama/ollama/cmd/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1496,3 +1498,232 @@ func compareStringSlices(got, want [][]string) string {
|
|||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfirmLowContextLength(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
models []string
|
||||||
|
statusBody string
|
||||||
|
statusCode int
|
||||||
|
showParams string // Parameters field returned by /api/show
|
||||||
|
wantWarning bool
|
||||||
|
wantModelfile bool // true if warning should mention Modelfile
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no warning when server context meets recommended",
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
statusBody: `{"cloud":{},"context_length":65536}`,
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no warning when server context exceeds recommended",
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
statusBody: `{"cloud":{},"context_length":131072}`,
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "warns when server context is below recommended",
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
statusBody: `{"cloud":{},"context_length":4096}`,
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
wantWarning: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no warning when status endpoint fails",
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
statusCode: http.StatusInternalServerError,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no warning for cloud-only models even with low context",
|
||||||
|
models: []string{"gpt-4o:cloud"},
|
||||||
|
statusBody: `{"cloud":{},"context_length":4096}`,
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no warning when models list is empty",
|
||||||
|
models: []string{},
|
||||||
|
statusBody: `{"cloud":{},"context_length":4096}`,
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no warning when modelfile num_ctx meets recommended",
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
statusBody: `{"cloud":{},"context_length":4096}`,
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
showParams: "num_ctx 65536",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no warning when modelfile num_ctx exceeds recommended",
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
statusBody: `{"cloud":{},"context_length":4096}`,
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
showParams: "num_ctx 131072",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "warns with modelfile hint when modelfile num_ctx is below recommended",
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
statusBody: `{"cloud":{},"context_length":131072}`,
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
showParams: "num_ctx 4096",
|
||||||
|
wantWarning: true,
|
||||||
|
wantModelfile: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "warns with modelfile hint when both server and modelfile are low",
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
statusBody: `{"cloud":{},"context_length":2048}`,
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
showParams: "num_ctx 4096",
|
||||||
|
wantWarning: true,
|
||||||
|
wantModelfile: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no warning when status returns malformed JSON",
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
statusBody: `{invalid json`,
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no warning when status returns empty body with 200",
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
statusBody: "",
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no warning when show endpoint fails",
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
statusBody: `{"cloud":{},"context_length":65536}`,
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
showParams: "SHOW_ERROR", // sentinel to make show return 500
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/api/status" {
|
||||||
|
w.WriteHeader(tt.statusCode)
|
||||||
|
if tt.statusBody != "" {
|
||||||
|
fmt.Fprint(w, tt.statusBody)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.URL.Path == "/api/show" {
|
||||||
|
if tt.showParams == "SHOW_ERROR" {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprintf(w, `{"parameters":%q}`, tt.showParams)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
client, err := newTestClient(srv.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create client: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// capture stderr
|
||||||
|
oldStderr := os.Stderr
|
||||||
|
r, w, _ := os.Pipe()
|
||||||
|
os.Stderr = w
|
||||||
|
|
||||||
|
err = lowContextLength(context.Background(), client, tt.models)
|
||||||
|
|
||||||
|
w.Close()
|
||||||
|
var buf bytes.Buffer
|
||||||
|
buf.ReadFrom(r)
|
||||||
|
os.Stderr = oldStderr
|
||||||
|
output := buf.String()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
hasWarning := strings.Contains(output, "Warning:")
|
||||||
|
if hasWarning != tt.wantWarning {
|
||||||
|
t.Fatalf("expected warning=%v, got output: %q", tt.wantWarning, output)
|
||||||
|
}
|
||||||
|
if tt.wantWarning && tt.wantModelfile {
|
||||||
|
if !strings.Contains(output, "Use the model:") {
|
||||||
|
t.Fatalf("expected parent model hint in output: %q", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tt.wantWarning && !tt.wantModelfile {
|
||||||
|
if strings.Contains(output, "Use the model:") {
|
||||||
|
t.Fatalf("expected server hint, not parent model hint in output: %q", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseNumCtxFromParameters(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
parameters string
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "extracts num_ctx",
|
||||||
|
parameters: "num_ctx 65536",
|
||||||
|
want: 65536,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extracts num_ctx among other parameters",
|
||||||
|
parameters: "temperature 0.7\nnum_ctx 131072\nstop \"<|end|>\"",
|
||||||
|
want: 131072,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns zero when no num_ctx",
|
||||||
|
parameters: "temperature 0.7\nstop \"<|end|>\"",
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns zero for empty string",
|
||||||
|
parameters: "",
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "handles float representation",
|
||||||
|
parameters: "num_ctx 65536.0",
|
||||||
|
want: 65536,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns zero when num_ctx value is not a number",
|
||||||
|
parameters: "num_ctx abc",
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns zero for completely garbled input",
|
||||||
|
parameters: "!@#$%^&*()_+{}|:<>?",
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns zero when num_ctx has no value",
|
||||||
|
parameters: "num_ctx",
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns zero when num_ctx has extra fields",
|
||||||
|
parameters: "num_ctx 65536 extra_stuff",
|
||||||
|
want: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := parseNumCtx(tt.parameters)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Fatalf("parseNumCtx(%q) = %d, want %d", tt.parameters, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestClient(url string) (*api.Client, error) {
|
||||||
|
return api.ClientFromEnvironment()
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"os/exec"
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -443,6 +444,77 @@ func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool
|
|||||||
return status.Cloud.Disabled, true
|
return status.Cloud.Disabled, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(ParthSareen): make this controllable on an integration level as well
|
||||||
|
const recommendedContextLength = 64000
|
||||||
|
|
||||||
|
func hasLocalModel(models []string) bool {
|
||||||
|
for _, m := range models {
|
||||||
|
if !isCloudModelName(m) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func lowContextLength(ctx context.Context, client *api.Client, models []string) error {
|
||||||
|
if !hasLocalModel(models) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
status, err := client.CloudStatusExperimental(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil //nolint:nilerr // best-effort check; ignore if status endpoint is unavailable
|
||||||
|
}
|
||||||
|
serverCtx := status.ContextLength
|
||||||
|
if serverCtx == 0 {
|
||||||
|
return nil // couldn't determine context length, skip check
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range models {
|
||||||
|
if isCloudModelName(m) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// A Modelfile can override num_ctx, which takes precedence over the server default.
|
||||||
|
effectiveCtx := serverCtx
|
||||||
|
modelfileOverride := false
|
||||||
|
var info *api.ShowResponse
|
||||||
|
if info, err = client.Show(ctx, &api.ShowRequest{Model: m}); err == nil {
|
||||||
|
if numCtx := parseNumCtx(info.Parameters); numCtx > 0 {
|
||||||
|
effectiveCtx = numCtx
|
||||||
|
modelfileOverride = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if effectiveCtx < recommendedContextLength {
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%sWarning: context window is %d tokens (recommended: %d+)%s\n", ansiYellow, effectiveCtx, recommendedContextLength, ansiReset)
|
||||||
|
if modelfileOverride {
|
||||||
|
parentModel := info.Details.ParentModel
|
||||||
|
fmt.Fprintf(os.Stderr, "%sUse the model: %s and increase the context length to at least %d in Ollama App Settings.%s\n\n", ansiYellow, parentModel, recommendedContextLength, ansiReset)
|
||||||
|
} else {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
fmt.Fprintf(os.Stderr, "%sIncrease it in Ollama App Settings or with $env:OLLAMA_CONTEXT_LENGTH=%d; ollama serve%s\n\n", ansiYellow, recommendedContextLength, ansiReset)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "%sIncrease it in Ollama App Settings or with OLLAMA_CONTEXT_LENGTH=%d ollama serve%s\n\n", ansiYellow, recommendedContextLength, ansiReset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseNumCtx extracts num_ctx from the Show response Parameters string.
|
||||||
|
func parseNumCtx(parameters string) int {
|
||||||
|
for _, line := range strings.Split(parameters, "\n") {
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) == 2 && fields[0] == "num_ctx" {
|
||||||
|
if v, err := strconv.ParseFloat(fields[1], 64); err == nil {
|
||||||
|
return int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(parthsareen): this duplicates the pull progress UI in cmd.PullHandler.
|
// TODO(parthsareen): this duplicates the pull progress UI in cmd.PullHandler.
|
||||||
// Move the shared pull rendering to a small utility once the package boundary settles.
|
// Move the shared pull rendering to a small utility once the package boundary settles.
|
||||||
func pullModel(ctx context.Context, client *api.Client, model string, insecure bool) error {
|
func pullModel(ctx context.Context, client *api.Client, model string, insecure bool) error {
|
||||||
|
|||||||
@@ -1935,11 +1935,19 @@ func streamResponse(c *gin.Context, ch chan any) {
|
|||||||
|
|
||||||
func (s *Server) StatusHandler(c *gin.Context) {
|
func (s *Server) StatusHandler(c *gin.Context) {
|
||||||
disabled, source := internalcloud.Status()
|
disabled, source := internalcloud.Status()
|
||||||
|
|
||||||
|
contextLength := int(envconfig.ContextLength())
|
||||||
|
if contextLength == 0 {
|
||||||
|
slog.Warn("OLLAMA_CONTEXT_LENGTH is not set, using default", "default", s.defaultNumCtx)
|
||||||
|
contextLength = s.defaultNumCtx
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, api.StatusResponse{
|
c.JSON(http.StatusOK, api.StatusResponse{
|
||||||
Cloud: api.CloudStatus{
|
Cloud: api.CloudStatus{
|
||||||
Disabled: disabled,
|
Disabled: disabled,
|
||||||
Source: source,
|
Source: source,
|
||||||
},
|
},
|
||||||
|
ContextLength: contextLength,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -44,6 +44,62 @@ func TestStatusHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStatusHandlerContextLength(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
envContextLen string
|
||||||
|
defaultNumCtx int
|
||||||
|
wantContextLen int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "env var takes precedence over VRAM default",
|
||||||
|
envContextLen: "8192",
|
||||||
|
defaultNumCtx: 32768,
|
||||||
|
wantContextLen: 8192,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "falls back to VRAM default when env not set",
|
||||||
|
envContextLen: "",
|
||||||
|
defaultNumCtx: 32768,
|
||||||
|
wantContextLen: 32768,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero when neither is set",
|
||||||
|
envContextLen: "",
|
||||||
|
defaultNumCtx: 0,
|
||||||
|
wantContextLen: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.envContextLen != "" {
|
||||||
|
t.Setenv("OLLAMA_CONTEXT_LENGTH", tt.envContextLen)
|
||||||
|
} else {
|
||||||
|
t.Setenv("OLLAMA_CONTEXT_LENGTH", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
s := Server{defaultNumCtx: tt.defaultNumCtx}
|
||||||
|
w := createRequest(t, s.StatusHandler, nil)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp api.StatusResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.ContextLength != tt.wantContextLen {
|
||||||
|
t.Fatalf("expected context_length %d, got %d", tt.wantContextLen, resp.ContextLength)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCloudDisabledBlocksRemoteOperations(t *testing.T) {
|
func TestCloudDisabledBlocksRemoteOperations(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
setTestHome(t, t.TempDir())
|
setTestHome(t, t.TempDir())
|
||||||
|
|||||||
Reference in New Issue
Block a user