mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 21:54:08 +02:00
It's possible to have a local model with a `:cloud` suffix. If the cloud responds with a 404, then we now fall back to looking for it locally before returning to the client. The cloud proxy has been slightly refactored so parts of it can be reused here (in this flow we inspect the response, so it's no longer a pure passthrough)
580 lines
15 KiB
Go
580 lines
15 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/klauspost/compress/zstd"
|
|
|
|
"github.com/ollama/ollama/auth"
|
|
"github.com/ollama/ollama/envconfig"
|
|
internalcloud "github.com/ollama/ollama/internal/cloud"
|
|
"github.com/ollama/ollama/version"
|
|
)
|
|
|
|
const (
|
|
defaultCloudProxyBaseURL = "https://ollama.com:443"
|
|
defaultCloudProxySigningHost = "ollama.com"
|
|
cloudProxyBaseURLEnv = "OLLAMA_CLOUD_BASE_URL"
|
|
legacyCloudAnthropicKey = "legacy_cloud_anthropic_web_search"
|
|
cloudProxyClientVersionHeader = "X-Ollama-Client-Version"
|
|
|
|
// maxDecompressedBodySize limits the size of a decompressed request body
|
|
maxDecompressedBodySize = 20 << 20
|
|
)
|
|
|
|
var (
|
|
cloudProxyBaseURL = defaultCloudProxyBaseURL
|
|
cloudProxySigningHost = defaultCloudProxySigningHost
|
|
cloudProxySignRequest = signCloudProxyRequest
|
|
cloudProxySigninURL = signinURL
|
|
)
|
|
|
|
var hopByHopHeaders = map[string]struct{}{
|
|
"connection": {},
|
|
"content-length": {},
|
|
"proxy-connection": {},
|
|
"keep-alive": {},
|
|
"proxy-authenticate": {},
|
|
"proxy-authorization": {},
|
|
"te": {},
|
|
"trailer": {},
|
|
"transfer-encoding": {},
|
|
"upgrade": {},
|
|
}
|
|
|
|
func init() {
|
|
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL(envconfig.Var(cloudProxyBaseURLEnv), mode)
|
|
if err != nil {
|
|
slog.Warn("ignoring cloud base URL override", "env", cloudProxyBaseURLEnv, "error", err)
|
|
return
|
|
}
|
|
|
|
cloudProxyBaseURL = baseURL
|
|
cloudProxySigningHost = signingHost
|
|
|
|
if overridden {
|
|
slog.Info("cloud base URL override enabled", "env", cloudProxyBaseURLEnv, "url", cloudProxyBaseURL, "mode", mode)
|
|
}
|
|
}
|
|
|
|
func cloudPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if c.Request.Method != http.MethodPost {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
// Decompress zstd-encoded request bodies so we can inspect the model
|
|
if c.GetHeader("Content-Encoding") == "zstd" {
|
|
reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20))
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to decompress request body"})
|
|
c.Abort()
|
|
return
|
|
}
|
|
defer reader.Close()
|
|
c.Request.Body = http.MaxBytesReader(c.Writer, io.NopCloser(reader), maxDecompressedBodySize)
|
|
c.Request.Header.Del("Content-Encoding")
|
|
}
|
|
|
|
// TODO(drifkin): Avoid full-body buffering here for model detection.
|
|
// A future optimization can parse just enough JSON to read "model" (and
|
|
// optionally short-circuit cloud-disabled explicit-cloud requests) while
|
|
// preserving raw passthrough semantics.
|
|
body, err := readRequestBody(c.Request)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
model, ok := extractModelField(body)
|
|
if !ok {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
modelRef, err := parseAndValidateModelRef(model)
|
|
if err != nil || modelRef.Source != modelSourceCloud {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
normalizedBody, err := replaceJSONModelField(body, modelRef.Base)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// TEMP(drifkin): keep Anthropic web search requests on the local middleware
|
|
// path so WebSearchAnthropicWriter can orchestrate follow-up calls.
|
|
if c.Request.URL.Path == "/v1/messages" {
|
|
if hasAnthropicWebSearchTool(body) {
|
|
c.Set(legacyCloudAnthropicKey, true)
|
|
c.Next()
|
|
return
|
|
}
|
|
}
|
|
|
|
proxyCloudRequest(c, normalizedBody, disabledOperation)
|
|
c.Abort()
|
|
}
|
|
}
|
|
|
|
func cloudModelPathPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
modelName := strings.TrimSpace(c.Param("model"))
|
|
if modelName == "" {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
modelRef, err := parseAndValidateModelRef(modelName)
|
|
if err != nil || modelRef.Source != modelSourceCloud {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
proxyPath := "/v1/models/" + modelRef.Base
|
|
proxyCloudRequestWithPath(c, nil, proxyPath, disabledOperation)
|
|
c.Abort()
|
|
}
|
|
}
|
|
|
|
func proxyCloudJSONRequest(c *gin.Context, payload any, disabledOperation string) {
|
|
// TEMP(drifkin): we currently split out this `WithPath` method because we are
|
|
// mapping `/v1/messages` + web_search to `/api/chat` temporarily. Once we
|
|
// stop doing this, we can inline this method.
|
|
proxyCloudJSONRequestWithPath(c, payload, c.Request.URL.Path, disabledOperation)
|
|
}
|
|
|
|
func proxyCloudJSONRequestWithPath(c *gin.Context, payload any, path string, disabledOperation string) {
|
|
body, err := json.Marshal(payload)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
proxyCloudRequestWithPath(c, body, path, disabledOperation)
|
|
}
|
|
|
|
func proxyCloudRequest(c *gin.Context, body []byte, disabledOperation string) {
|
|
proxyCloudRequestWithPath(c, body, c.Request.URL.Path, disabledOperation)
|
|
}
|
|
|
|
func buildCloudProxyRequest(c *gin.Context, path string, body []byte) (*http.Request, error) {
|
|
baseURL, err := url.Parse(cloudProxyBaseURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
targetURL := baseURL.ResolveReference(&url.URL{
|
|
Path: path,
|
|
RawQuery: c.Request.URL.RawQuery,
|
|
})
|
|
|
|
outReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
copyProxyRequestHeaders(outReq.Header, c.Request.Header)
|
|
if clientVersion := strings.TrimSpace(version.Version); clientVersion != "" {
|
|
outReq.Header.Set(cloudProxyClientVersionHeader, clientVersion)
|
|
}
|
|
if outReq.Header.Get("Content-Type") == "" && len(body) > 0 {
|
|
outReq.Header.Set("Content-Type", "application/json")
|
|
}
|
|
|
|
return outReq, nil
|
|
}
|
|
|
|
func writeCloudProxyResponse(c *gin.Context, path string, resp *http.Response) {
|
|
copyProxyResponseHeaders(c.Writer.Header(), resp.Header)
|
|
c.Status(resp.StatusCode)
|
|
|
|
var bodyWriter http.ResponseWriter = c.Writer
|
|
var framedWriter *jsonlFramingResponseWriter
|
|
// TEMP(drifkin): only needed on the cloud-proxied first leg of Anthropic
|
|
// web_search fallback (which is a path we're removing soon). Local
|
|
// /v1/messages writes one JSON value per streamResponse callback directly
|
|
// into WebSearchAnthropicWriter, but this proxy copy loop may coalesce
|
|
// multiple jsonl records into one Write. WebSearchAnthropicWriter currently
|
|
// unmarshals one JSON value per Write.
|
|
if path == "/api/chat" && resp.StatusCode == http.StatusOK && c.GetBool(legacyCloudAnthropicKey) {
|
|
framedWriter = &jsonlFramingResponseWriter{ResponseWriter: c.Writer}
|
|
bodyWriter = framedWriter
|
|
}
|
|
|
|
err := copyProxyResponseBody(bodyWriter, resp.Body)
|
|
if err == nil && framedWriter != nil {
|
|
err = framedWriter.FlushPending()
|
|
}
|
|
if err != nil {
|
|
ctxErr := c.Request.Context().Err()
|
|
if errors.Is(err, context.Canceled) && errors.Is(ctxErr, context.Canceled) {
|
|
slog.Debug(
|
|
"cloud proxy response stream closed by client",
|
|
"path", c.Request.URL.Path,
|
|
"status", resp.StatusCode,
|
|
)
|
|
return
|
|
}
|
|
|
|
slog.Warn(
|
|
"cloud proxy response copy failed",
|
|
"path", c.Request.URL.Path,
|
|
"upstream_path", path,
|
|
"status", resp.StatusCode,
|
|
"request_context_canceled", ctxErr != nil,
|
|
"request_context_err", ctxErr,
|
|
"error", err,
|
|
)
|
|
}
|
|
}
|
|
|
|
func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disabledOperation string) {
|
|
if disabled, _ := internalcloud.Status(); disabled {
|
|
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(disabledOperation)})
|
|
return
|
|
}
|
|
|
|
outReq, err := buildCloudProxyRequest(c, path, body)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if err := cloudProxySignRequest(outReq.Context(), outReq); err != nil {
|
|
slog.Warn("cloud proxy signing failed", "error", err)
|
|
writeCloudUnauthorized(c)
|
|
return
|
|
}
|
|
|
|
// TODO(drifkin): Add phase-specific proxy timeouts.
|
|
// Connect/TLS/TTFB should have bounded timeouts, but once streaming starts
|
|
// we should not enforce a short total timeout for long-lived responses.
|
|
resp, err := http.DefaultClient.Do(outReq)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
writeCloudProxyResponse(c, path, resp)
|
|
}
|
|
|
|
func replaceJSONModelField(body []byte, model string) ([]byte, error) {
|
|
if len(body) == 0 {
|
|
return body, nil
|
|
}
|
|
|
|
var payload map[string]json.RawMessage
|
|
if err := json.Unmarshal(body, &payload); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
modelJSON, err := json.Marshal(model)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
payload["model"] = modelJSON
|
|
|
|
return json.Marshal(payload)
|
|
}
|
|
|
|
func readRequestBody(r *http.Request) ([]byte, error) {
|
|
if r.Body == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
r.Body = io.NopCloser(bytes.NewReader(body))
|
|
return body, nil
|
|
}
|
|
|
|
func extractModelField(body []byte) (string, bool) {
|
|
if len(body) == 0 {
|
|
return "", false
|
|
}
|
|
|
|
var payload map[string]json.RawMessage
|
|
if err := json.Unmarshal(body, &payload); err != nil {
|
|
return "", false
|
|
}
|
|
|
|
raw, ok := payload["model"]
|
|
if !ok {
|
|
return "", false
|
|
}
|
|
|
|
var model string
|
|
if err := json.Unmarshal(raw, &model); err != nil {
|
|
return "", false
|
|
}
|
|
|
|
model = strings.TrimSpace(model)
|
|
return model, model != ""
|
|
}
|
|
|
|
func hasAnthropicWebSearchTool(body []byte) bool {
|
|
if len(body) == 0 {
|
|
return false
|
|
}
|
|
|
|
var payload struct {
|
|
Tools []struct {
|
|
Type string `json:"type"`
|
|
} `json:"tools"`
|
|
}
|
|
if err := json.Unmarshal(body, &payload); err != nil {
|
|
return false
|
|
}
|
|
|
|
for _, tool := range payload.Tools {
|
|
if strings.HasPrefix(strings.TrimSpace(tool.Type), "web_search") {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func writeCloudUnauthorized(c *gin.Context) {
|
|
signinURL, err := cloudProxySigninURL()
|
|
if err != nil {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": signinURL})
|
|
}
|
|
|
|
func signCloudProxyRequest(ctx context.Context, req *http.Request) error {
|
|
if !strings.EqualFold(req.URL.Hostname(), cloudProxySigningHost) {
|
|
return nil
|
|
}
|
|
|
|
ts := strconv.FormatInt(time.Now().Unix(), 10)
|
|
challenge := buildCloudSignatureChallenge(req, ts)
|
|
signature, err := auth.Sign(ctx, []byte(challenge))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req.Header.Set("Authorization", signature)
|
|
return nil
|
|
}
|
|
|
|
func buildCloudSignatureChallenge(req *http.Request, ts string) string {
|
|
query := req.URL.Query()
|
|
query.Set("ts", ts)
|
|
req.URL.RawQuery = query.Encode()
|
|
|
|
return fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI())
|
|
}
|
|
|
|
func resolveCloudProxyBaseURL(rawOverride string, runMode string) (baseURL string, signingHost string, overridden bool, err error) {
|
|
baseURL = defaultCloudProxyBaseURL
|
|
signingHost = defaultCloudProxySigningHost
|
|
|
|
rawOverride = strings.TrimSpace(rawOverride)
|
|
if rawOverride == "" {
|
|
return baseURL, signingHost, false, nil
|
|
}
|
|
|
|
u, err := url.Parse(rawOverride)
|
|
if err != nil {
|
|
return "", "", false, fmt.Errorf("invalid URL: %w", err)
|
|
}
|
|
if u.Scheme == "" || u.Host == "" {
|
|
return "", "", false, fmt.Errorf("invalid URL: scheme and host are required")
|
|
}
|
|
if u.User != nil {
|
|
return "", "", false, fmt.Errorf("invalid URL: userinfo is not allowed")
|
|
}
|
|
if u.Path != "" && u.Path != "/" {
|
|
return "", "", false, fmt.Errorf("invalid URL: path is not allowed")
|
|
}
|
|
if u.RawQuery != "" || u.Fragment != "" {
|
|
return "", "", false, fmt.Errorf("invalid URL: query and fragment are not allowed")
|
|
}
|
|
|
|
host := u.Hostname()
|
|
if host == "" {
|
|
return "", "", false, fmt.Errorf("invalid URL: host is required")
|
|
}
|
|
|
|
loopback := isLoopbackHost(host)
|
|
if runMode == gin.ReleaseMode && !loopback {
|
|
return "", "", false, fmt.Errorf("non-loopback cloud override is not allowed in release mode")
|
|
}
|
|
if !loopback && !strings.EqualFold(u.Scheme, "https") {
|
|
return "", "", false, fmt.Errorf("non-loopback cloud override must use https")
|
|
}
|
|
|
|
u.Path = ""
|
|
u.RawPath = ""
|
|
u.RawQuery = ""
|
|
u.Fragment = ""
|
|
|
|
return u.String(), strings.ToLower(host), true, nil
|
|
}
|
|
|
|
func isLoopbackHost(host string) bool {
|
|
if strings.EqualFold(host, "localhost") {
|
|
return true
|
|
}
|
|
|
|
ip := net.ParseIP(host)
|
|
return ip != nil && ip.IsLoopback()
|
|
}
|
|
|
|
func copyProxyRequestHeaders(dst, src http.Header) {
|
|
connectionTokens := connectionHeaderTokens(src)
|
|
for key, values := range src {
|
|
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
|
|
continue
|
|
}
|
|
|
|
dst.Del(key)
|
|
for _, value := range values {
|
|
dst.Add(key, value)
|
|
}
|
|
}
|
|
}
|
|
|
|
func copyProxyResponseHeaders(dst, src http.Header) {
|
|
connectionTokens := connectionHeaderTokens(src)
|
|
for key, values := range src {
|
|
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
|
|
continue
|
|
}
|
|
|
|
dst.Del(key)
|
|
for _, value := range values {
|
|
dst.Add(key, value)
|
|
}
|
|
}
|
|
}
|
|
|
|
func copyProxyResponseBody(dst http.ResponseWriter, src io.Reader) error {
|
|
flusher, canFlush := dst.(http.Flusher)
|
|
buf := make([]byte, 32*1024)
|
|
|
|
for {
|
|
n, err := src.Read(buf)
|
|
if n > 0 {
|
|
if _, writeErr := dst.Write(buf[:n]); writeErr != nil {
|
|
return writeErr
|
|
}
|
|
if canFlush {
|
|
// TODO(drifkin): Consider conditional flushing so non-streaming
|
|
// responses don't flush every write and can optimize throughput.
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
type jsonlFramingResponseWriter struct {
|
|
http.ResponseWriter
|
|
pending []byte
|
|
}
|
|
|
|
func (w *jsonlFramingResponseWriter) Flush() {
|
|
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
|
|
func (w *jsonlFramingResponseWriter) Write(p []byte) (int, error) {
|
|
w.pending = append(w.pending, p...)
|
|
if err := w.flushCompleteLines(); err != nil {
|
|
return len(p), err
|
|
}
|
|
return len(p), nil
|
|
}
|
|
|
|
func (w *jsonlFramingResponseWriter) FlushPending() error {
|
|
trailing := bytes.TrimSpace(w.pending)
|
|
w.pending = nil
|
|
if len(trailing) == 0 {
|
|
return nil
|
|
}
|
|
|
|
_, err := w.ResponseWriter.Write(trailing)
|
|
return err
|
|
}
|
|
|
|
func (w *jsonlFramingResponseWriter) flushCompleteLines() error {
|
|
for {
|
|
newline := bytes.IndexByte(w.pending, '\n')
|
|
if newline < 0 {
|
|
return nil
|
|
}
|
|
|
|
line := bytes.TrimSpace(w.pending[:newline])
|
|
w.pending = w.pending[newline+1:]
|
|
if len(line) == 0 {
|
|
continue
|
|
}
|
|
|
|
if _, err := w.ResponseWriter.Write(line); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
func isHopByHopHeader(name string) bool {
|
|
_, ok := hopByHopHeaders[strings.ToLower(name)]
|
|
return ok
|
|
}
|
|
|
|
func connectionHeaderTokens(header http.Header) map[string]struct{} {
|
|
tokens := map[string]struct{}{}
|
|
for _, raw := range header.Values("Connection") {
|
|
for _, token := range strings.Split(raw, ",") {
|
|
token = strings.TrimSpace(strings.ToLower(token))
|
|
if token == "" {
|
|
continue
|
|
}
|
|
tokens[token] = struct{}{}
|
|
}
|
|
}
|
|
return tokens
|
|
}
|
|
|
|
func isConnectionTokenHeader(name string, tokens map[string]struct{}) bool {
|
|
if len(tokens) == 0 {
|
|
return false
|
|
}
|
|
_, ok := tokens[strings.ToLower(name)]
|
|
return ok
|
|
}
|