Files
ollama/app/ui/ui_test.go
Devon Rifkin 79865e6c5a app: use the same client for inference and other requests (#15204)
Previously we were accidentally using different clients/UAs depending on
whether it was an inference call or a different call. This change makes
them consistent, other than the timeout being different.
2026-04-02 11:07:50 -07:00

843 lines
22 KiB
Go

//go:build windows || darwin
package ui
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"path/filepath"
"runtime"
"strings"
"sync/atomic"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/store"
"github.com/ollama/ollama/app/updater"
)
func TestHandlePostApiSettings(t *testing.T) {
tests := []struct {
name string
requested store.Settings
wantErr bool
}{
{
name: "valid settings update - all fields",
requested: store.Settings{
Expose: true,
Browser: true,
Models: "/custom/models",
Agent: true,
Tools: true,
WorkingDir: "/workspace",
},
wantErr: false,
},
{
name: "partial settings update",
requested: store.Settings{
Agent: true,
Tools: false,
WorkingDir: "/new/path",
},
wantErr: false,
},
{
name: "settings with special characters in paths",
requested: store.Settings{
Models: "/path with spaces/models",
WorkingDir: "/tmp/work-dir_123",
Agent: true,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testStore := &store.Store{
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
}
defer testStore.Close() // Ensure database is closed before cleanup
body, err := json.Marshal(tt.requested)
if err != nil {
t.Fatalf("failed to marshal test body: %v", err)
}
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
// Set up server with test store
server := &Server{
Store: testStore,
Restart: func() {}, // Mock restart function for tests
}
if err := server.settings(rr, req); (err != nil) != tt.wantErr {
t.Errorf("handlePostApiSettings() error = %v, wantErr %v", err, tt.wantErr)
}
if rr.Code != http.StatusOK {
t.Errorf("handlePostApiSettings() status = %v, want %v", rr.Code, http.StatusOK)
}
// Check settings were saved correctly (if no error expected)
if !tt.wantErr {
savedSettings, err := testStore.Settings()
if err != nil {
t.Errorf("failed to retrieve saved settings: %v", err)
} else {
// Compare field by field, accounting for defaults that may be set by the store
if savedSettings.Expose != tt.requested.Expose {
t.Errorf("Expose: got %v, want %v", savedSettings.Expose, tt.requested.Expose)
}
if savedSettings.Browser != tt.requested.Browser {
t.Errorf("Browser: got %v, want %v", savedSettings.Browser, tt.requested.Browser)
}
if savedSettings.Agent != tt.requested.Agent {
t.Errorf("Agent: got %v, want %v", savedSettings.Agent, tt.requested.Agent)
}
if savedSettings.Tools != tt.requested.Tools {
t.Errorf("Tools: got %v, want %v", savedSettings.Tools, tt.requested.Tools)
}
if savedSettings.WorkingDir != tt.requested.WorkingDir {
t.Errorf("WorkingDir: got %q, want %q", savedSettings.WorkingDir, tt.requested.WorkingDir)
}
// Only check Models if explicitly set in the test case
if tt.requested.Models != "" && savedSettings.Models != tt.requested.Models {
t.Errorf("Models: got %q, want %q", savedSettings.Models, tt.requested.Models)
}
}
}
})
}
}
func TestHandlePostApiCloudSetting(t *testing.T) {
tmpHome := t.TempDir()
t.Setenv("HOME", tmpHome)
t.Setenv("OLLAMA_NO_CLOUD", "")
testStore := &store.Store{
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
}
defer testStore.Close()
restartCount := 0
server := &Server{
Store: testStore,
Restart: func() {
restartCount++
},
}
for _, tc := range []struct {
name string
body string
wantEnabled bool
}{
{name: "disable cloud", body: `{"enabled": false}`, wantEnabled: false},
{name: "enable cloud", body: `{"enabled": true}`, wantEnabled: true},
} {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/api/v1/cloud", bytes.NewBufferString(tc.body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
if err := server.cloudSetting(rr, req); err != nil {
t.Fatalf("cloudSetting() error = %v", err)
}
if rr.Code != http.StatusOK {
t.Fatalf("cloudSetting() status = %d, want %d", rr.Code, http.StatusOK)
}
var got map[string]any
if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil {
t.Fatalf("cloudSetting() invalid response JSON: %v", err)
}
if got["disabled"] != !tc.wantEnabled {
t.Fatalf("response disabled = %v, want %v", got["disabled"], !tc.wantEnabled)
}
disabled, err := testStore.CloudDisabled()
if err != nil {
t.Fatalf("CloudDisabled() error = %v", err)
}
if gotEnabled := !disabled; gotEnabled != tc.wantEnabled {
t.Fatalf("cloud enabled = %v, want %v", gotEnabled, tc.wantEnabled)
}
})
}
if restartCount != 2 {
t.Fatalf("Restart called %d times, want 2", restartCount)
}
}
func TestHandleGetApiCloudSetting(t *testing.T) {
tmpHome := t.TempDir()
t.Setenv("HOME", tmpHome)
t.Setenv("OLLAMA_NO_CLOUD", "")
testStore := &store.Store{
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
}
defer testStore.Close()
if err := testStore.SetCloudEnabled(false); err != nil {
t.Fatalf("SetCloudEnabled(false) error = %v", err)
}
server := &Server{
Store: testStore,
Restart: func() {},
}
req := httptest.NewRequest("GET", "/api/v1/cloud", nil)
rr := httptest.NewRecorder()
if err := server.getCloudSetting(rr, req); err != nil {
t.Fatalf("getCloudSetting() error = %v", err)
}
if rr.Code != http.StatusOK {
t.Fatalf("getCloudSetting() status = %d, want %d", rr.Code, http.StatusOK)
}
var got map[string]any
if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil {
t.Fatalf("getCloudSetting() invalid response JSON: %v", err)
}
if got["disabled"] != true {
t.Fatalf("response disabled = %v, want true", got["disabled"])
}
if got["source"] != "config" {
t.Fatalf("response source = %v, want config", got["source"])
}
}
func TestAuthenticationMiddleware(t *testing.T) {
tests := []struct {
name string
method string
contentType string
tokenCookie string
serverToken string
wantStatus int
wantError string
setupRequest func(*http.Request)
}{
{
name: "missing token cookie",
method: "GET",
tokenCookie: "",
serverToken: "test-token-123",
wantStatus: http.StatusForbidden,
wantError: "Token is required",
},
{
name: "invalid token value",
method: "GET",
tokenCookie: "wrong-token",
serverToken: "test-token-123",
wantStatus: http.StatusForbidden,
wantError: "Token is required",
},
{
name: "valid token - GET request",
method: "GET",
tokenCookie: "test-token-123",
serverToken: "test-token-123",
wantStatus: http.StatusOK,
wantError: "",
},
{
name: "valid token - POST with application/json",
method: "POST",
contentType: "application/json",
tokenCookie: "test-token-123",
serverToken: "test-token-123",
wantStatus: http.StatusOK,
wantError: "",
},
{
name: "POST without Content-Type header",
method: "POST",
contentType: "",
tokenCookie: "test-token-123",
serverToken: "test-token-123",
wantStatus: http.StatusForbidden,
wantError: "Content-Type must be application/json",
},
{
name: "POST with wrong Content-Type",
method: "POST",
contentType: "text/plain",
tokenCookie: "test-token-123",
serverToken: "test-token-123",
wantStatus: http.StatusForbidden,
wantError: "Content-Type must be application/json",
},
{
name: "OPTIONS request (CORS preflight) - should bypass auth",
method: "OPTIONS",
tokenCookie: "",
serverToken: "test-token-123",
wantStatus: http.StatusOK,
wantError: "",
setupRequest: func(r *http.Request) {
// Simulate CORS being enabled
// Note: This assumes CORS() returns true in test environment
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a test handler that just returns 200 OK if auth passes
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
})
// Create server with test token
server := &Server{
Token: tt.serverToken,
}
// Get the authentication middleware by calling Handler()
// We need to wrap our test handler with the auth middleware
handler := server.Handler()
// Create a test router to simulate the authentication middleware
mux := http.NewServeMux()
mux.Handle("/test", handler)
// But since Handler() returns the full router, we'll need a different approach
// Let's create a minimal handler that includes just the auth logic
authHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add CORS headers for dev work
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
if r.Method == "POST" && r.Header.Get("Content-Type") != "application/json" {
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode(map[string]string{"error": "Content-Type must be application/json"})
return
}
cookie, err := r.Cookie("token")
if err != nil {
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode(map[string]string{"error": "Token is required"})
return
}
if cookie.Value != server.Token {
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode(map[string]string{"error": "Token is required"})
return
}
// If auth passes, call the test handler
testHandler.ServeHTTP(w, r)
})
// Create test request
req := httptest.NewRequest(tt.method, "/test", nil)
// Set Content-Type if provided
if tt.contentType != "" {
req.Header.Set("Content-Type", tt.contentType)
}
// Set token cookie if provided
if tt.tokenCookie != "" {
req.AddCookie(&http.Cookie{
Name: "token",
Value: tt.tokenCookie,
})
}
// Run any additional setup
if tt.setupRequest != nil {
tt.setupRequest(req)
}
// Create response recorder
rr := httptest.NewRecorder()
// Serve the request
authHandler.ServeHTTP(rr, req)
// Check status code
if rr.Code != tt.wantStatus {
t.Errorf("handler returned wrong status code: got %v want %v", rr.Code, tt.wantStatus)
}
// Check error message if expected
if tt.wantError != "" {
var response map[string]string
if err := json.NewDecoder(rr.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response body: %v", err)
}
if response["error"] != tt.wantError {
t.Errorf("handler returned wrong error message: got %v want %v", response["error"], tt.wantError)
}
}
})
}
}
func TestUserAgent(t *testing.T) {
ua := userAgent()
// The userAgent function should return a string in the format:
// "ollama/version (arch os) app/version Go/goversion"
// Example: "ollama/v0.1.28 (amd64 darwin) Go/go1.21.0"
if ua == "" {
t.Fatal("userAgent returned empty string")
}
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("User-Agent", ua)
// This is a copy of the logic ollama.com uses to parse the user agent
clientInfoFromRequest := func(r *http.Request) struct {
Product string
Version string
OS string
Arch string
AppVersion string
} {
product, rest, _ := strings.Cut(r.UserAgent(), " ")
client, version, ok := strings.Cut(product, "/")
if !ok {
return struct {
Product string
Version string
OS string
Arch string
AppVersion string
}{}
}
if version != "" && version[0] != 'v' {
version = "v" + version
}
arch, rest, _ := strings.Cut(rest, " ")
arch = strings.Trim(arch, "(")
os, rest, _ := strings.Cut(rest, ")")
var appVersion string
if strings.Contains(rest, "app/") {
_, appPart, found := strings.Cut(rest, "app/")
if found {
appVersion = strings.Fields(strings.TrimSpace(appPart))[0]
if appVersion != "" && appVersion[0] != 'v' {
appVersion = "v" + appVersion
}
}
}
return struct {
Product string
Version string
OS string
Arch string
AppVersion string
}{
Product: client,
Version: version,
OS: os,
Arch: arch,
AppVersion: appVersion,
}
}
info := clientInfoFromRequest(req)
if info.Product != "ollama" {
t.Errorf("Expected Product to be 'ollama', got '%s'", info.Product)
}
if info.Version != "" && info.Version[0] != 'v' {
t.Errorf("Expected Version to start with 'v', got '%s'", info.Version)
}
expectedOS := runtime.GOOS
if info.OS != expectedOS {
t.Errorf("Expected OS to be '%s', got '%s'", expectedOS, info.OS)
}
expectedArch := runtime.GOARCH
if info.Arch != expectedArch {
t.Errorf("Expected Arch to be '%s', got '%s'", expectedArch, info.Arch)
}
if info.AppVersion != "" && info.AppVersion[0] != 'v' {
t.Errorf("Expected AppVersion to start with 'v', got '%s'", info.AppVersion)
}
t.Logf("User Agent: %s", ua)
t.Logf("Parsed - Product: %s, Version: %s, OS: %s, Arch: %s",
info.Product, info.Version, info.OS, info.Arch)
}
func TestUserAgentTransport(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.Write([]byte(r.Header.Get("User-Agent")))
}))
defer ts.Close()
server := &Server{}
client := server.httpClient()
resp, err := client.Get(ts.URL)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
defer resp.Body.Close()
// In this case the User-Agent is the response body, as the server just echoes it back
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read response: %v", err)
}
receivedUA := string(body)
expectedUA := userAgent()
if receivedUA != expectedUA {
t.Errorf("User-Agent mismatch\nExpected: %s\nReceived: %s", expectedUA, receivedUA)
}
if !strings.HasPrefix(receivedUA, "ollama/") {
t.Errorf("User-Agent should start with 'ollama/', got: %s", receivedUA)
}
t.Logf("User-Agent transport successfully set: %s", receivedUA)
}
func TestInferenceClientUsesUserAgent(t *testing.T) {
var gotUserAgent atomic.Value
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotUserAgent.Store(r.Header.Get("User-Agent"))
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{}`))
}))
defer ts.Close()
t.Setenv("OLLAMA_HOST", ts.URL)
server := &Server{}
client := server.inferenceClient()
_, err := client.Show(context.Background(), &api.ShowRequest{Model: "test"})
if err != nil {
t.Fatalf("show request failed: %v", err)
}
receivedUA, _ := gotUserAgent.Load().(string)
expectedUA := userAgent()
if receivedUA != expectedUA {
t.Errorf("User-Agent mismatch\nExpected: %s\nReceived: %s", expectedUA, receivedUA)
}
}
func TestSupportsBrowserTools(t *testing.T) {
tests := []struct {
model string
want bool
}{
{"gpt-oss", true},
{"gpt-oss-latest", true},
{"GPT-OSS", true},
{"Gpt-Oss-v2", true},
{"qwen3", false},
{"deepseek-v3", false},
{"llama3.3", false},
{"", false},
}
for _, tt := range tests {
t.Run(tt.model, func(t *testing.T) {
if got := supportsBrowserTools(tt.model); got != tt.want {
t.Errorf("supportsBrowserTools(%q) = %v, want %v", tt.model, got, tt.want)
}
})
}
}
func TestWebSearchToolRegistration(t *testing.T) {
// Validates that the capability-gating logic in chat() correctly
// decides which tools to register based on model capabilities and
// the web search flag.
tests := []struct {
name string
webSearchEnabled bool
hasToolsCap bool
model string
wantBrowser bool // expects browser tools (gpt-oss)
wantWebSearch bool // expects basic web search/fetch tools
wantNone bool // expects no tools registered
}{
{
name: "web search enabled with tools capability - browser model",
webSearchEnabled: true,
hasToolsCap: true,
model: "gpt-oss-latest",
wantBrowser: true,
},
{
name: "web search enabled with tools capability - non-browser model",
webSearchEnabled: true,
hasToolsCap: true,
model: "qwen3",
wantWebSearch: true,
},
{
name: "web search enabled without tools capability",
webSearchEnabled: true,
hasToolsCap: false,
model: "llama3.3",
wantNone: true,
},
{
name: "web search disabled with tools capability",
webSearchEnabled: false,
hasToolsCap: true,
model: "qwen3",
wantNone: true,
},
{
name: "web search disabled without tools capability",
webSearchEnabled: false,
hasToolsCap: false,
model: "llama3.3",
wantNone: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Replicate the decision logic from chat() handler
gotBrowser := false
gotWebSearch := false
if tt.webSearchEnabled && tt.hasToolsCap {
if supportsBrowserTools(tt.model) {
gotBrowser = true
} else {
gotWebSearch = true
}
}
if tt.wantBrowser && !gotBrowser {
t.Error("expected browser tools to be registered")
}
if tt.wantWebSearch && !gotWebSearch {
t.Error("expected web search tools to be registered")
}
if tt.wantNone && (gotBrowser || gotWebSearch) {
t.Error("expected no tools to be registered")
}
if !tt.wantBrowser && gotBrowser {
t.Error("unexpected browser tools registered")
}
if !tt.wantWebSearch && gotWebSearch {
t.Error("unexpected web search tools registered")
}
})
}
}
func TestSettingsToggleAutoUpdateOff_CancelsDownload(t *testing.T) {
testStore := &store.Store{
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
}
defer testStore.Close()
// Start with auto-update enabled
settings, err := testStore.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = true
if err := testStore.SetSettings(settings); err != nil {
t.Fatal(err)
}
upd := &updater.Updater{Store: &store.Store{
DBPath: filepath.Join(t.TempDir(), "db2.sqlite"),
}}
defer upd.Store.Close()
// We can't easily mock CancelOngoingDownload, but we can verify
// the full settings handler flow works without error
server := &Server{
Store: testStore,
Restart: func() {},
Updater: upd,
}
// Disable auto-update via settings API
settings.AutoUpdateEnabled = false
body, err := json.Marshal(settings)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
if err := server.settings(rr, req); err != nil {
t.Fatalf("settings() error = %v", err)
}
if rr.Code != http.StatusOK {
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
}
// Verify settings were saved with auto-update disabled
saved, err := testStore.Settings()
if err != nil {
t.Fatal(err)
}
if saved.AutoUpdateEnabled {
t.Fatal("expected AutoUpdateEnabled to be false after toggle off")
}
}
func TestSettingsToggleAutoUpdateOn_WithPendingUpdate_ShowsNotification(t *testing.T) {
testStore := &store.Store{
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
}
defer testStore.Close()
// Start with auto-update disabled
settings, err := testStore.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = false
if err := testStore.SetSettings(settings); err != nil {
t.Fatal(err)
}
// Simulate that an update was previously downloaded
oldVal := updater.UpdateDownloaded
updater.UpdateDownloaded = true
defer func() { updater.UpdateDownloaded = oldVal }()
var notificationCalled atomic.Bool
server := &Server{
Store: testStore,
Restart: func() {},
UpdateAvailableFunc: func() {
notificationCalled.Store(true)
},
}
// Re-enable auto-update via settings API
settings.AutoUpdateEnabled = true
body, err := json.Marshal(settings)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
if err := server.settings(rr, req); err != nil {
t.Fatalf("settings() error = %v", err)
}
if rr.Code != http.StatusOK {
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
}
if !notificationCalled.Load() {
t.Fatal("expected UpdateAvailableFunc to be called when re-enabling with a downloaded update")
}
}
func TestSettingsToggleAutoUpdateOn_NoPendingUpdate_TriggersCheck(t *testing.T) {
testStore := &store.Store{
DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
}
defer testStore.Close()
// Start with auto-update disabled
settings, err := testStore.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = false
if err := testStore.SetSettings(settings); err != nil {
t.Fatal(err)
}
// Ensure no pending update - clear both the downloaded flag and the stage dir
oldVal := updater.UpdateDownloaded
updater.UpdateDownloaded = false
defer func() { updater.UpdateDownloaded = oldVal }()
oldStageDir := updater.UpdateStageDir
updater.UpdateStageDir = t.TempDir() // empty dir means IsUpdatePending() returns false
defer func() { updater.UpdateStageDir = oldStageDir }()
upd := &updater.Updater{Store: &store.Store{
DBPath: filepath.Join(t.TempDir(), "db2.sqlite"),
}}
defer upd.Store.Close()
// Initialize the checkNow channel by starting (and immediately stopping) the checker
// so TriggerImmediateCheck doesn't panic on nil channel
ctx, cancel := context.WithCancel(t.Context())
upd.StartBackgroundUpdaterChecker(ctx, func(string) error { return nil })
defer cancel()
var notificationCalled atomic.Bool
server := &Server{
Store: testStore,
Restart: func() {},
Updater: upd,
UpdateAvailableFunc: func() {
notificationCalled.Store(true)
},
}
// Re-enable auto-update via settings API
settings.AutoUpdateEnabled = true
body, err := json.Marshal(settings)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest("POST", "/api/v1/settings", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
if err := server.settings(rr, req); err != nil {
t.Fatalf("settings() error = %v", err)
}
if rr.Code != http.StatusOK {
t.Fatalf("settings() status = %d, want %d", rr.Code, http.StatusOK)
}
// UpdateAvailableFunc should NOT be called since there's no pending update
if notificationCalled.Load() {
t.Fatal("UpdateAvailableFunc should not be called when there is no pending update")
}
}