mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 21:54:08 +02:00
server: add experimental web search and web fetch routes (#14753)
This commit is contained in:
@@ -62,6 +62,8 @@ const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||
const (
|
||||
cloudErrRemoteInferenceUnavailable = "remote model is unavailable"
|
||||
cloudErrRemoteModelDetailsUnavailable = "remote model details are unavailable"
|
||||
cloudErrWebSearchUnavailable = "web search is unavailable"
|
||||
cloudErrWebFetchUnavailable = "web fetch is unavailable"
|
||||
)
|
||||
|
||||
func writeModelRefParseError(c *gin.Context, err error, fallbackStatus int, fallbackMessage string) {
|
||||
@@ -1693,6 +1695,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.GET("/api/experimental/aliases", s.ListAliasesHandler)
|
||||
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
|
||||
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
|
||||
r.POST("/api/experimental/web_search", s.WebSearchExperimentalHandler)
|
||||
r.POST("/api/experimental/web_fetch", s.WebFetchExperimentalHandler)
|
||||
|
||||
// Inference
|
||||
r.GET("/api/ps", s.PsHandler)
|
||||
@@ -1937,6 +1941,29 @@ func (s *Server) StatusHandler(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) WebSearchExperimentalHandler(c *gin.Context) {
|
||||
s.webExperimentalProxyHandler(c, "/api/web_search", cloudErrWebSearchUnavailable)
|
||||
}
|
||||
|
||||
func (s *Server) WebFetchExperimentalHandler(c *gin.Context) {
|
||||
s.webExperimentalProxyHandler(c, "/api/web_fetch", cloudErrWebFetchUnavailable)
|
||||
}
|
||||
|
||||
func (s *Server) webExperimentalProxyHandler(c *gin.Context, proxyPath, disabledOperation string) {
|
||||
body, err := readRequestBody(c.Request)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if len(bytes.TrimSpace(body)) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
}
|
||||
|
||||
proxyCloudRequestWithPath(c, body, proxyPath, disabledOperation)
|
||||
}
|
||||
|
||||
func (s *Server) WhoamiHandler(c *gin.Context) {
|
||||
// todo allow other hosts
|
||||
u, err := url.Parse("https://ollama.com")
|
||||
|
||||
335
server/routes_web_experimental_test.go
Normal file
335
server/routes_web_experimental_test.go
Normal file
@@ -0,0 +1,335 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
)
|
||||
|
||||
type webExperimentalUpstreamCapture struct {
|
||||
path string
|
||||
body string
|
||||
header http.Header
|
||||
}
|
||||
|
||||
func newWebExperimentalUpstream(t *testing.T, responseBody string) (*httptest.Server, *webExperimentalUpstreamCapture) {
|
||||
t.Helper()
|
||||
|
||||
capture := &webExperimentalUpstreamCapture{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
payload, _ := io.ReadAll(r.Body)
|
||||
capture.path = r.URL.Path
|
||||
capture.body = string(payload)
|
||||
capture.header = r.Header.Clone()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(responseBody))
|
||||
}))
|
||||
|
||||
return srv, capture
|
||||
}
|
||||
|
||||
func TestExperimentalWebEndpointsPassthrough(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
localPath string
|
||||
upstreamPath string
|
||||
requestBody string
|
||||
responseBody string
|
||||
assertBody string
|
||||
}{
|
||||
{
|
||||
name: "web_search",
|
||||
localPath: "/api/experimental/web_search",
|
||||
upstreamPath: "/api/web_search",
|
||||
requestBody: `{"query":"what is ollama?","max_results":3}`,
|
||||
responseBody: `{"results":[{"title":"Ollama","url":"https://ollama.com","content":"Cloud models are now available"}]}`,
|
||||
assertBody: `"query":"what is ollama?"`,
|
||||
},
|
||||
{
|
||||
name: "web_fetch",
|
||||
localPath: "/api/experimental/web_fetch",
|
||||
upstreamPath: "/api/web_fetch",
|
||||
requestBody: `{"url":"https://ollama.com"}`,
|
||||
responseBody: `{"title":"Ollama","content":"Cloud models are now available","links":["https://ollama.com/"]}`,
|
||||
assertBody: `"url":"https://ollama.com"`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
upstream, capture := newWebExperimentalUpstream(t, tt.responseBody)
|
||||
defer upstream.Close()
|
||||
|
||||
original := cloudProxyBaseURL
|
||||
cloudProxyBaseURL = upstream.URL
|
||||
t.Cleanup(func() { cloudProxyBaseURL = original })
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
local := httptest.NewServer(router)
|
||||
defer local.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+tt.localPath, bytes.NewBufferString(tt.requestBody))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer should-forward")
|
||||
req.Header.Set("X-Test-Header", "web-experimental")
|
||||
|
||||
resp, err := local.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body))
|
||||
}
|
||||
if capture.path != tt.upstreamPath {
|
||||
t.Fatalf("expected upstream path %q, got %q", tt.upstreamPath, capture.path)
|
||||
}
|
||||
if !bytes.Contains([]byte(capture.body), []byte(tt.assertBody)) {
|
||||
t.Fatalf("expected upstream body to contain %q, got %q", tt.assertBody, capture.body)
|
||||
}
|
||||
if got := capture.header.Get("Authorization"); got != "Bearer should-forward" {
|
||||
t.Fatalf("expected forwarded Authorization header, got %q", got)
|
||||
}
|
||||
if got := capture.header.Get("X-Test-Header"); got != "web-experimental" {
|
||||
t.Fatalf("expected forwarded X-Test-Header=web-experimental, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExperimentalWebEndpointsMissingBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
local := httptest.NewServer(router)
|
||||
defer local.Close()
|
||||
|
||||
tests := []string{
|
||||
"/api/experimental/web_search",
|
||||
"/api/experimental/web_fetch",
|
||||
}
|
||||
|
||||
for _, path := range tests {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+path, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := local.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected status 400, got %d (%s)", resp.StatusCode, string(body))
|
||||
}
|
||||
if string(body) != `{"error":"missing request body"}` {
|
||||
t.Fatalf("unexpected response body: %s", string(body))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExperimentalWebEndpointsCloudDisabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
t.Setenv("OLLAMA_NO_CLOUD", "1")
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
local := httptest.NewServer(router)
|
||||
defer local.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
request string
|
||||
operation string
|
||||
}{
|
||||
{
|
||||
name: "web_search",
|
||||
path: "/api/experimental/web_search",
|
||||
request: `{"query":"latest ollama release"}`,
|
||||
operation: cloudErrWebSearchUnavailable,
|
||||
},
|
||||
{
|
||||
name: "web_fetch",
|
||||
path: "/api/experimental/web_fetch",
|
||||
request: `{"url":"https://ollama.com"}`,
|
||||
operation: cloudErrWebFetchUnavailable,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+tt.path, bytes.NewBufferString(tt.request))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := local.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Fatalf("expected status 403, got %d (%s)", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var got map[string]string
|
||||
if err := json.Unmarshal(body, &got); err != nil {
|
||||
t.Fatalf("expected json error body, got: %q", string(body))
|
||||
}
|
||||
if got["error"] != internalcloud.DisabledError(tt.operation) {
|
||||
t.Fatalf("unexpected error message: %q", got["error"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExperimentalWebEndpointSigningFailureReturnsUnauthorized(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
origSignRequest := cloudProxySignRequest
|
||||
origSigninURL := cloudProxySigninURL
|
||||
cloudProxySignRequest = func(context.Context, *http.Request) error {
|
||||
return errors.New("ssh: no key found")
|
||||
}
|
||||
cloudProxySigninURL = func() (string, error) {
|
||||
return "https://ollama.com/signin/example", nil
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
cloudProxySignRequest = origSignRequest
|
||||
cloudProxySigninURL = origSigninURL
|
||||
})
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
local := httptest.NewServer(router)
|
||||
defer local.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/experimental/web_search", bytes.NewBufferString(`{"query":"hello"}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := local.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Fatalf("expected status 401, got %d (%s)", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(body, &got); err != nil {
|
||||
t.Fatalf("expected json error body, got: %q", string(body))
|
||||
}
|
||||
if got["error"] != "unauthorized" {
|
||||
t.Fatalf("unexpected error message: %v", got["error"])
|
||||
}
|
||||
if got["signin_url"] != "https://ollama.com/signin/example" {
|
||||
t.Fatalf("unexpected signin_url: %v", got["signin_url"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestExperimentalWebEndpointSigningFailureWithoutSigninURL(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setTestHome(t, t.TempDir())
|
||||
|
||||
origSignRequest := cloudProxySignRequest
|
||||
origSigninURL := cloudProxySigninURL
|
||||
cloudProxySignRequest = func(context.Context, *http.Request) error {
|
||||
return errors.New("ssh: no key found")
|
||||
}
|
||||
cloudProxySigninURL = func() (string, error) {
|
||||
return "", errors.New("key missing")
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
cloudProxySignRequest = origSignRequest
|
||||
cloudProxySigninURL = origSigninURL
|
||||
})
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
local := httptest.NewServer(router)
|
||||
defer local.Close()
|
||||
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/experimental/web_fetch", bytes.NewBufferString(`{"url":"https://ollama.com"}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := local.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Fatalf("expected status 401, got %d (%s)", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(body, &got); err != nil {
|
||||
t.Fatalf("expected json error body, got: %q", string(body))
|
||||
}
|
||||
if got["error"] != "unauthorized" {
|
||||
t.Fatalf("unexpected error message: %v", got["error"])
|
||||
}
|
||||
if _, ok := got["signin_url"]; ok {
|
||||
t.Fatalf("did not expect signin_url when helper fails, got %v", got["signin_url"])
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user