diff --git a/anthropic/anthropic.go b/anthropic/anthropic.go index a5ff995f5..f5799fc1e 100755 --- a/anthropic/anthropic.go +++ b/anthropic/anthropic.go @@ -518,26 +518,24 @@ func mapStopReason(reason string, hasToolCalls bool) string { // StreamConverter manages state for converting Ollama streaming responses to Anthropic format type StreamConverter struct { - ID string - Model string - firstWrite bool - contentIndex int - inputTokens int - outputTokens int - estimatedInputTokens int // Estimated tokens from request (used when actual metrics are 0) - thinkingStarted bool - thinkingDone bool - textStarted bool - toolCallsSent map[string]bool + ID string + Model string + firstWrite bool + contentIndex int + inputTokens int + outputTokens int + thinkingStarted bool + thinkingDone bool + textStarted bool + toolCallsSent map[string]bool } -func NewStreamConverter(id, model string, estimatedInputTokens int) *StreamConverter { +func NewStreamConverter(id, model string) *StreamConverter { return &StreamConverter{ - ID: id, - Model: model, - firstWrite: true, - estimatedInputTokens: estimatedInputTokens, - toolCallsSent: make(map[string]bool), + ID: id, + Model: model, + firstWrite: true, + toolCallsSent: make(map[string]bool), } } @@ -553,11 +551,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent { if c.firstWrite { c.firstWrite = false - // Use actual metrics if available, otherwise use estimate c.inputTokens = r.Metrics.PromptEvalCount - if c.inputTokens == 0 && c.estimatedInputTokens > 0 { - c.inputTokens = c.estimatedInputTokens - } events = append(events, StreamEvent{ Event: "message_start", @@ -785,121 +779,3 @@ func mapToArgs(m map[string]any) api.ToolCallFunctionArguments { } return args } - -// CountTokensRequest represents an Anthropic count_tokens request -type CountTokensRequest struct { - Model string `json:"model"` - Messages []MessageParam `json:"messages"` - System any `json:"system,omitempty"` - Tools []Tool `json:"tools,omitempty"` - Thinking *ThinkingConfig `json:"thinking,omitempty"` -} - -// EstimateInputTokens estimates input tokens from a MessagesRequest (reuses CountTokensRequest logic) -func EstimateInputTokens(req MessagesRequest) int { - return estimateTokens(CountTokensRequest{ - Model: req.Model, - Messages: req.Messages, - System: req.System, - Tools: req.Tools, - Thinking: req.Thinking, - }) -} - -// CountTokensResponse represents an Anthropic count_tokens response -type CountTokensResponse struct { - InputTokens int `json:"input_tokens"` -} - -// estimateTokens returns a rough estimate of tokens (len/4) -func estimateTokens(req CountTokensRequest) int { - var totalLen int - - // Count system prompt - if req.System != nil { - totalLen += countAnyContent(req.System) - } - - // Count messages - for _, msg := range req.Messages { - // Count role (always present) - totalLen += len(msg.Role) - // Count content - contentLen := countAnyContent(msg.Content) - totalLen += contentLen - } - - for _, tool := range req.Tools { - totalLen += len(tool.Name) + len(tool.Description) + len(tool.InputSchema) - } - - // Return len/4 as rough token estimate, minimum 1 if there's any content - tokens := totalLen / 4 - if tokens == 0 && (len(req.Messages) > 0 || req.System != nil) { - tokens = 1 - } - return tokens -} - -func countAnyContent(content any) int { - if content == nil { - return 0 - } - - switch c := content.(type) { - case string: - return len(c) - case []any: - total := 0 - for _, block := range c { - total += countContentBlock(block) - } - return total - default: - if data, err := json.Marshal(content); err == nil { - return len(data) - } - return 0 - } -} - -func countContentBlock(block any) int { - blockMap, ok := block.(map[string]any) - if !ok { - if s, ok := block.(string); ok { - return len(s) - } - return 0 - } - - total := 0 - blockType, _ := blockMap["type"].(string) - - if text, ok := blockMap["text"].(string); ok { - total += len(text) - } - - if thinking, ok := blockMap["thinking"].(string); ok { - total += len(thinking) - } - - if blockType == "tool_use" { - if data, err := json.Marshal(blockMap); err == nil { - total += len(data) - } - } - - if blockType == "tool_result" { - if data, err := json.Marshal(blockMap); err == nil { - total += len(data) - } - } - - if source, ok := blockMap["source"].(map[string]any); ok { - if data, ok := source["data"].(string); ok { - total += len(data) - } - } - - return total -} diff --git a/anthropic/anthropic_test.go b/anthropic/anthropic_test.go index a60327e6a..1c2a4a868 100755 --- a/anthropic/anthropic_test.go +++ b/anthropic/anthropic_test.go @@ -605,7 +605,7 @@ func TestGenerateMessageID(t *testing.T) { } func TestStreamConverter_Basic(t *testing.T) { - conv := NewStreamConverter("msg_123", "test-model", 0) + conv := NewStreamConverter("msg_123", "test-model") // First chunk resp1 := api.ChatResponse{ @@ -678,7 +678,7 @@ func TestStreamConverter_Basic(t *testing.T) { } func TestStreamConverter_WithToolCalls(t *testing.T) { - conv := NewStreamConverter("msg_123", "test-model", 0) + conv := NewStreamConverter("msg_123", "test-model") resp := api.ChatResponse{ Model: "test-model", @@ -731,7 +731,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) { func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) { // Test that unmarshalable arguments (like channels) are handled gracefully // and don't cause a panic or corrupt stream - conv := NewStreamConverter("msg_123", "test-model", 0) + conv := NewStreamConverter("msg_123", "test-model") // Create a channel which cannot be JSON marshaled unmarshalable := make(chan int) @@ -778,7 +778,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) { func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) { // Test that valid tool calls still work when mixed with invalid ones - conv := NewStreamConverter("msg_123", "test-model", 0) + conv := NewStreamConverter("msg_123", "test-model") unmarshalable := make(chan int) badArgs := api.NewToolCallFunctionArguments() @@ -903,7 +903,7 @@ func TestContentBlockJSON_EmptyFieldsPresent(t *testing.T) { // events include the required empty fields for SDK compatibility. func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) { t.Run("text block start includes empty text", func(t *testing.T) { - conv := NewStreamConverter("msg_123", "test-model", 0) + conv := NewStreamConverter("msg_123", "test-model") resp := api.ChatResponse{ Model: "test-model", @@ -937,7 +937,7 @@ func TestStreamConverter_ContentBlockStartIncludesEmptyFields(t *testing.T) { }) t.Run("thinking block start includes empty thinking", func(t *testing.T) { - conv := NewStreamConverter("msg_123", "test-model", 0) + conv := NewStreamConverter("msg_123", "test-model") resp := api.ChatResponse{ Model: "test-model", diff --git a/api/client.go b/api/client.go index eec720b93..d70672a6b 100644 --- a/api/client.go +++ b/api/client.go @@ -466,15 +466,3 @@ func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) { } return &resp, nil } - -// AliasRequest is the request body for creating or updating a model alias. -type AliasRequest struct { - Alias string `json:"alias"` - Target string `json:"target"` - PrefixMatching bool `json:"prefix_matching,omitempty"` -} - -// SetAliasExperimental creates or updates a model alias via the experimental aliases API. -func (c *Client) SetAliasExperimental(ctx context.Context, req *AliasRequest) error { - return c.do(ctx, http.MethodPost, "/api/experimental/aliases", req, nil) -} diff --git a/cmd/config/claude.go b/cmd/config/claude.go index c26bd6ccb..80a72f564 100644 --- a/cmd/config/claude.go +++ b/cmd/config/claude.go @@ -1,23 +1,18 @@ package config import ( - "context" "fmt" "os" "os/exec" "path/filepath" "runtime" - "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" ) -// Claude implements Runner and AliasConfigurer for Claude Code integration +// Claude implements Runner for Claude Code integration type Claude struct{} -// Compile-time check that Claude implements AliasConfigurer -var _ AliasConfigurer = (*Claude)(nil) - func (c *Claude) String() string { return "Claude Code" } func (c *Claude) args(model string, extra []string) []string { @@ -65,96 +60,3 @@ func (c *Claude) Run(model string, args []string) error { ) return cmd.Run() } - -// ConfigureAliases sets up Primary and Fast model aliases for Claude Code. -func (c *Claude) ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error) { - aliases := make(map[string]string) - for k, v := range existing { - aliases[k] = v - } - - if primaryModel != "" { - aliases["primary"] = primaryModel - } - - if !force && aliases["primary"] != "" && aliases["fast"] != "" { - return aliases, false, nil - } - - items, existingModels, cloudModels, client, err := listModels(ctx) - if err != nil { - return nil, false, err - } - - fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n", ansiBold, ansiReset) - fmt.Fprintf(os.Stderr, "%sClaude Code uses multiple models for various tasks%s\n\n", ansiGray, ansiReset) - - fmt.Fprintf(os.Stderr, "%sPrimary%s\n", ansiBold, ansiReset) - fmt.Fprintf(os.Stderr, "%sHandles complex reasoning: planning, code generation, debugging.%s\n\n", ansiGray, ansiReset) - - if aliases["primary"] == "" || force { - primary, err := selectPrompt("Select Primary model:", items) - if err != nil { - return nil, false, err - } - if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil { - return nil, false, err - } - if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil { - return nil, false, err - } - aliases["primary"] = primary - } else { - fmt.Fprintf(os.Stderr, " %s\n\n", aliases["primary"]) - } - - fmt.Fprintf(os.Stderr, "%sFast%s\n", ansiBold, ansiReset) - fmt.Fprintf(os.Stderr, "%sHandles quick operations: file searches, simple edits, status checks.%s\n", ansiGray, ansiReset) - fmt.Fprintf(os.Stderr, "%sSmaller models work well and respond faster.%s\n\n", ansiGray, ansiReset) - - if aliases["fast"] == "" || force { - fast, err := selectPrompt("Select Fast model:", items) - if err != nil { - return nil, false, err - } - if err := pullIfNeeded(ctx, client, existingModels, fast); err != nil { - return nil, false, err - } - if err := ensureAuth(ctx, client, cloudModels, []string{fast}); err != nil { - return nil, false, err - } - aliases["fast"] = fast - } - - return aliases, true, nil -} - -// SetAliases syncs the configured aliases to the Ollama server using prefix matching. -func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error { - client, err := api.ClientFromEnvironment() - if err != nil { - return err - } - - prefixAliases := map[string]string{ - "claude-sonnet-": aliases["primary"], - "claude-haiku-": aliases["fast"], - } - - var errs []string - for prefix, target := range prefixAliases { - req := &api.AliasRequest{ - Alias: prefix, - Target: target, - PrefixMatching: true, - } - if err := client.SetAliasExperimental(ctx, req); err != nil { - errs = append(errs, prefix) - } - } - - if len(errs) > 0 { - return fmt.Errorf("failed to set aliases: %v", errs) - } - return nil -} diff --git a/cmd/config/config.go b/cmd/config/config.go index 1c183e7b9..5f98bd5ed 100644 --- a/cmd/config/config.go +++ b/cmd/config/config.go @@ -13,8 +13,7 @@ import ( ) type integration struct { - Models []string `json:"models"` - Aliases map[string]string `json:"aliases,omitempty"` + Models []string `json:"models"` } type config struct { @@ -134,16 +133,8 @@ func saveIntegration(appName string, models []string) error { return err } - key := strings.ToLower(appName) - existing := cfg.Integrations[key] - var aliases map[string]string - if existing != nil && existing.Aliases != nil { - aliases = existing.Aliases - } - - cfg.Integrations[key] = &integration{ - Models: models, - Aliases: aliases, + cfg.Integrations[strings.ToLower(appName)] = &integration{ + Models: models, } return save(cfg) @@ -163,33 +154,6 @@ func loadIntegration(appName string) (*integration, error) { return ic, nil } -func saveAliases(appName string, aliases map[string]string) error { - if appName == "" { - return errors.New("app name cannot be empty") - } - - cfg, err := load() - if err != nil { - return err - } - - key := strings.ToLower(appName) - existing := cfg.Integrations[key] - if existing == nil { - existing = &integration{} - } - - if existing.Aliases == nil { - existing.Aliases = make(map[string]string) - } - for k, v := range aliases { - existing.Aliases[k] = v - } - - cfg.Integrations[key] = existing - return save(cfg) -} - func listIntegrations() ([]integration, error) { cfg, err := load() if err != nil { diff --git a/cmd/config/config_test.go b/cmd/config/config_test.go index a491a276f..ae87c6a40 100644 --- a/cmd/config/config_test.go +++ b/cmd/config/config_test.go @@ -46,53 +46,6 @@ func TestIntegrationConfig(t *testing.T) { } }) - t.Run("save and load aliases", func(t *testing.T) { - models := []string{"llama3.2"} - if err := saveIntegration("claude", models); err != nil { - t.Fatal(err) - } - aliases := map[string]string{ - "primary": "llama3.2:70b", - "fast": "llama3.2:8b", - } - if err := saveAliases("claude", aliases); err != nil { - t.Fatal(err) - } - - config, err := loadIntegration("claude") - if err != nil { - t.Fatal(err) - } - if config.Aliases == nil { - t.Fatal("expected aliases to be saved") - } - for k, v := range aliases { - if config.Aliases[k] != v { - t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k]) - } - } - }) - - t.Run("saveIntegration preserves aliases", func(t *testing.T) { - if err := saveIntegration("claude", []string{"model-a"}); err != nil { - t.Fatal(err) - } - if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil { - t.Fatal(err) - } - - if err := saveIntegration("claude", []string{"model-b"}); err != nil { - t.Fatal(err) - } - config, err := loadIntegration("claude") - if err != nil { - t.Fatal(err) - } - if config.Aliases["primary"] != "model-a" { - t.Errorf("expected aliases to be preserved, got %v", config.Aliases) - } - }) - t.Run("defaultModel returns first model", func(t *testing.T) { saveIntegration("codex", []string{"model-a", "model-b"}) diff --git a/cmd/config/integrations.go b/cmd/config/integrations.go index 6991609e3..714eae625 100644 --- a/cmd/config/integrations.go +++ b/cmd/config/integrations.go @@ -39,15 +39,6 @@ type Editor interface { Models() []string } -// AliasConfigurer can configure model aliases (e.g., for subagent routing). -// Integrations like Claude and Codex use this to route model requests to local models. -type AliasConfigurer interface { - // ConfigureAliases prompts the user to configure aliases and returns the updated map. - ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error) - // SetAliases syncs the configured aliases to the server - SetAliases(ctx context.Context, aliases map[string]string) error -} - // integrations is the registry of available integrations. var integrations = map[string]Runner{ "claude": &Claude{}, @@ -138,11 +129,7 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) { return nil, err } } else { - prompt := fmt.Sprintf("Select model for %s:", r) - if _, ok := r.(AliasConfigurer); ok { - prompt = fmt.Sprintf("Select Primary model for %s:", r) - } - model, err := selectPrompt(prompt, items) + model, err := selectPrompt(fmt.Sprintf("Select model for %s:", r), items) if err != nil { return nil, err } @@ -170,146 +157,73 @@ func selectModels(ctx context.Context, name, current string) ([]string, error) { } } - if err := ensureAuth(ctx, client, cloudModels, selected); err != nil { - return nil, err - } - - return selected, nil -} - -func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error { - if existingModels[model] { - return nil - } - msg := fmt.Sprintf("Download %s?", model) - if ok, err := confirmPrompt(msg); err != nil { - return err - } else if !ok { - return errCancelled - } - fmt.Fprintf(os.Stderr, "\n") - if err := pullModel(ctx, client, model); err != nil { - return fmt.Errorf("failed to pull %s: %w", model, err) - } - return nil -} - -func listModels(ctx context.Context) ([]selectItem, map[string]bool, map[string]bool, *api.Client, error) { - client, err := api.ClientFromEnvironment() - if err != nil { - return nil, nil, nil, nil, err - } - - models, err := client.List(ctx) - if err != nil { - return nil, nil, nil, nil, err - } - - var existing []modelInfo - for _, m := range models.Models { - existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""}) - } - - items, _, existingModels, cloudModels := buildModelList(existing, nil, "") - - if len(items) == 0 { - return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull ' first") - } - - return items, existingModels, cloudModels, client, nil -} - -func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error { var selectedCloudModels []string for _, m := range selected { if cloudModels[m] { selectedCloudModels = append(selectedCloudModels, m) } } - if len(selectedCloudModels) == 0 { - return nil - } + if len(selectedCloudModels) > 0 { + // ensure user is signed in + user, err := client.Whoami(ctx) + if err == nil && user != nil && user.Name != "" { + return selected, nil + } - user, err := client.Whoami(ctx) - if err == nil && user != nil && user.Name != "" { - return nil - } + var aErr api.AuthorizationError + if !errors.As(err, &aErr) || aErr.SigninURL == "" { + return nil, err + } - var aErr api.AuthorizationError - if !errors.As(err, &aErr) || aErr.SigninURL == "" { - return err - } + modelList := strings.Join(selectedCloudModels, ", ") + yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList)) + if err != nil || !yes { + return nil, fmt.Errorf("%s requires sign in", modelList) + } - modelList := strings.Join(selectedCloudModels, ", ") - yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList)) - if err != nil || !yes { - return fmt.Errorf("%s requires sign in", modelList) - } + fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL) - fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL) + // TODO(parthsareen): extract into auth package for cmd + // Auto-open browser (best effort, fail silently) + switch runtime.GOOS { + case "darwin": + _ = exec.Command("open", aErr.SigninURL).Start() + case "linux": + _ = exec.Command("xdg-open", aErr.SigninURL).Start() + case "windows": + _ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start() + } - switch runtime.GOOS { - case "darwin": - _ = exec.Command("open", aErr.SigninURL).Start() - case "linux": - _ = exec.Command("xdg-open", aErr.SigninURL).Start() - case "windows": - _ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start() - } + spinnerFrames := []string{"|", "/", "-", "\\"} + frame := 0 - spinnerFrames := []string{"|", "/", "-", "\\"} - frame := 0 + fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0]) - fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0]) + ticker := time.NewTicker(200 * time.Millisecond) + defer ticker.Stop() - ticker := time.NewTicker(200 * time.Millisecond) - defer ticker.Stop() + for { + select { + case <-ctx.Done(): + fmt.Fprintf(os.Stderr, "\r\033[K") + return nil, ctx.Err() + case <-ticker.C: + frame++ + fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)]) - for { - select { - case <-ctx.Done(): - fmt.Fprintf(os.Stderr, "\r\033[K") - return ctx.Err() - case <-ticker.C: - frame++ - fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)]) - - // poll every 10th frame (~2 seconds) - if frame%10 == 0 { - u, err := client.Whoami(ctx) - if err == nil && u != nil && u.Name != "" { - fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name) - return nil + // poll every 10th frame (~2 seconds) + if frame%10 == 0 { + u, err := client.Whoami(ctx) + if err == nil && u != nil && u.Name != "" { + fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name) + return selected, nil + } } } } } -} -func ensureAliases(ctx context.Context, r Runner, name string, primaryModel string, existing map[string]string, force bool) (bool, error) { - ac, ok := r.(AliasConfigurer) - if !ok { - return false, nil - } - - aliases, updated, err := ac.ConfigureAliases(ctx, primaryModel, existing, force) - if err != nil { - return false, err - } - if !updated { - return false, nil - } - - if err := saveAliases(name, aliases); err != nil { - return false, err - } - - if err := ac.SetAliases(ctx, aliases); err != nil { - fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases to server: %v%s\n", ansiGray, err, ansiReset) - fmt.Fprintf(os.Stderr, "%sAliases saved locally. Server sync will retry on next launch.%s\n\n", ansiGray, ansiReset) - } - - return true, nil + return selected, nil } func runIntegration(name, modelName string, args []string) error { @@ -317,17 +231,6 @@ func runIntegration(name, modelName string, args []string) error { if !ok { return fmt.Errorf("unknown integration: %s", name) } - - if _, ok := r.(AliasConfigurer); ok { - if config, err := loadIntegration(name); err == nil && config.Aliases != nil { - primary, fast := config.Aliases["primary"], config.Aliases["fast"] - if primary != "" && fast != "" { - fmt.Fprintf(os.Stderr, "\nLaunching %s with Primary: %s, Fast: %s...\n", r, primary, fast) - return r.Run(modelName, args) - } - } - } - fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName) return r.Run(modelName, args) } @@ -401,50 +304,10 @@ Examples: if !configFlag && modelFlag == "" { if config, err := loadIntegration(name); err == nil && len(config.Models) > 0 { - if _, err := ensureAliases(cmd.Context(), r, name, config.Models[0], config.Aliases, false); errors.Is(err, errCancelled) { - return nil - } else if err != nil { - return err - } return runIntegration(name, config.Models[0], passArgs) } } - if ac, ok := r.(AliasConfigurer); ok { - var existingAliases map[string]string - if existing, err := loadIntegration(name); err == nil { - existingAliases = existing.Aliases - } - aliases, updated, err := ac.ConfigureAliases(cmd.Context(), "", existingAliases, configFlag) - if errors.Is(err, errCancelled) { - return nil - } - if err != nil { - return err - } - if updated { - if err := saveAliases(name, aliases); err != nil { - return err - } - if err := ac.SetAliases(cmd.Context(), aliases); err != nil { - fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases to server: %v%s\n", ansiGray, err, ansiReset) - } - fmt.Fprintf(os.Stderr, "\n%sConfiguration Complete%s\n", ansiBold, ansiReset) - fmt.Fprintf(os.Stderr, "Primary: %s\n", aliases["primary"]) - fmt.Fprintf(os.Stderr, "Fast: %s\n\n", aliases["fast"]) - } - if err := saveIntegration(name, []string{aliases["primary"]}); err != nil { - return fmt.Errorf("failed to save: %w", err) - } - if configFlag { - if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch { - return runIntegration(name, aliases["primary"], passArgs) - } - return nil - } - return runIntegration(name, aliases["primary"], passArgs) - } - var models []string if modelFlag != "" { models = []string{modelFlag} diff --git a/cmd/config/integrations_test.go b/cmd/config/integrations_test.go index e4b213e64..dd2056e98 100644 --- a/cmd/config/integrations_test.go +++ b/cmd/config/integrations_test.go @@ -509,19 +509,3 @@ func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) { t.Error("llama3.2 should not be in cloudModels") } } - -func TestAliasConfigurerInterface(t *testing.T) { - t.Run("claude implements AliasConfigurer", func(t *testing.T) { - claude := &Claude{} - if _, ok := interface{}(claude).(AliasConfigurer); !ok { - t.Error("Claude should implement AliasConfigurer") - } - }) - - t.Run("codex does not implement AliasConfigurer", func(t *testing.T) { - codex := &Codex{} - if _, ok := interface{}(codex).(AliasConfigurer); ok { - t.Error("Codex should not implement AliasConfigurer") - } - }) -} diff --git a/cmd/config/selector.go b/cmd/config/selector.go index f617a7c1d..956e1f1ea 100644 --- a/cmd/config/selector.go +++ b/cmd/config/selector.go @@ -65,10 +65,6 @@ func (s *selectState) handleInput(event inputEvent, char byte) (done bool, resul if len(filtered) > 0 && s.selected < len(filtered) { return true, filtered[s.selected].Name, nil } - // No matches but user typed something - return filter for pull prompt - if len(filtered) == 0 && s.filter != "" { - return true, s.filter, nil - } case eventEscape: return true, "", errCancelled case eventBackspace: @@ -287,11 +283,7 @@ func renderSelect(w io.Writer, prompt string, s *selectState) int { lineCount := 1 if len(filtered) == 0 { - if s.filter != "" { - fmt.Fprintf(w, " %s→ Download model: '%s'? Press Enter%s\r\n", ansiGray, s.filter, ansiReset) - } else { - fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset) - } + fmt.Fprintf(w, " %s(no matches)%s\r\n", ansiGray, ansiReset) lineCount++ } else { displayCount := min(len(filtered), maxDisplayedItems) diff --git a/cmd/config/selector_test.go b/cmd/config/selector_test.go index a6bd64465..74e8796ee 100644 --- a/cmd/config/selector_test.go +++ b/cmd/config/selector_test.go @@ -87,18 +87,10 @@ func TestSelectState(t *testing.T) { } }) - t.Run("Enter_EmptyFilteredList_ReturnsFilter", func(t *testing.T) { + t.Run("Enter_EmptyFilteredList_DoesNothing", func(t *testing.T) { s := newSelectState(items) s.filter = "nonexistent" done, result, err := s.handleInput(eventEnter, 0) - if !done || result != "nonexistent" || err != nil { - t.Errorf("expected (true, 'nonexistent', nil), got (%v, %v, %v)", done, result, err) - } - }) - - t.Run("Enter_EmptyFilteredList_EmptyFilter_DoesNothing", func(t *testing.T) { - s := newSelectState([]selectItem{}) - done, result, err := s.handleInput(eventEnter, 0) if done || result != "" || err != nil { t.Errorf("expected (false, '', nil), got (%v, %v, %v)", done, result, err) } @@ -576,25 +568,14 @@ func TestRenderSelect(t *testing.T) { } }) - t.Run("EmptyFilteredList_ShowsPullPrompt", func(t *testing.T) { + t.Run("EmptyFilteredList_ShowsNoMatches", func(t *testing.T) { s := newSelectState(items) s.filter = "xyz" var buf bytes.Buffer renderSelect(&buf, "Select:", s) - output := buf.String() - if !strings.Contains(output, "Download model: 'xyz'?") { - t.Errorf("expected 'Download model: xyz?' message, got: %s", output) - } - }) - - t.Run("EmptyFilteredList_EmptyFilter_ShowsNoMatches", func(t *testing.T) { - s := newSelectState([]selectItem{}) - var buf bytes.Buffer - renderSelect(&buf, "Select:", s) - if !strings.Contains(buf.String(), "no matches") { - t.Error("expected 'no matches' message for empty list with no filter") + t.Error("expected 'no matches' message") } }) diff --git a/middleware/anthropic.go b/middleware/anthropic.go index 5df87a84a..ff55b6ebf 100644 --- a/middleware/anthropic.go +++ b/middleware/anthropic.go @@ -131,15 +131,12 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc { messageID := anthropic.GenerateMessageID() - // Estimate input tokens for streaming (actual count not available until generation completes) - estimatedTokens := anthropic.EstimateInputTokens(req) - w := &AnthropicWriter{ BaseWriter: BaseWriter{ResponseWriter: c.Writer}, stream: req.Stream, id: messageID, model: req.Model, - converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens), + converter: anthropic.NewStreamConverter(messageID, req.Model), } if req.Stream { diff --git a/server/aliases.go b/server/aliases.go deleted file mode 100644 index 9757a33fe..000000000 --- a/server/aliases.go +++ /dev/null @@ -1,422 +0,0 @@ -package server - -import ( - "encoding/json" - "errors" - "fmt" - "log/slog" - "os" - "path/filepath" - "sort" - "strings" - "sync" - - "github.com/ollama/ollama/manifest" - "github.com/ollama/ollama/types/model" -) - -const ( - routerConfigFilename = "server.json" - routerConfigVersion = 1 -) - -var errAliasCycle = errors.New("alias cycle detected") - -type aliasEntry struct { - Alias string `json:"alias"` - Target string `json:"target"` - PrefixMatching bool `json:"prefix_matching,omitempty"` -} - -type routerConfig struct { - Version int `json:"version"` - Aliases []aliasEntry `json:"aliases"` -} - -type aliasStore struct { - mu sync.RWMutex - path string - entries map[string]aliasEntry // normalized alias -> entry (exact matches) - prefixEntries []aliasEntry // prefix matches, sorted longest-first -} - -func newAliasStore(path string) (*aliasStore, error) { - store := &aliasStore{ - path: path, - entries: make(map[string]aliasEntry), - } - if err := store.load(); err != nil { - return nil, err - } - return store, nil -} - -func (s *aliasStore) load() error { - data, err := os.ReadFile(s.path) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - return nil - } - return err - } - - var cfg routerConfig - if err := json.Unmarshal(data, &cfg); err != nil { - return err - } - - if cfg.Version != 0 && cfg.Version != routerConfigVersion { - return fmt.Errorf("unsupported router config version %d", cfg.Version) - } - - for _, entry := range cfg.Aliases { - targetName := model.ParseName(entry.Target) - if !targetName.IsValid() { - slog.Warn("invalid alias target in router config", "target", entry.Target) - continue - } - canonicalTarget := displayAliasName(targetName) - - if entry.PrefixMatching { - // Prefix aliases don't need to be valid model names - alias := strings.TrimSpace(entry.Alias) - if alias == "" { - slog.Warn("empty prefix alias in router config") - continue - } - s.prefixEntries = append(s.prefixEntries, aliasEntry{ - Alias: alias, - Target: canonicalTarget, - PrefixMatching: true, - }) - } else { - aliasName := model.ParseName(entry.Alias) - if !aliasName.IsValid() { - slog.Warn("invalid alias name in router config", "alias", entry.Alias) - continue - } - canonicalAlias := displayAliasName(aliasName) - s.entries[normalizeAliasKey(aliasName)] = aliasEntry{ - Alias: canonicalAlias, - Target: canonicalTarget, - } - } - } - - // Sort prefix entries by alias length descending (longest prefix wins) - s.sortPrefixEntriesLocked() - - return nil -} - -func (s *aliasStore) saveLocked() error { - dir := filepath.Dir(s.path) - if err := os.MkdirAll(dir, 0o755); err != nil { - return err - } - - // Combine exact and prefix entries - entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries)) - for _, entry := range s.entries { - entries = append(entries, entry) - } - entries = append(entries, s.prefixEntries...) - - sort.Slice(entries, func(i, j int) bool { - return strings.Compare(entries[i].Alias, entries[j].Alias) < 0 - }) - - cfg := routerConfig{ - Version: routerConfigVersion, - Aliases: entries, - } - - f, err := os.CreateTemp(dir, "router-*.json") - if err != nil { - return err - } - - enc := json.NewEncoder(f) - enc.SetIndent("", " ") - if err := enc.Encode(cfg); err != nil { - _ = f.Close() - _ = os.Remove(f.Name()) - return err - } - - if err := f.Close(); err != nil { - _ = os.Remove(f.Name()) - return err - } - - if err := os.Chmod(f.Name(), 0o644); err != nil { - _ = os.Remove(f.Name()) - return err - } - - return os.Rename(f.Name(), s.path) -} - -func (s *aliasStore) ResolveName(name model.Name) (model.Name, bool, error) { - // If a local model exists, do not allow alias shadowing (highest priority). - exists, err := localModelExists(name) - if err != nil { - return name, false, err - } - if exists { - return name, false, nil - } - - key := normalizeAliasKey(name) - - s.mu.RLock() - entry, exactMatch := s.entries[key] - var prefixMatch *aliasEntry - if !exactMatch { - // Try prefix matching - prefixEntries is sorted longest-first - nameStr := strings.ToLower(displayAliasName(name)) - for i := range s.prefixEntries { - prefix := strings.ToLower(s.prefixEntries[i].Alias) - if strings.HasPrefix(nameStr, prefix) { - prefixMatch = &s.prefixEntries[i] - break // First match is longest due to sorting - } - } - } - s.mu.RUnlock() - - if !exactMatch && prefixMatch == nil { - return name, false, nil - } - - var current string - var visited map[string]struct{} - - if exactMatch { - visited = map[string]struct{}{key: {}} - current = entry.Target - } else { - // For prefix match, use the target as-is - visited = map[string]struct{}{} - current = prefixMatch.Target - } - - targetKey := normalizeAliasKeyString(current) - - for { - targetName := model.ParseName(current) - if !targetName.IsValid() { - return name, false, fmt.Errorf("alias target %q is invalid", current) - } - - if _, seen := visited[targetKey]; seen { - return name, false, errAliasCycle - } - visited[targetKey] = struct{}{} - - s.mu.RLock() - next, ok := s.entries[targetKey] - s.mu.RUnlock() - if !ok { - return targetName, true, nil - } - - current = next.Target - targetKey = normalizeAliasKeyString(current) - } -} - -func (s *aliasStore) Set(alias, target model.Name, prefixMatching bool) error { - targetKey := normalizeAliasKey(target) - - s.mu.Lock() - defer s.mu.Unlock() - - if prefixMatching { - // For prefix aliases, we skip cycle detection since prefix matching - // works differently and the target is a specific model - aliasStr := displayAliasName(alias) - - // Remove any existing prefix entry with the same alias - for i, e := range s.prefixEntries { - if strings.EqualFold(e.Alias, aliasStr) { - s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...) - break - } - } - - s.prefixEntries = append(s.prefixEntries, aliasEntry{ - Alias: aliasStr, - Target: displayAliasName(target), - PrefixMatching: true, - }) - s.sortPrefixEntriesLocked() - return s.saveLocked() - } - - aliasKey := normalizeAliasKey(alias) - - if aliasKey == targetKey { - return fmt.Errorf("alias cannot point to itself") - } - - visited := map[string]struct{}{aliasKey: {}} - currentKey := targetKey - for { - if _, seen := visited[currentKey]; seen { - return errAliasCycle - } - visited[currentKey] = struct{}{} - - next, ok := s.entries[currentKey] - if !ok { - break - } - currentKey = normalizeAliasKeyString(next.Target) - } - - s.entries[aliasKey] = aliasEntry{ - Alias: displayAliasName(alias), - Target: displayAliasName(target), - } - - return s.saveLocked() -} - -func (s *aliasStore) Delete(alias model.Name) (bool, error) { - aliasKey := normalizeAliasKey(alias) - - s.mu.Lock() - defer s.mu.Unlock() - - // Try exact match first - if _, ok := s.entries[aliasKey]; ok { - delete(s.entries, aliasKey) - return true, s.saveLocked() - } - - // Try prefix entries - aliasStr := displayAliasName(alias) - for i, e := range s.prefixEntries { - if strings.EqualFold(e.Alias, aliasStr) { - s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...) - return true, s.saveLocked() - } - } - - return false, nil -} - -// DeleteByString deletes an alias by its raw string value, useful for prefix -// aliases that may not be valid model names. -func (s *aliasStore) DeleteByString(alias string) (bool, error) { - alias = strings.TrimSpace(alias) - aliasLower := strings.ToLower(alias) - - s.mu.Lock() - defer s.mu.Unlock() - - // Try prefix entries first (since this is mainly for prefix aliases) - for i, e := range s.prefixEntries { - if strings.EqualFold(e.Alias, alias) { - s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...) - return true, s.saveLocked() - } - } - - // Also check exact entries by normalized key - if _, ok := s.entries[aliasLower]; ok { - delete(s.entries, aliasLower) - return true, s.saveLocked() - } - - return false, nil -} - -func (s *aliasStore) List() []aliasEntry { - s.mu.RLock() - defer s.mu.RUnlock() - - entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries)) - for _, entry := range s.entries { - entries = append(entries, entry) - } - entries = append(entries, s.prefixEntries...) - - sort.Slice(entries, func(i, j int) bool { - return strings.Compare(entries[i].Alias, entries[j].Alias) < 0 - }) - return entries -} - -func normalizeAliasKey(name model.Name) string { - return strings.ToLower(displayAliasName(name)) -} - -func (s *aliasStore) sortPrefixEntriesLocked() { - sort.Slice(s.prefixEntries, func(i, j int) bool { - // Sort by length descending (longest prefix first) - return len(s.prefixEntries[i].Alias) > len(s.prefixEntries[j].Alias) - }) -} - -func normalizeAliasKeyString(value string) string { - n := model.ParseName(value) - if !n.IsValid() { - return strings.ToLower(strings.TrimSpace(value)) - } - return normalizeAliasKey(n) -} - -func displayAliasName(n model.Name) string { - display := n.DisplayShortest() - if strings.EqualFold(n.Tag, "latest") { - if idx := strings.LastIndex(display, ":"); idx != -1 { - return display[:idx] - } - } - return display -} - -func localModelExists(name model.Name) (bool, error) { - manifests, err := manifest.Manifests(true) - if err != nil { - return false, err - } - needle := name.String() - for existing := range manifests { - if strings.EqualFold(existing.String(), needle) { - return true, nil - } - } - return false, nil -} - -func routerConfigPath() string { - home, err := os.UserHomeDir() - if err != nil { - return filepath.Join(".ollama", routerConfigFilename) - } - return filepath.Join(home, ".ollama", routerConfigFilename) -} - -func (s *Server) aliasStore() (*aliasStore, error) { - s.aliasesOnce.Do(func() { - s.aliases, s.aliasesErr = newAliasStore(routerConfigPath()) - }) - - return s.aliases, s.aliasesErr -} - -func (s *Server) resolveModelAliasName(name model.Name) (model.Name, bool, error) { - store, err := s.aliasStore() - if err != nil { - return name, false, err - } - - if store == nil { - return name, false, nil - } - - return store.ResolveName(name) -} diff --git a/server/routes.go b/server/routes.go index 34c1350a7..910b8e954 100644 --- a/server/routes.go +++ b/server/routes.go @@ -22,7 +22,6 @@ import ( "os/signal" "slices" "strings" - "sync" "sync/atomic" "syscall" "time" @@ -82,9 +81,6 @@ type Server struct { addr net.Addr sched *Scheduler defaultNumCtx int - aliasesOnce sync.Once - aliases *aliasStore - aliasesErr error } func init() { @@ -195,16 +191,9 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - resolvedName, _, err := s.resolveModelAliasName(name) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - name = resolvedName - // We cannot currently consolidate this into GetModel because all we'll // induce infinite recursion given the current code structure. - name, err = getExistingName(name) + name, err := getExistingName(name) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) return @@ -1591,9 +1580,6 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.POST("/api/blobs/:digest", s.CreateBlobHandler) r.HEAD("/api/blobs/:digest", s.HeadBlobHandler) r.POST("/api/copy", s.CopyHandler) - r.GET("/api/experimental/aliases", s.ListAliasesHandler) - r.POST("/api/experimental/aliases", s.CreateAliasHandler) - r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler) // Inference r.GET("/api/ps", s.PsHandler) @@ -1964,20 +1950,13 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - resolvedName, _, err := s.resolveModelAliasName(name) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - name = resolvedName - - name, err = getExistingName(name) + name, err := getExistingName(name) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) return } - m, err := GetModel(name.String()) + m, err := GetModel(req.Model) if err != nil { switch { case os.IsNotExist(err): diff --git a/server/routes_aliases.go b/server/routes_aliases.go deleted file mode 100644 index d68514e9c..000000000 --- a/server/routes_aliases.go +++ /dev/null @@ -1,159 +0,0 @@ -package server - -import ( - "errors" - "fmt" - "io" - "net/http" - "strings" - - "github.com/gin-gonic/gin" - - "github.com/ollama/ollama/types/model" -) - -type aliasListResponse struct { - Aliases []aliasEntry `json:"aliases"` -} - -type aliasDeleteRequest struct { - Alias string `json:"alias"` -} - -func (s *Server) ListAliasesHandler(c *gin.Context) { - store, err := s.aliasStore() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - var aliases []aliasEntry - if store != nil { - aliases = store.List() - } - - c.JSON(http.StatusOK, aliasListResponse{Aliases: aliases}) -} - -func (s *Server) CreateAliasHandler(c *gin.Context) { - var req aliasEntry - if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) - return - } else if err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - req.Alias = strings.TrimSpace(req.Alias) - req.Target = strings.TrimSpace(req.Target) - if req.Alias == "" || req.Target == "" { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias and target are required"}) - return - } - - // Target must always be a valid model name - targetName := model.ParseName(req.Target) - if !targetName.IsValid() { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("target %q is invalid", req.Target)}) - return - } - - var aliasName model.Name - if req.PrefixMatching { - // For prefix aliases, we still parse the alias to normalize it, - // but we allow any non-empty string since prefix patterns may not be valid model names - aliasName = model.ParseName(req.Alias) - // Even if not valid as a model name, we accept it for prefix matching - } else { - aliasName = model.ParseName(req.Alias) - if !aliasName.IsValid() { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q is invalid", req.Alias)}) - return - } - - if normalizeAliasKey(aliasName) == normalizeAliasKey(targetName) { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias cannot point to itself"}) - return - } - - exists, err := localModelExists(aliasName) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - if exists { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q conflicts with existing model", req.Alias)}) - return - } - } - - store, err := s.aliasStore() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - if err := store.Set(aliasName, targetName, req.PrefixMatching); err != nil { - status := http.StatusInternalServerError - if errors.Is(err, errAliasCycle) { - status = http.StatusBadRequest - } - c.AbortWithStatusJSON(status, gin.H{"error": err.Error()}) - return - } - - resp := aliasEntry{ - Alias: displayAliasName(aliasName), - Target: displayAliasName(targetName), - PrefixMatching: req.PrefixMatching, - } - if req.PrefixMatching && !aliasName.IsValid() { - // For prefix aliases that aren't valid model names, use the raw alias - resp.Alias = req.Alias - } - c.JSON(http.StatusOK, resp) -} - -func (s *Server) DeleteAliasHandler(c *gin.Context) { - var req aliasDeleteRequest - if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) - return - } else if err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - req.Alias = strings.TrimSpace(req.Alias) - if req.Alias == "" { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias is required"}) - return - } - - store, err := s.aliasStore() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - aliasName := model.ParseName(req.Alias) - var deleted bool - if aliasName.IsValid() { - deleted, err = store.Delete(aliasName) - } else { - // For invalid model names (like prefix aliases), try deleting by raw string - deleted, err = store.DeleteByString(req.Alias) - } - - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - if !deleted { - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("alias %q not found", req.Alias)}) - return - } - - c.JSON(http.StatusOK, gin.H{"deleted": true}) -} diff --git a/server/routes_aliases_test.go b/server/routes_aliases_test.go deleted file mode 100644 index f4cfb4be7..000000000 --- a/server/routes_aliases_test.go +++ /dev/null @@ -1,426 +0,0 @@ -package server - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "net/url" - "path/filepath" - "testing" - - "github.com/gin-gonic/gin" - - "github.com/ollama/ollama/api" - "github.com/ollama/ollama/types/model" -) - -func TestAliasShadowingRejected(t *testing.T) { - gin.SetMode(gin.TestMode) - t.Setenv("HOME", t.TempDir()) - - s := Server{} - w := createRequest(t, s.CreateHandler, api.CreateRequest{ - Model: "shadowed-model", - RemoteHost: "example.com", - From: "test", - Info: map[string]any{ - "capabilities": []string{"completion"}, - }, - Stream: &stream, - }) - if w.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d", w.Code) - } - - w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "shadowed-model", Target: "other-model"}) - if w.Code != http.StatusBadRequest { - t.Fatalf("expected status 400, got %d", w.Code) - } -} - -func TestAliasResolvesForChatRemote(t *testing.T) { - gin.SetMode(gin.TestMode) - t.Setenv("HOME", t.TempDir()) - - var remoteModel string - rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var req api.ChatRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - t.Fatal(err) - } - remoteModel = req.Model - - w.Header().Set("Content-Type", "application/json") - resp := api.ChatResponse{ - Model: req.Model, - Done: true, - DoneReason: "load", - } - if err := json.NewEncoder(w).Encode(&resp); err != nil { - t.Fatal(err) - } - })) - defer rs.Close() - - p, err := url.Parse(rs.URL) - if err != nil { - t.Fatal(err) - } - - t.Setenv("OLLAMA_REMOTES", p.Hostname()) - - s := Server{} - w := createRequest(t, s.CreateHandler, api.CreateRequest{ - Model: "target-model", - RemoteHost: rs.URL, - From: "test", - Info: map[string]any{ - "capabilities": []string{"completion"}, - }, - Stream: &stream, - }) - if w.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d", w.Code) - } - - w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "alias-model", Target: "target-model"}) - if w.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d", w.Code) - } - - w = createRequest(t, s.ChatHandler, api.ChatRequest{ - Model: "alias-model", - Messages: []api.Message{{Role: "user", Content: "hi"}}, - Stream: &stream, - }) - if w.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d", w.Code) - } - - var resp api.ChatResponse - if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { - t.Fatal(err) - } - - if resp.Model != "alias-model" { - t.Fatalf("expected response model to be alias-model, got %q", resp.Model) - } - - if remoteModel != "test" { - t.Fatalf("expected remote model to be 'test', got %q", remoteModel) - } -} - -func TestPrefixAliasBasicMatching(t *testing.T) { - tmpDir := t.TempDir() - store, err := newAliasStore(filepath.Join(tmpDir, "server.json")) - if err != nil { - t.Fatal(err) - } - - // Create a prefix alias: "myprefix-" -> "targetmodel" - targetName := model.ParseName("targetmodel") - - // Set a prefix alias (using "myprefix-" as the pattern) - store.mu.Lock() - store.prefixEntries = append(store.prefixEntries, aliasEntry{ - Alias: "myprefix-", - Target: "targetmodel", - PrefixMatching: true, - }) - store.mu.Unlock() - - // Test that "myprefix-foo" resolves to "targetmodel" - testName := model.ParseName("myprefix-foo") - resolved, wasResolved, err := store.ResolveName(testName) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !wasResolved { - t.Fatal("expected name to be resolved") - } - if resolved.DisplayShortest() != targetName.DisplayShortest() { - t.Fatalf("expected resolved name to be %q, got %q", targetName.DisplayShortest(), resolved.DisplayShortest()) - } - - // Test that "otherprefix-foo" does not resolve - otherName := model.ParseName("otherprefix-foo") - _, wasResolved, err = store.ResolveName(otherName) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if wasResolved { - t.Fatal("expected name not to be resolved") - } - - // Test that exact alias takes precedence - exactAlias := model.ParseName("myprefix-exact") - exactTarget := model.ParseName("exacttarget") - if err := store.Set(exactAlias, exactTarget, false); err != nil { - t.Fatal(err) - } - - resolved, wasResolved, err = store.ResolveName(exactAlias) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !wasResolved { - t.Fatal("expected name to be resolved") - } - if resolved.DisplayShortest() != exactTarget.DisplayShortest() { - t.Fatalf("expected resolved name to be %q (exact match), got %q", exactTarget.DisplayShortest(), resolved.DisplayShortest()) - } -} - -func TestPrefixAliasLongestMatchWins(t *testing.T) { - tmpDir := t.TempDir() - store, err := newAliasStore(filepath.Join(tmpDir, "server.json")) - if err != nil { - t.Fatal(err) - } - - // Add two prefix aliases with overlapping patterns - store.mu.Lock() - store.prefixEntries = []aliasEntry{ - {Alias: "abc-", Target: "short-target", PrefixMatching: true}, - {Alias: "abc-def-", Target: "long-target", PrefixMatching: true}, - } - store.sortPrefixEntriesLocked() - store.mu.Unlock() - - // "abc-def-ghi" should match the longer prefix "abc-def-" - testName := model.ParseName("abc-def-ghi") - resolved, wasResolved, err := store.ResolveName(testName) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !wasResolved { - t.Fatal("expected name to be resolved") - } - expectedLongTarget := model.ParseName("long-target") - if resolved.DisplayShortest() != expectedLongTarget.DisplayShortest() { - t.Fatalf("expected resolved name to be %q (longest prefix match), got %q", expectedLongTarget.DisplayShortest(), resolved.DisplayShortest()) - } - - // "abc-xyz" should match the shorter prefix "abc-" - testName2 := model.ParseName("abc-xyz") - resolved, wasResolved, err = store.ResolveName(testName2) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !wasResolved { - t.Fatal("expected name to be resolved") - } - expectedShortTarget := model.ParseName("short-target") - if resolved.DisplayShortest() != expectedShortTarget.DisplayShortest() { - t.Fatalf("expected resolved name to be %q, got %q", expectedShortTarget.DisplayShortest(), resolved.DisplayShortest()) - } -} - -func TestPrefixAliasChain(t *testing.T) { - tmpDir := t.TempDir() - store, err := newAliasStore(filepath.Join(tmpDir, "server.json")) - if err != nil { - t.Fatal(err) - } - - // Create a chain: prefix "test-" -> "intermediate" -> "final" - intermediate := model.ParseName("intermediate") - final := model.ParseName("final") - - // Add prefix alias - store.mu.Lock() - store.prefixEntries = []aliasEntry{ - {Alias: "test-", Target: "intermediate", PrefixMatching: true}, - } - store.mu.Unlock() - - // Add exact alias for the intermediate step - if err := store.Set(intermediate, final, false); err != nil { - t.Fatal(err) - } - - // "test-foo" should resolve through the chain to "final" - testName := model.ParseName("test-foo") - resolved, wasResolved, err := store.ResolveName(testName) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !wasResolved { - t.Fatal("expected name to be resolved") - } - if resolved.DisplayShortest() != final.DisplayShortest() { - t.Fatalf("expected resolved name to be %q, got %q", final.DisplayShortest(), resolved.DisplayShortest()) - } -} - -func TestPrefixAliasCRUD(t *testing.T) { - gin.SetMode(gin.TestMode) - t.Setenv("HOME", t.TempDir()) - - s := Server{} - - // Create a prefix alias via API - w := createRequest(t, s.CreateAliasHandler, aliasEntry{ - Alias: "myprefix-", - Target: "llama2", - PrefixMatching: true, - }) - if w.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) - } - - var createResp aliasEntry - if err := json.NewDecoder(w.Body).Decode(&createResp); err != nil { - t.Fatal(err) - } - if !createResp.PrefixMatching { - t.Fatal("expected prefix_matching to be true in response") - } - - // List aliases and verify the prefix alias is included - w = createRequest(t, s.ListAliasesHandler, nil) - if w.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d", w.Code) - } - - var listResp aliasListResponse - if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil { - t.Fatal(err) - } - - found := false - for _, a := range listResp.Aliases { - if a.PrefixMatching && a.Target == "llama2" { - found = true - break - } - } - if !found { - t.Fatal("expected to find prefix alias in list") - } - - // Delete the prefix alias - w = createRequest(t, s.DeleteAliasHandler, aliasDeleteRequest{Alias: "myprefix-"}) - if w.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) - } - - // Verify it's deleted - w = createRequest(t, s.ListAliasesHandler, nil) - if w.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d", w.Code) - } - - if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil { - t.Fatal(err) - } - - for _, a := range listResp.Aliases { - if a.PrefixMatching { - t.Fatal("expected prefix alias to be deleted") - } - } -} - -func TestPrefixAliasCaseInsensitive(t *testing.T) { - tmpDir := t.TempDir() - store, err := newAliasStore(filepath.Join(tmpDir, "server.json")) - if err != nil { - t.Fatal(err) - } - - // Add a prefix alias with mixed case - store.mu.Lock() - store.prefixEntries = []aliasEntry{ - {Alias: "MyPrefix-", Target: "targetmodel", PrefixMatching: true}, - } - store.mu.Unlock() - - // Test that matching is case-insensitive - testName := model.ParseName("myprefix-foo") - resolved, wasResolved, err := store.ResolveName(testName) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !wasResolved { - t.Fatal("expected name to be resolved (case-insensitive)") - } - expectedTarget := model.ParseName("targetmodel") - if resolved.DisplayShortest() != expectedTarget.DisplayShortest() { - t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest()) - } - - // Test uppercase request - testName2 := model.ParseName("MYPREFIX-BAR") - _, wasResolved, err = store.ResolveName(testName2) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !wasResolved { - t.Fatal("expected name to be resolved (uppercase)") - } -} - -func TestPrefixAliasLocalModelPrecedence(t *testing.T) { - gin.SetMode(gin.TestMode) - t.Setenv("HOME", t.TempDir()) - - s := Server{} - - // Create a local model that would match a prefix alias - w := createRequest(t, s.CreateHandler, api.CreateRequest{ - Model: "myprefix-localmodel", - RemoteHost: "example.com", - From: "test", - Info: map[string]any{ - "capabilities": []string{"completion"}, - }, - Stream: &stream, - }) - if w.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) - } - - // Create a prefix alias that would match the local model name - w = createRequest(t, s.CreateAliasHandler, aliasEntry{ - Alias: "myprefix-", - Target: "someothermodel", - PrefixMatching: true, - }) - if w.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) - } - - // Verify that resolving "myprefix-localmodel" returns the local model, not the alias target - store, err := s.aliasStore() - if err != nil { - t.Fatal(err) - } - - localModelName := model.ParseName("myprefix-localmodel") - resolved, wasResolved, err := store.ResolveName(localModelName) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if wasResolved { - t.Fatalf("expected local model to take precedence (wasResolved should be false), but got resolved to %q", resolved.DisplayShortest()) - } - if resolved.DisplayShortest() != localModelName.DisplayShortest() { - t.Fatalf("expected resolved name to be local model %q, got %q", localModelName.DisplayShortest(), resolved.DisplayShortest()) - } - - // Also verify that a non-local model matching the prefix DOES resolve to the alias target - nonLocalName := model.ParseName("myprefix-nonexistent") - resolved, wasResolved, err = store.ResolveName(nonLocalName) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !wasResolved { - t.Fatal("expected non-local model to resolve via prefix alias") - } - expectedTarget := model.ParseName("someothermodel") - if resolved.DisplayShortest() != expectedTarget.DisplayShortest() { - t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest()) - } -}