diff --git a/app/server/server.go b/app/server/server.go index 74913828a..89a8df58f 100644 --- a/app/server/server.go +++ b/app/server/server.go @@ -41,6 +41,11 @@ type InferenceCompute struct { VRAM string } +type InferenceInfo struct { + Computes []InferenceCompute + DefaultContextLength int +} + func New(s *store.Store, devMode bool) *Server { p := resolvePath("ollama") return &Server{store: s, bin: p, dev: devMode} @@ -272,9 +277,12 @@ func openRotatingLog() (io.WriteCloser, error) { // Attempt to retrieve inference compute information from the server // log. Set ctx to timeout to control how long to wait for the logs to appear -func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) { - inference := []InferenceCompute{} - marker := regexp.MustCompile(`inference compute.*library=`) +func GetInferenceInfo(ctx context.Context) (*InferenceInfo, error) { + info := &InferenceInfo{} + computeMarker := regexp.MustCompile(`inference compute.*library=`) + defaultCtxMarker := regexp.MustCompile(`vram-based default context`) + defaultCtxRegex := regexp.MustCompile(`default_num_ctx=(\d+)`) + q := `inference compute.*%s=["]([^"]*)["]` nq := `inference compute.*%s=(\S+)\s` type regex struct { @@ -340,8 +348,8 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) { scanner := bufio.NewScanner(file) for scanner.Scan() { line := scanner.Text() - match := marker.FindStringSubmatch(line) - if len(match) > 0 { + // Check for inference compute lines + if computeMarker.MatchString(line) { ic := InferenceCompute{ Library: get("library", line), Variant: get("variant", line), @@ -352,12 +360,25 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) { } slog.Info("Matched", "inference compute", ic) - inference = append(inference, ic) - } else { - // Break out on first non matching line after we start matching - if len(inference) > 0 { - return inference, nil + info.Computes = append(info.Computes, ic) + continue + } + // Check for default context length line + if defaultCtxMarker.MatchString(line) { + match := defaultCtxRegex.FindStringSubmatch(line) + if len(match) > 1 { + numCtx, err := strconv.Atoi(match[1]) + if err == nil { + info.DefaultContextLength = numCtx + slog.Info("Matched default context length", "default_num_ctx", numCtx) + } } + return info, nil + } + // If we've found compute info but hit a non-matching line, return what we have + // This handles older server versions that don't log the default context line + if len(info.Computes) > 0 { + return info, nil } } time.Sleep(100 * time.Millisecond) diff --git a/app/server/server_test.go b/app/server/server_test.go index 8d3a6f27e..34d12dfd6 100644 --- a/app/server/server_test.go +++ b/app/server/server_test.go @@ -205,44 +205,50 @@ func TestServerCmdCloudSettingEnv(t *testing.T) { } } -func TestGetInferenceComputer(t *testing.T) { +func TestGetInferenceInfo(t *testing.T) { tests := []struct { - name string - log string - exp []InferenceCompute + name string + log string + expComputes []InferenceCompute + expDefaultCtxLen int }{ { name: "metal", log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler" time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB" +time=2025-06-30T09:23:07.417-07:00 level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="96.0 GiB" default_num_ctx=262144 time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32 `, - exp: []InferenceCompute{{ + expComputes: []InferenceCompute{{ Library: "metal", Driver: "0.0", VRAM: "96.0 GiB", }}, + expDefaultCtxLen: 262144, }, { name: "cpu", log: `time=2025-07-01T17:59:51.470Z level=INFO source=gpu.go:377 msg="no compatible GPUs were discovered" time=2025-07-01T17:59:51.470Z level=INFO source=types.go:130 msg="inference compute" id=0 library=cpu variant="" compute="" driver=0.0 name="" total="31.3 GiB" available="30.4 GiB" +time=2025-07-01T17:59:51.471Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="31.3 GiB" default_num_ctx=32768 [GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/" `, - exp: []InferenceCompute{{ + expComputes: []InferenceCompute{{ Library: "cpu", Driver: "0.0", VRAM: "31.3 GiB", }}, + expDefaultCtxLen: 32768, }, { name: "cuda1", log: `time=2025-07-01T19:33:43.162Z level=DEBUG source=amd_linux.go:419 msg="amdgpu driver not detected /sys/module/amdgpu" releasing cuda driver library time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference compute" id=GPU-452cac9f-6960-839c-4fb3-0cec83699196 library=cuda variant=v12 compute=6.1 driver=12.7 name="NVIDIA GeForce GT 1030" total="3.9 GiB" available="3.9 GiB" +time=2025-07-01T19:33:43.163Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="3.9 GiB" default_num_ctx=4096 [GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/" `, - exp: []InferenceCompute{{ + expComputes: []InferenceCompute{{ Library: "cuda", Variant: "v12", Compute: "6.1", @@ -250,6 +256,7 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp Name: "NVIDIA GeForce GT 1030", VRAM: "3.9 GiB", }}, + expDefaultCtxLen: 4096, }, { name: "frank", @@ -257,9 +264,10 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp releasing cuda driver library time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-d6de3398-9932-6902-11ec-fee8e424c8a2 library=cuda variant=v12 compute=7.5 driver=12.8 name="NVIDIA GeForce RTX 2080 Ti" total="10.6 GiB" available="10.4 GiB" time=2025-07-01T19:36:13.315Z level=INFO source=types.go:130 msg="inference compute" id=GPU-9abb57639fa80c50 library=rocm variant="" compute=gfx1030 driver=6.3 name=1002:73bf total="16.0 GiB" available="1.3 GiB" + time=2025-07-01T19:36:13.316Z level=INFO source=routes.go:1721 msg="vram-based default context" total_vram="26.6 GiB" default_num_ctx=32768 [GIN] 2025/07/01 - 18:00:09 | 200 | 50.263µs | 100.126.204.152 | HEAD "/" `, - exp: []InferenceCompute{ + expComputes: []InferenceCompute{ { Library: "cuda", Variant: "v12", @@ -276,6 +284,20 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp VRAM: "16.0 GiB", }, }, + expDefaultCtxLen: 32768, + }, + { + name: "missing_default_context", + log: `time=2025-06-30T09:23:07.374-07:00 level=DEBUG source=sched.go:108 msg="starting llm scheduler" +time=2025-06-30T09:23:07.416-07:00 level=INFO source=types.go:130 msg="inference compute" id=0 library=metal variant="" compute="" driver=0.0 name="" total="96.0 GiB" available="96.0 GiB" +time=2025-06-30T09:25:56.197-07:00 level=DEBUG source=ggml.go:155 msg="key not found" key=general.alignment default=32 +`, + expComputes: []InferenceCompute{{ + Library: "metal", + Driver: "0.0", + VRAM: "96.0 GiB", + }}, + expDefaultCtxLen: 0, // No default context line, should return 0 }, } for _, tt := range tests { @@ -288,18 +310,21 @@ time=2025-07-01T19:33:43.162Z level=INFO source=types.go:130 msg="inference comp } ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond) defer cancel() - ics, err := GetInferenceComputer(ctx) + info, err := GetInferenceInfo(ctx) if err != nil { - t.Fatalf(" failed to get inference compute: %v", err) + t.Fatalf("failed to get inference info: %v", err) } - if !reflect.DeepEqual(ics, tt.exp) { - t.Fatalf("got:\n%#v\nwant:\n%#v", ics, tt.exp) + if !reflect.DeepEqual(info.Computes, tt.expComputes) { + t.Fatalf("computes mismatch\ngot:\n%#v\nwant:\n%#v", info.Computes, tt.expComputes) + } + if info.DefaultContextLength != tt.expDefaultCtxLen { + t.Fatalf("default context length mismatch: got %d, want %d", info.DefaultContextLength, tt.expDefaultCtxLen) } }) } } -func TestGetInferenceComputerTimeout(t *testing.T) { +func TestGetInferenceInfoTimeout(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond) defer cancel() tmpDir := t.TempDir() @@ -308,7 +333,7 @@ func TestGetInferenceComputerTimeout(t *testing.T) { if err != nil { t.Fatalf("failed to write log file %s: %s", serverLogPath, err) } - _, err = GetInferenceComputer(ctx) + _, err = GetInferenceInfo(ctx) if err == nil { t.Fatal("expected timeout") } diff --git a/app/store/database.go b/app/store/database.go index 8e97b9c8c..20f384bb7 100644 --- a/app/store/database.go +++ b/app/store/database.go @@ -14,7 +14,7 @@ import ( // currentSchemaVersion defines the current database schema version. // Increment this when making schema changes that require migrations. -const currentSchemaVersion = 13 +const currentSchemaVersion = 14 // database wraps the SQLite connection. // SQLite handles its own locking for concurrent access: @@ -73,7 +73,7 @@ func (db *database) init() error { agent BOOLEAN NOT NULL DEFAULT 0, tools BOOLEAN NOT NULL DEFAULT 0, working_dir TEXT NOT NULL DEFAULT '', - context_length INTEGER NOT NULL DEFAULT 4096, + context_length INTEGER NOT NULL DEFAULT 0, window_width INTEGER NOT NULL DEFAULT 0, window_height INTEGER NOT NULL DEFAULT 0, config_migrated BOOLEAN NOT NULL DEFAULT 0, @@ -251,6 +251,12 @@ func (db *database) migrate() error { return fmt.Errorf("migrate v12 to v13: %w", err) } version = 13 + case 13: + // change default context_length from 4096 to 0 (VRAM-based tiered defaults) + if err := db.migrateV13ToV14(); err != nil { + return fmt.Errorf("migrate v13 to v14: %w", err) + } + version = 14 default: // If we have a version we don't recognize, just set it to current // This might happen during development @@ -474,6 +480,22 @@ func (db *database) migrateV12ToV13() error { return nil } +// migrateV13ToV14 changes the default context_length from 4096 to 0. +// When context_length is 0, the ollama server uses VRAM-based tiered defaults. +func (db *database) migrateV13ToV14() error { + _, err := db.conn.Exec(`UPDATE settings SET context_length = 0 WHERE context_length = 4096`) + if err != nil { + return fmt.Errorf("update context_length default: %w", err) + } + + _, err = db.conn.Exec(`UPDATE settings SET schema_version = 14`) + if err != nil { + return fmt.Errorf("update schema version: %w", err) + } + + return nil +} + // cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug func (db *database) cleanupOrphanedData() error { _, err := db.conn.Exec(` diff --git a/app/store/database_test.go b/app/store/database_test.go index 1b037a75d..a5ebabd2e 100644 --- a/app/store/database_test.go +++ b/app/store/database_test.go @@ -98,6 +98,43 @@ func TestSchemaMigrations(t *testing.T) { }) } +func TestMigrationV13ToV14ContextLength(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "test.db") + + db, err := newDatabase(dbPath) + if err != nil { + t.Fatalf("failed to create database: %v", err) + } + defer db.Close() + + _, err = db.conn.Exec("UPDATE settings SET context_length = 4096, schema_version = 13") + if err != nil { + t.Fatalf("failed to seed v13 settings row: %v", err) + } + + if err := db.migrate(); err != nil { + t.Fatalf("migration from v13 to v14 failed: %v", err) + } + + var contextLength int + if err := db.conn.QueryRow("SELECT context_length FROM settings").Scan(&contextLength); err != nil { + t.Fatalf("failed to read context_length: %v", err) + } + + if contextLength != 0 { + t.Fatalf("expected context_length to migrate to 0, got %d", contextLength) + } + + version, err := db.getSchemaVersion() + if err != nil { + t.Fatalf("failed to get schema version: %v", err) + } + if version != currentSchemaVersion { + t.Fatalf("expected schema version %d, got %d", currentSchemaVersion, version) + } +} + func TestChatDeletionWithCascade(t *testing.T) { t.Run("chat deletion cascades to related messages", func(t *testing.T) { tmpDir := t.TempDir() diff --git a/app/store/testdata/schema.sql b/app/store/testdata/schema.sql index 8f944ff85..9ed23c28b 100644 --- a/app/store/testdata/schema.sql +++ b/app/store/testdata/schema.sql @@ -13,7 +13,7 @@ CREATE TABLE IF NOT EXISTS settings ( agent BOOLEAN NOT NULL DEFAULT 0, tools BOOLEAN NOT NULL DEFAULT 0, working_dir TEXT NOT NULL DEFAULT '', - context_length INTEGER NOT NULL DEFAULT 4096, + context_length INTEGER NOT NULL DEFAULT 0, window_width INTEGER NOT NULL DEFAULT 0, window_height INTEGER NOT NULL DEFAULT 0, config_migrated BOOLEAN NOT NULL DEFAULT 0, diff --git a/app/ui/app/codegen/gotypes.gen.ts b/app/ui/app/codegen/gotypes.gen.ts index 0f4594209..61140bf7f 100644 --- a/app/ui/app/codegen/gotypes.gen.ts +++ b/app/ui/app/codegen/gotypes.gen.ts @@ -289,10 +289,12 @@ export class InferenceCompute { } export class InferenceComputeResponse { inferenceComputes: InferenceCompute[]; + defaultContextLength: number; constructor(source: any = {}) { if ('string' === typeof source) source = JSON.parse(source); this.inferenceComputes = this.convertValues(source["inferenceComputes"], InferenceCompute); + this.defaultContextLength = source["defaultContextLength"]; } convertValues(a: any, classs: any, asMap: boolean = false): any { diff --git a/app/ui/app/src/api.ts b/app/ui/app/src/api.ts index 739dfb09d..062f5bdd2 100644 --- a/app/ui/app/src/api.ts +++ b/app/ui/app/src/api.ts @@ -4,7 +4,6 @@ import { ChatEvent, DownloadEvent, ErrorEvent, - InferenceCompute, InferenceComputeResponse, ModelCapabilitiesResponse, Model, @@ -407,7 +406,7 @@ export async function* pullModel( } } -export async function getInferenceCompute(): Promise { +export async function getInferenceCompute(): Promise { const response = await fetch(`${API_BASE}/api/v1/inference-compute`); if (!response.ok) { throw new Error( @@ -416,8 +415,7 @@ export async function getInferenceCompute(): Promise { } const data = await response.json(); - const inferenceComputeResponse = new InferenceComputeResponse(data); - return inferenceComputeResponse.inferenceComputes || []; + return new InferenceComputeResponse(data); } export async function fetchHealth(): Promise { diff --git a/app/ui/app/src/components/Settings.tsx b/app/ui/app/src/components/Settings.tsx index c5153a994..ef0bf4c53 100644 --- a/app/ui/app/src/components/Settings.tsx +++ b/app/ui/app/src/components/Settings.tsx @@ -26,6 +26,7 @@ import { type CloudStatusResponse, updateCloudSetting, updateSettings, + getInferenceCompute, } from "@/api"; function AnimatedDots() { @@ -77,6 +78,13 @@ export default function Settings() { const settings = settingsData?.settings || null; + const { data: inferenceComputeResponse } = useQuery({ + queryKey: ["inferenceCompute"], + queryFn: getInferenceCompute, + }); + + const defaultContextLength = inferenceComputeResponse?.defaultContextLength; + const updateSettingsMutation = useMutation({ mutationFn: updateSettings, onSuccess: () => { @@ -204,7 +212,7 @@ export default function Settings() { Models: "", Agent: false, Tools: false, - ContextLength: 4096, + ContextLength: 0, }); updateSettingsMutation.mutate(defaultSettings); } @@ -507,13 +515,11 @@ export default function Settings() {
{ - // Otherwise use the settings value - return settings.ContextLength || 4096; - })()} + value={settings.ContextLength || defaultContextLength || 0} onChange={(value) => { handleChange("ContextLength", value); }} + disabled={!defaultContextLength} options={[ { value: 4096, label: "4k" }, { value: 8192, label: "8k" }, diff --git a/app/ui/app/src/components/ui/slider.tsx b/app/ui/app/src/components/ui/slider.tsx index 75ce1767b..11c0d4fa0 100644 --- a/app/ui/app/src/components/ui/slider.tsx +++ b/app/ui/app/src/components/ui/slider.tsx @@ -6,10 +6,11 @@ export interface SliderProps { value?: number; onChange?: (value: number) => void; className?: string; + disabled?: boolean; } const Slider = React.forwardRef( - ({ label, options, value = 0, onChange }, ref) => { + ({ label, options, value = 0, onChange, disabled = false }, ref) => { const [selectedValue, setSelectedValue] = React.useState(value); const [isDragging, setIsDragging] = React.useState(false); const containerRef = React.useRef(null); @@ -20,6 +21,7 @@ const Slider = React.forwardRef( }, [value]); const handleClick = (optionValue: number) => { + if (disabled) return; setSelectedValue(optionValue); onChange?.(optionValue); }; @@ -39,6 +41,7 @@ const Slider = React.forwardRef( }; const handleMouseDown = (e: React.MouseEvent) => { + if (disabled) return; setIsDragging(true); e.preventDefault(); }; @@ -77,7 +80,7 @@ const Slider = React.forwardRef( } return ( -
+
{label && }
@@ -88,10 +91,11 @@ const Slider = React.forwardRef(