mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 15:53:27 +02:00
* don't require pulling stubs for cloud models This is a first in a series of PRs that will better integrate Ollama's cloud into the API and CLI. Previously we used to have a layer of indirection where you'd first have to pull a "stub" model that contains a reference to a cloud model. With this change, you don't have to pull first, you can just use a cloud model in various routes like `/api/chat` and `/api/show`. This change respects <https://github.com/ollama/ollama/pull/14221>, so if cloud is disabled, these models won't be accessible. There's also a new, simpler pass-through proxy that doesn't convert the requests ahead of hitting the cloud models, which they themselves already support various formats (e.g., `v1/chat/completions` or Open Responses, etc.). This will help prevent issues caused by double converting (e.g., `v1/chat/completions` converted to `api/chat` on the client, then calling cloud and converting back to a `v1/chat/completions` response instead of the cloud model handling the original `v1/chat/completions` request first). There's now a notion of "source tags", which can be mixed with existing tags. So instead of having different formats like`gpt-oss:20b-cloud` vs. `kimi-k2.5:cloud` (`-cloud` suffix vs. `:cloud`), you can now specify cloud by simply appending `:cloud`. This PR doesn't change model resolution yet, but sets us up to allow for things like omitting the non-source tag, which would make something like `ollama run gpt-oss:cloud` work the same way that `ollama run gpt-oss` already works today. More detailed changes: - Added a shared model selector parser in `types/modelselector`: - supports `:cloud` and `:local` - accepts source tags in any position - supports legacy `:<tag>-cloud` - rejects conflicting source tags - Integrated selector handling across server inference/show routes: - `GenerateHandler`, `ChatHandler`, `EmbedHandler`, `EmbeddingsHandler`, `ShowHandler` - Added explicit-cloud passthrough proxy for ollama.com: - same-endpoint forwarding for `/api/*`, `/v1/*`, and `/v1/messages` - normalizes `model` (and `name` for `/api/show`) before forwarding - forwards request headers except hop-by-hop/proxy-managed headers - uses bounded response-header timeout - handles auth failures in a friendly way - Preserved cloud-disable behavior (`OLLAMA_NO_CLOUD`) - Updated create flow to support `FROM ...:cloud` model sources (though this flow uses the legacy proxy still, supporting Modelfile overrides is more complicated with the direct proxy approach) - Updated CLI/TUI/config cloud detection to use shared selector logic - Updated CLI preflight behavior so explicit cloud requests do not auto-pull local stubs What's next? - Cloud discovery/listing and cache-backed `ollama ls` / `/api/tags` - Modelfile overlay support for virtual cloud models on OpenAI/Anthropic request families - Recommender/default-selection behavior for ambiguous model families - Fully remove the legacy flow Fixes: https://github.com/ollama/ollama/issues/13801 * consolidate pull logic into confirmAndPull helper pullIfNeeded and ShowOrPull shared identical confirm-and-pull logic. Extract confirmAndPull to eliminate the duplication. * skip local existence checks for cloud models ModelExists and the TUI's modelExists both check the local model list, which causes cloud models to appear missing. Return true early for explicit cloud models so the TUI displays them beside the integration name and skips re-prompting the model picker on relaunch. * support optionally pulling stubs for newly-style names We now normalize names like `<family>:<size>:cloud` into legacy-style names like `<family>:<size>-cloud` for pulling and deleting (this also supports stripping `:local`). Support for pulling cloud models is temporary, once we integrate properly into `/api/tags` we won't need this anymore. * Fix server alias syncing * Update cmd/cmd.go Co-authored-by: Parth Sareen <parth.sareen@ollama.com> * address comments * improve some naming --------- Co-authored-by: ParthSareen <parth.sareen@ollama.com>
461 lines
11 KiB
Go
461 lines
11 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
"github.com/ollama/ollama/auth"
|
|
"github.com/ollama/ollama/envconfig"
|
|
internalcloud "github.com/ollama/ollama/internal/cloud"
|
|
)
|
|
|
|
const (
|
|
defaultCloudProxyBaseURL = "https://ollama.com:443"
|
|
defaultCloudProxySigningHost = "ollama.com"
|
|
cloudProxyBaseURLEnv = "OLLAMA_CLOUD_BASE_URL"
|
|
legacyCloudAnthropicKey = "legacy_cloud_anthropic_web_search"
|
|
)
|
|
|
|
var (
|
|
cloudProxyBaseURL = defaultCloudProxyBaseURL
|
|
cloudProxySigningHost = defaultCloudProxySigningHost
|
|
cloudProxySignRequest = signCloudProxyRequest
|
|
cloudProxySigninURL = signinURL
|
|
)
|
|
|
|
var hopByHopHeaders = map[string]struct{}{
|
|
"connection": {},
|
|
"content-length": {},
|
|
"proxy-connection": {},
|
|
"keep-alive": {},
|
|
"proxy-authenticate": {},
|
|
"proxy-authorization": {},
|
|
"te": {},
|
|
"trailer": {},
|
|
"transfer-encoding": {},
|
|
"upgrade": {},
|
|
}
|
|
|
|
func init() {
|
|
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL(envconfig.Var(cloudProxyBaseURLEnv), mode)
|
|
if err != nil {
|
|
slog.Warn("ignoring cloud base URL override", "env", cloudProxyBaseURLEnv, "error", err)
|
|
return
|
|
}
|
|
|
|
cloudProxyBaseURL = baseURL
|
|
cloudProxySigningHost = signingHost
|
|
|
|
if overridden {
|
|
slog.Info("cloud base URL override enabled", "env", cloudProxyBaseURLEnv, "url", cloudProxyBaseURL, "mode", mode)
|
|
}
|
|
}
|
|
|
|
func cloudPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if c.Request.Method != http.MethodPost {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
// TODO(drifkin): Avoid full-body buffering here for model detection.
|
|
// A future optimization can parse just enough JSON to read "model" (and
|
|
// optionally short-circuit cloud-disabled explicit-cloud requests) while
|
|
// preserving raw passthrough semantics.
|
|
body, err := readRequestBody(c.Request)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
model, ok := extractModelField(body)
|
|
if !ok {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
modelRef, err := parseAndValidateModelRef(model)
|
|
if err != nil || modelRef.Source != modelSourceCloud {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
normalizedBody, err := replaceJSONModelField(body, modelRef.Base)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// TEMP(drifkin): keep Anthropic web search requests on the local middleware
|
|
// path so WebSearchAnthropicWriter can orchestrate follow-up calls.
|
|
if c.Request.URL.Path == "/v1/messages" {
|
|
if hasAnthropicWebSearchTool(body) {
|
|
c.Set(legacyCloudAnthropicKey, true)
|
|
c.Next()
|
|
return
|
|
}
|
|
}
|
|
|
|
proxyCloudRequest(c, normalizedBody, disabledOperation)
|
|
c.Abort()
|
|
}
|
|
}
|
|
|
|
func cloudModelPathPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
modelName := strings.TrimSpace(c.Param("model"))
|
|
if modelName == "" {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
modelRef, err := parseAndValidateModelRef(modelName)
|
|
if err != nil || modelRef.Source != modelSourceCloud {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
proxyPath := "/v1/models/" + modelRef.Base
|
|
proxyCloudRequestWithPath(c, nil, proxyPath, disabledOperation)
|
|
c.Abort()
|
|
}
|
|
}
|
|
|
|
func proxyCloudJSONRequest(c *gin.Context, payload any, disabledOperation string) {
|
|
// TEMP(drifkin): we currently split out this `WithPath` method because we are
|
|
// mapping `/v1/messages` + web_search to `/api/chat` temporarily. Once we
|
|
// stop doing this, we can inline this method.
|
|
proxyCloudJSONRequestWithPath(c, payload, c.Request.URL.Path, disabledOperation)
|
|
}
|
|
|
|
func proxyCloudJSONRequestWithPath(c *gin.Context, payload any, path string, disabledOperation string) {
|
|
body, err := json.Marshal(payload)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
proxyCloudRequestWithPath(c, body, path, disabledOperation)
|
|
}
|
|
|
|
func proxyCloudRequest(c *gin.Context, body []byte, disabledOperation string) {
|
|
proxyCloudRequestWithPath(c, body, c.Request.URL.Path, disabledOperation)
|
|
}
|
|
|
|
func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disabledOperation string) {
|
|
if disabled, _ := internalcloud.Status(); disabled {
|
|
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(disabledOperation)})
|
|
return
|
|
}
|
|
|
|
baseURL, err := url.Parse(cloudProxyBaseURL)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
targetURL := baseURL.ResolveReference(&url.URL{
|
|
Path: path,
|
|
RawQuery: c.Request.URL.RawQuery,
|
|
})
|
|
|
|
outReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewReader(body))
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
copyProxyRequestHeaders(outReq.Header, c.Request.Header)
|
|
if outReq.Header.Get("Content-Type") == "" && len(body) > 0 {
|
|
outReq.Header.Set("Content-Type", "application/json")
|
|
}
|
|
|
|
if err := cloudProxySignRequest(outReq.Context(), outReq); err != nil {
|
|
slog.Warn("cloud proxy signing failed", "error", err)
|
|
writeCloudUnauthorized(c)
|
|
return
|
|
}
|
|
|
|
// TODO(drifkin): Add phase-specific proxy timeouts.
|
|
// Connect/TLS/TTFB should have bounded timeouts, but once streaming starts
|
|
// we should not enforce a short total timeout for long-lived responses.
|
|
resp, err := http.DefaultClient.Do(outReq)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
copyProxyResponseHeaders(c.Writer.Header(), resp.Header)
|
|
c.Status(resp.StatusCode)
|
|
|
|
if err := copyProxyResponseBody(c.Writer, resp.Body); err != nil {
|
|
c.Error(err) //nolint:errcheck
|
|
}
|
|
}
|
|
|
|
func replaceJSONModelField(body []byte, model string) ([]byte, error) {
|
|
if len(body) == 0 {
|
|
return body, nil
|
|
}
|
|
|
|
var payload map[string]json.RawMessage
|
|
if err := json.Unmarshal(body, &payload); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
modelJSON, err := json.Marshal(model)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
payload["model"] = modelJSON
|
|
|
|
return json.Marshal(payload)
|
|
}
|
|
|
|
func readRequestBody(r *http.Request) ([]byte, error) {
|
|
if r.Body == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
r.Body = io.NopCloser(bytes.NewReader(body))
|
|
return body, nil
|
|
}
|
|
|
|
func extractModelField(body []byte) (string, bool) {
|
|
if len(body) == 0 {
|
|
return "", false
|
|
}
|
|
|
|
var payload map[string]json.RawMessage
|
|
if err := json.Unmarshal(body, &payload); err != nil {
|
|
return "", false
|
|
}
|
|
|
|
raw, ok := payload["model"]
|
|
if !ok {
|
|
return "", false
|
|
}
|
|
|
|
var model string
|
|
if err := json.Unmarshal(raw, &model); err != nil {
|
|
return "", false
|
|
}
|
|
|
|
model = strings.TrimSpace(model)
|
|
return model, model != ""
|
|
}
|
|
|
|
func hasAnthropicWebSearchTool(body []byte) bool {
|
|
if len(body) == 0 {
|
|
return false
|
|
}
|
|
|
|
var payload struct {
|
|
Tools []struct {
|
|
Type string `json:"type"`
|
|
} `json:"tools"`
|
|
}
|
|
if err := json.Unmarshal(body, &payload); err != nil {
|
|
return false
|
|
}
|
|
|
|
for _, tool := range payload.Tools {
|
|
if strings.HasPrefix(strings.TrimSpace(tool.Type), "web_search") {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func writeCloudUnauthorized(c *gin.Context) {
|
|
signinURL, err := cloudProxySigninURL()
|
|
if err != nil {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": signinURL})
|
|
}
|
|
|
|
func signCloudProxyRequest(ctx context.Context, req *http.Request) error {
|
|
if !strings.EqualFold(req.URL.Hostname(), cloudProxySigningHost) {
|
|
return nil
|
|
}
|
|
|
|
ts := strconv.FormatInt(time.Now().Unix(), 10)
|
|
challenge := buildCloudSignatureChallenge(req, ts)
|
|
signature, err := auth.Sign(ctx, []byte(challenge))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req.Header.Set("Authorization", signature)
|
|
return nil
|
|
}
|
|
|
|
func buildCloudSignatureChallenge(req *http.Request, ts string) string {
|
|
query := req.URL.Query()
|
|
query.Set("ts", ts)
|
|
req.URL.RawQuery = query.Encode()
|
|
|
|
return fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI())
|
|
}
|
|
|
|
func resolveCloudProxyBaseURL(rawOverride string, runMode string) (baseURL string, signingHost string, overridden bool, err error) {
|
|
baseURL = defaultCloudProxyBaseURL
|
|
signingHost = defaultCloudProxySigningHost
|
|
|
|
rawOverride = strings.TrimSpace(rawOverride)
|
|
if rawOverride == "" {
|
|
return baseURL, signingHost, false, nil
|
|
}
|
|
|
|
u, err := url.Parse(rawOverride)
|
|
if err != nil {
|
|
return "", "", false, fmt.Errorf("invalid URL: %w", err)
|
|
}
|
|
if u.Scheme == "" || u.Host == "" {
|
|
return "", "", false, fmt.Errorf("invalid URL: scheme and host are required")
|
|
}
|
|
if u.User != nil {
|
|
return "", "", false, fmt.Errorf("invalid URL: userinfo is not allowed")
|
|
}
|
|
if u.Path != "" && u.Path != "/" {
|
|
return "", "", false, fmt.Errorf("invalid URL: path is not allowed")
|
|
}
|
|
if u.RawQuery != "" || u.Fragment != "" {
|
|
return "", "", false, fmt.Errorf("invalid URL: query and fragment are not allowed")
|
|
}
|
|
|
|
host := u.Hostname()
|
|
if host == "" {
|
|
return "", "", false, fmt.Errorf("invalid URL: host is required")
|
|
}
|
|
|
|
loopback := isLoopbackHost(host)
|
|
if runMode == gin.ReleaseMode && !loopback {
|
|
return "", "", false, fmt.Errorf("non-loopback cloud override is not allowed in release mode")
|
|
}
|
|
if !loopback && !strings.EqualFold(u.Scheme, "https") {
|
|
return "", "", false, fmt.Errorf("non-loopback cloud override must use https")
|
|
}
|
|
|
|
u.Path = ""
|
|
u.RawPath = ""
|
|
u.RawQuery = ""
|
|
u.Fragment = ""
|
|
|
|
return u.String(), strings.ToLower(host), true, nil
|
|
}
|
|
|
|
func isLoopbackHost(host string) bool {
|
|
if strings.EqualFold(host, "localhost") {
|
|
return true
|
|
}
|
|
|
|
ip := net.ParseIP(host)
|
|
return ip != nil && ip.IsLoopback()
|
|
}
|
|
|
|
func copyProxyRequestHeaders(dst, src http.Header) {
|
|
connectionTokens := connectionHeaderTokens(src)
|
|
for key, values := range src {
|
|
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
|
|
continue
|
|
}
|
|
|
|
dst.Del(key)
|
|
for _, value := range values {
|
|
dst.Add(key, value)
|
|
}
|
|
}
|
|
}
|
|
|
|
func copyProxyResponseHeaders(dst, src http.Header) {
|
|
connectionTokens := connectionHeaderTokens(src)
|
|
for key, values := range src {
|
|
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
|
|
continue
|
|
}
|
|
|
|
dst.Del(key)
|
|
for _, value := range values {
|
|
dst.Add(key, value)
|
|
}
|
|
}
|
|
}
|
|
|
|
func copyProxyResponseBody(dst http.ResponseWriter, src io.Reader) error {
|
|
flusher, canFlush := dst.(http.Flusher)
|
|
buf := make([]byte, 32*1024)
|
|
|
|
for {
|
|
n, err := src.Read(buf)
|
|
if n > 0 {
|
|
if _, writeErr := dst.Write(buf[:n]); writeErr != nil {
|
|
return writeErr
|
|
}
|
|
if canFlush {
|
|
// TODO(drifkin): Consider conditional flushing so non-streaming
|
|
// responses don't flush every write and can optimize throughput.
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
func isHopByHopHeader(name string) bool {
|
|
_, ok := hopByHopHeaders[strings.ToLower(name)]
|
|
return ok
|
|
}
|
|
|
|
func connectionHeaderTokens(header http.Header) map[string]struct{} {
|
|
tokens := map[string]struct{}{}
|
|
for _, raw := range header.Values("Connection") {
|
|
for _, token := range strings.Split(raw, ",") {
|
|
token = strings.TrimSpace(strings.ToLower(token))
|
|
if token == "" {
|
|
continue
|
|
}
|
|
tokens[token] = struct{}{}
|
|
}
|
|
}
|
|
return tokens
|
|
}
|
|
|
|
func isConnectionTokenHeader(name string, tokens map[string]struct{}) bool {
|
|
if len(tokens) == 0 {
|
|
return false
|
|
}
|
|
_, ok := tokens[strings.ToLower(name)]
|
|
return ok
|
|
}
|