Files
ollama/server/routes_test.go
2026-04-23 17:03:03 -07:00

1423 lines
39 KiB
Go

package server
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"io/fs"
"math"
"math/rand/v2"
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"reflect"
"slices"
"sort"
"strings"
"testing"
"time"
"unicode"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
func TestPsHandlerUsesRunningManifestAndRunner(t *testing.T) {
gin.SetMode(gin.TestMode)
childDigest := strings.Repeat("a", 64)
s := Server{
sched: &Scheduler{
loaded: map[string]*runnerRef{
"test": {
model: &Model{
ShortName: "test-model:latest",
Digest: strings.Repeat("b", 64),
ManifestDigest: childDigest,
Runner: manifest.RunnerMLX,
Config: model.ConfigV2{
ModelFormat: manifest.FormatSafetensors,
},
},
runner: manifest.RunnerMLX,
totalSize: 1024,
vramSize: 1024,
expiresAt: time.Now().Add(time.Hour),
sessionDuration: time.Hour,
},
},
},
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/api/ps", nil)
s.PsHandler(c)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d: %s", w.Code, http.StatusOK, w.Body.String())
}
var resp api.ProcessResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if len(resp.Models) != 1 {
t.Fatalf("model count = %d, want 1", len(resp.Models))
}
if resp.Models[0].Digest != childDigest {
t.Fatalf("digest = %q, want child digest %q", resp.Models[0].Digest, childDigest)
}
if resp.Models[0].Runner != manifest.RunnerMLX {
t.Fatalf("runner = %q, want %q", resp.Models[0].Runner, manifest.RunnerMLX)
}
}
func createTestFile(t *testing.T, name string) (string, string) {
t.Helper()
modelDir := os.Getenv("OLLAMA_MODELS")
if modelDir == "" {
t.Fatalf("OLLAMA_MODELS not specified")
}
f, err := os.CreateTemp(t.TempDir(), name)
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
defer f.Close()
err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
if err != nil {
t.Fatalf("failed to write to file: %v", err)
}
err = binary.Write(f, binary.LittleEndian, uint32(3))
if err != nil {
t.Fatalf("failed to write to file: %v", err)
}
err = binary.Write(f, binary.LittleEndian, uint64(0))
if err != nil {
t.Fatalf("failed to write to file: %v", err)
}
err = binary.Write(f, binary.LittleEndian, uint64(0))
if err != nil {
t.Fatalf("failed to write to file: %v", err)
}
// Calculate sha256 sum of file
if _, err := f.Seek(0, 0); err != nil {
t.Fatal(err)
}
digest, _ := GetSHA256Digest(f)
if err := f.Close(); err != nil {
t.Fatal(err)
}
if err := createLink(f.Name(), filepath.Join(modelDir, "blobs", fmt.Sprintf("sha256-%s", strings.TrimPrefix(digest, "sha256:")))); err != nil {
t.Fatal(err)
}
return f.Name(), digest
}
type panicTransport struct{}
func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
panic("unexpected RoundTrip call")
}
var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
func TestRoutes(t *testing.T) {
type testCase struct {
Name string
Method string
Path string
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, resp *http.Response)
}
createTestModel := func(t *testing.T, name string) {
t.Helper()
_, digest := createTestFile(t, "ollama-model")
fn := func(resp api.ProgressResponse) {
t.Logf("Status: %s", resp.Status)
}
r := api.CreateRequest{
Name: name,
Files: map[string]string{"test.gguf": digest},
Parameters: map[string]any{
"seed": 42,
"top_p": 0.9,
"stop": []string{"foo", "bar"},
},
}
modelName := model.ParseName(name)
baseLayers, err := ggufLayers(digest, fn)
if err != nil {
t.Fatalf("failed to create model: %v", err)
}
config := &model.ConfigV2{
OS: "linux",
Architecture: "amd64",
RootFS: model.RootFS{
Type: "layers",
},
}
if err := createModel(r, modelName, baseLayers, config, fn); err != nil {
t.Fatal(err)
}
}
testCases := []testCase{
{
Name: "Version Handler",
Method: http.MethodGet,
Path: "/api/version",
Setup: func(t *testing.T, req *http.Request) {
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
expectedBody := fmt.Sprintf(`{"version":"%s"}`, version.Version)
if string(body) != expectedBody {
t.Errorf("expected body %s, got %s", expectedBody, string(body))
}
},
},
{
Name: "Tags Handler (no tags)",
Method: http.MethodGet,
Path: "/api/tags",
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
var modelList api.ListResponse
err = json.Unmarshal(body, &modelList)
if err != nil {
t.Fatalf("failed to unmarshal response body: %v", err)
}
if modelList.Models == nil || len(modelList.Models) != 0 {
t.Errorf("expected empty model list, got %v", modelList.Models)
}
},
},
{
Name: "openai empty list",
Method: http.MethodGet,
Path: "/v1/models",
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json" {
t.Errorf("expected content type application/json, got %s", contentType)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
var modelList openai.ListCompletion
err = json.Unmarshal(body, &modelList)
if err != nil {
t.Fatalf("failed to unmarshal response body: %v", err)
}
if modelList.Object != "list" || len(modelList.Data) != 0 {
t.Errorf("expected empty model list, got %v", modelList.Data)
}
},
},
{
Name: "Tags Handler (yes tags)",
Method: http.MethodGet,
Path: "/api/tags",
Setup: func(t *testing.T, req *http.Request) {
createTestModel(t, "test-model")
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
if strings.Contains(string(body), "expires_at") {
t.Errorf("response body should not contain 'expires_at'")
}
var modelList api.ListResponse
err = json.Unmarshal(body, &modelList)
if err != nil {
t.Fatalf("failed to unmarshal response body: %v", err)
}
if len(modelList.Models) != 1 || modelList.Models[0].Name != "test-model:latest" {
t.Errorf("expected model 'test-model:latest', got %v", modelList.Models)
}
},
},
{
Name: "Delete Model Handler",
Method: http.MethodDelete,
Path: "/api/delete",
Setup: func(t *testing.T, req *http.Request) {
createTestModel(t, "model_to_delete")
deleteReq := api.DeleteRequest{
Name: "model_to_delete",
}
jsonData, err := json.Marshal(deleteReq)
if err != nil {
t.Fatalf("failed to marshal delete request: %v", err)
}
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status code 200, got %d", resp.StatusCode)
}
// Verify the model was deleted
_, err := GetModel("model-to-delete")
if err == nil || !os.IsNotExist(err) {
t.Errorf("expected model to be deleted, got error %v", err)
}
},
},
{
Name: "Delete Non-existent Model",
Method: http.MethodDelete,
Path: "/api/delete",
Setup: func(t *testing.T, req *http.Request) {
deleteReq := api.DeleteRequest{
Name: "non_existent_model",
}
jsonData, err := json.Marshal(deleteReq)
if err != nil {
t.Fatalf("failed to marshal delete request: %v", err)
}
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
if resp.StatusCode != http.StatusNotFound {
t.Errorf("expected status code 404, got %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
var errorResp map[string]string
err = json.Unmarshal(body, &errorResp)
if err != nil {
t.Fatalf("failed to unmarshal response body: %v", err)
}
if !strings.Contains(errorResp["error"], "not found") {
t.Errorf("expected error message to contain 'not found', got %s", errorResp["error"])
}
},
},
{
Name: "openai list models with tags",
Method: http.MethodGet,
Path: "/v1/models",
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json" {
t.Errorf("expected content type application/json, got %s", contentType)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
var modelList openai.ListCompletion
err = json.Unmarshal(body, &modelList)
if err != nil {
t.Fatalf("failed to unmarshal response body: %v", err)
}
if len(modelList.Data) != 1 || modelList.Data[0].Id != "test-model:latest" || modelList.Data[0].OwnedBy != "library" {
t.Errorf("expected model 'test-model:latest' owned by 'library', got %v", modelList.Data)
}
},
},
{
Name: "Create Model Handler",
Method: http.MethodPost,
Path: "/api/create",
Setup: func(t *testing.T, req *http.Request) {
_, digest := createTestFile(t, "ollama-model")
stream := false
createReq := api.CreateRequest{
Name: "t-bone",
Files: map[string]string{"test.gguf": digest},
Stream: &stream,
}
jsonData, err := json.Marshal(createReq)
if err != nil {
t.Fatalf("failed to marshal create request: %v", err)
}
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json" {
t.Errorf("expected content type application/json, got %s", contentType)
}
_, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
if resp.StatusCode != http.StatusOK { // Updated line
t.Errorf("expected status code 200, got %d", resp.StatusCode)
}
model, err := GetModel("t-bone")
if err != nil {
t.Fatalf("failed to get model: %v", err)
}
if model.ShortName != "t-bone:latest" {
t.Errorf("expected model name 't-bone:latest', got %s", model.ShortName)
}
},
},
{
Name: "Copy Model Handler",
Method: http.MethodPost,
Path: "/api/copy",
Setup: func(t *testing.T, req *http.Request) {
createTestModel(t, "hamshank")
copyReq := api.CopyRequest{
Source: "hamshank",
Destination: "beefsteak",
}
jsonData, err := json.Marshal(copyReq)
if err != nil {
t.Fatalf("failed to marshal copy request: %v", err)
}
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
model, err := GetModel("beefsteak")
if err != nil {
t.Fatalf("failed to get model: %v", err)
}
if model.ShortName != "beefsteak:latest" {
t.Errorf("expected model name 'beefsteak:latest', got %s", model.ShortName)
}
},
},
{
Name: "Show Model Handler",
Method: http.MethodPost,
Path: "/api/show",
Setup: func(t *testing.T, req *http.Request) {
createTestModel(t, "show-model")
showReq := api.ShowRequest{Model: "show-model"}
jsonData, err := json.Marshal(showReq)
if err != nil {
t.Fatalf("failed to marshal show request: %v", err)
}
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Errorf("expected content type application/json; charset=utf-8, got %s", contentType)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
var showResp api.ShowResponse
err = json.Unmarshal(body, &showResp)
if err != nil {
t.Fatalf("failed to unmarshal response body: %v", err)
}
var params []string
paramsSplit := strings.Split(showResp.Parameters, "\n")
for _, p := range paramsSplit {
params = append(params, strings.Join(strings.Fields(p), " "))
}
sort.Strings(params)
expectedParams := []string{
"seed 42",
"stop \"bar\"",
"stop \"foo\"",
"top_p 0.9",
}
if !slices.Equal(params, expectedParams) {
t.Errorf("expected parameters %v, got %v", expectedParams, params)
}
paramCount, ok := showResp.ModelInfo["general.parameter_count"].(float64)
if !ok {
t.Fatalf("expected parameter count to be a float64, got %T", showResp.ModelInfo["general.parameter_count"])
}
if math.Abs(paramCount) > 1e-9 {
t.Errorf("expected parameter count to be 0, got %f", paramCount)
}
},
},
{
Name: "openai retrieve model handler",
Setup: func(t *testing.T, req *http.Request) {
createTestModel(t, "show-model")
},
Method: http.MethodGet,
Path: "/v1/models/show-model",
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json" {
t.Errorf("expected content type application/json, got %s", contentType)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
var m openai.Model
err = json.Unmarshal(body, &m)
if err != nil {
t.Fatalf("failed to unmarshal response body: %v", err)
}
if m.Id != "show-model" || m.OwnedBy != "library" {
t.Errorf("expected model 'show-model' owned by 'library', got %v", m)
}
},
},
{
Name: "Method Not Allowed",
Method: http.MethodGet,
Path: "/api/show",
Expected: func(t *testing.T, resp *http.Response) {
if resp.StatusCode != 405 {
t.Errorf("expected status code 405, got %d", resp.StatusCode)
}
},
},
}
modelsDir := t.TempDir()
t.Setenv("OLLAMA_MODELS", modelsDir)
rc := &ollama.Registry{
// This is a temporary measure to allow us to move forward,
// surfacing any code contacting ollama.com we do not intended
// to.
//
// Currently, this only handles DELETE /api/delete, which
// should not make any contact with the ollama.com registry, so
// be clear about that.
//
// Tests that do need to contact the registry here, will be
// consumed into our new server/api code packages and removed
// from here.
HTTPClient: panicOnRoundTrip,
}
s := &Server{}
router, err := s.GenerateRoutes(rc)
if err != nil {
t.Fatalf("failed to generate routes: %v", err)
}
httpSrv := httptest.NewServer(router)
t.Cleanup(httpSrv.Close)
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
u := httpSrv.URL + tc.Path
req, err := http.NewRequestWithContext(t.Context(), tc.Method, u, nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
if tc.Setup != nil {
tc.Setup(t, req)
}
resp, err := httpSrv.Client().Do(req)
if err != nil {
t.Fatalf("failed to do request: %v", err)
}
defer resp.Body.Close()
if tc.Expected != nil {
tc.Expected(t, resp)
}
})
}
}
func TestGetModelInfo_SafetensorsUsesStoredFileType(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
cfgData, err := json.Marshal(model.ConfigV2{
ModelFormat: "safetensors",
FileType: "mxfp8",
Capabilities: []string{"completion"},
})
if err != nil {
t.Fatalf("failed to marshal config: %v", err)
}
configLayer, err := manifest.NewLayer(bytes.NewReader(cfgData), "application/vnd.docker.container.image.v1+json")
if err != nil {
t.Fatalf("failed to create config layer: %v", err)
}
name := model.ParseName("show-safetensors")
if err := manifest.WriteManifest(name, configLayer, nil); err != nil {
t.Fatalf("failed to write manifest: %v", err)
}
resp, err := GetModelInfo(api.ShowRequest{Model: name.String()})
if err != nil {
t.Fatalf("GetModelInfo() error = %v", err)
}
if resp.Details.QuantizationLevel != "mxfp8" {
t.Fatalf("QuantizationLevel = %q, want %q", resp.Details.QuantizationLevel, "mxfp8")
}
}
func TestGetModelInfo_SafetensorsModelfileUsesShortName(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
cfgData, err := json.Marshal(model.ConfigV2{
ModelFormat: "safetensors",
Capabilities: []string{"completion"},
})
if err != nil {
t.Fatalf("failed to marshal config: %v", err)
}
configLayer, err := manifest.NewLayer(bytes.NewReader(cfgData), "application/vnd.docker.container.image.v1+json")
if err != nil {
t.Fatalf("failed to create config layer: %v", err)
}
name := model.ParseName("show-safetensors")
if err := manifest.WriteManifest(name, configLayer, nil); err != nil {
t.Fatalf("failed to write manifest: %v", err)
}
resp, err := GetModelInfo(api.ShowRequest{Model: name.String()})
if err != nil {
t.Fatalf("GetModelInfo() error = %v", err)
}
if !strings.Contains(resp.Modelfile, "FROM show-safetensors:latest\n") {
t.Fatalf("Modelfile = %q, want FROM show-safetensors:latest", resp.Modelfile)
}
if strings.Contains(resp.Modelfile, "# To build a new Modelfile based on this, replace FROM with:") {
t.Fatalf("Modelfile should not include replacement hint: %q", resp.Modelfile)
}
}
func casingShuffle(s string) string {
rr := []rune(s)
for i := range rr {
if rand.N(2) == 0 {
rr[i] = unicode.ToUpper(rr[i])
} else {
rr[i] = unicode.ToLower(rr[i])
}
}
return string(rr)
}
func TestManifestCaseSensitivity(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
r := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
io.WriteString(w, `{}`) //nolint:errcheck
}))
defer r.Close()
nameUsed := make(map[string]bool)
name := func() string {
const fqmn = "example/namespace/model:tag"
for {
v := casingShuffle(fqmn)
if nameUsed[v] {
continue
}
nameUsed[v] = true
return v
}
}
wantStableName := name()
t.Logf("stable name: %s", wantStableName)
// checkManifestList tests that there is strictly one manifest in the
// models directory, and that the manifest is for the model under test.
checkManifestList := func() {
t.Helper()
mandir, err := manifest.V2Path()
if err != nil {
t.Fatalf("failed to resolve v2 manifest path: %v", err)
}
var entries []string
t.Logf("dir entries:")
fsys := os.DirFS(mandir)
err = fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error {
if err != nil {
return err
}
t.Logf(" %s", fs.FormatDirEntry(info))
if info.IsDir() {
return nil
}
path = strings.TrimPrefix(path, mandir)
entries = append(entries, path)
return nil
})
if err != nil {
t.Fatalf("failed to walk directory: %v", err)
}
if len(entries) != 1 {
t.Errorf("len(got) = %d, want 1", len(entries))
return // do not use Fatal so following steps run
}
g := entries[0] // raw path
g = filepath.ToSlash(g)
wp, err := manifest.V2PathForName(model.ParseName(wantStableName))
if err != nil {
t.Fatalf("failed to resolve expected manifest path: %v", err)
}
w, err := filepath.Rel(mandir, wp)
if err != nil {
t.Fatalf("failed to make expected manifest path relative: %v", err)
}
w = filepath.ToSlash(w)
if g != w {
t.Errorf("\ngot: %s\nwant: %s", g, w)
}
}
checkOK := func(w *httptest.ResponseRecorder) {
t.Helper()
if w.Code != http.StatusOK {
t.Errorf("code = %d, want 200", w.Code)
t.Logf("body: %s", w.Body.String())
}
}
var s Server
testMakeRequestDialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "tcp", r.Listener.Addr().String())
}
t.Cleanup(func() { testMakeRequestDialContext = nil })
t.Logf("creating")
_, digest := createBinFile(t, nil, nil)
checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
// Start with the stable name, and later use a case-shuffled
// version.
Name: wantStableName,
Files: map[string]string{"test.gguf": digest},
Stream: &stream,
}))
checkManifestList()
t.Logf("creating (again)")
checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
Name: name(),
Files: map[string]string{"test.gguf": digest},
Stream: &stream,
}))
checkManifestList()
t.Logf("pulling")
checkOK(createRequest(t, s.PullHandler, api.PullRequest{
Name: name(),
Stream: &stream,
Insecure: true,
}))
checkManifestList()
t.Logf("copying")
checkOK(createRequest(t, s.CopyHandler, api.CopyRequest{
Source: name(),
Destination: name(),
}))
checkManifestList()
t.Logf("pushing")
rr := createRequest(t, s.PushHandler, api.PushRequest{
Model: name(),
Insecure: true,
Username: "alice",
Password: "x",
})
checkOK(rr)
if !strings.Contains(rr.Body.String(), `"status":"success"`) {
t.Errorf("got = %q, want success", rr.Body.String())
}
}
func TestShow(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
var s Server
_, digest1 := createBinFile(t, ggml.KV{"general.architecture": "test"}, nil)
_, digest2 := createBinFile(t, ggml.KV{"general.type": "projector", "general.architecture": "clip"}, nil)
createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "show-model",
Files: map[string]string{"model.gguf": digest1, "projector.gguf": digest2},
})
w := createRequest(t, s.ShowHandler, api.ShowRequest{
Name: "show-model",
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
var resp api.ShowResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.ModelInfo["general.architecture"] != "test" {
t.Fatal("Expected model architecture to be 'test', but got", resp.ModelInfo["general.architecture"])
}
if resp.ProjectorInfo["general.architecture"] != "clip" {
t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
}
}
func createShowSafetensorsLayer(t *testing.T, tensorName string, shape []int64) manifest.Layer {
t.Helper()
header := map[string]any{
tensorName: map[string]any{
"dtype": "F32",
"shape": shape,
"data_offsets": []int64{0, 16},
},
}
headerData, err := json.Marshal(header)
if err != nil {
t.Fatal(err)
}
var buf bytes.Buffer
if err := binary.Write(&buf, binary.LittleEndian, uint64(len(headerData))); err != nil {
t.Fatal(err)
}
buf.Write(headerData)
layer, err := manifest.NewLayer(bytes.NewReader(buf.Bytes()), manifest.MediaTypeImageTensor)
if err != nil {
t.Fatal(err)
}
layer.Name = tensorName
return layer
}
func writeShowManifestVariant(t *testing.T, name, runner, format string, cfg model.ConfigV2, kv map[string]any, extraLayers ...manifest.Layer) {
t.Helper()
configData, err := json.Marshal(cfg)
if err != nil {
t.Fatal(err)
}
configLayer, err := manifest.NewLayer(bytes.NewReader(configData), "application/vnd.docker.container.image.v1+json")
if err != nil {
t.Fatal(err)
}
layers := make([]manifest.Layer, 0, len(extraLayers)+1)
switch format {
case manifest.FormatGGUF:
_, digest := createBinFile(t, kv, nil)
modelLayer, err := manifest.NewLayerFromLayer(digest, "application/vnd.ollama.image.model", name)
if err != nil {
t.Fatal(err)
}
layers = append(layers, modelLayer)
case manifest.FormatSafetensors:
layers = append(layers, createShowSafetensorsLayer(t, name+".weight", []int64{2, 2}))
}
layers = append(layers, extraLayers...)
if err := manifest.WriteManifestWithMetadata(model.ParseName(name), configLayer, layers, runner, format); err != nil {
t.Fatal(err)
}
}
func TestShowAllManifestsNonListReturnsSingleManifest(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
var s Server
_, digest := createBinFile(t, ggml.KV{"general.architecture": "test"}, nil)
createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "show-model",
Files: map[string]string{"model.gguf": digest},
})
w := createRequest(t, s.ShowHandler, api.ShowRequest{
Model: "show-model",
AllManifests: true,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d: %s", w.Code, w.Body.String())
}
var resp api.ShowManifestsResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if len(resp.Manifests) != 1 {
t.Fatalf("manifest count = %d, want 1", len(resp.Manifests))
}
if resp.Manifests[0].Runner != manifest.RunnerGGML {
t.Fatalf("runner = %q, want %q", resp.Manifests[0].Runner, manifest.RunnerGGML)
}
if resp.Manifests[0].Details.Format != manifest.FormatGGUF {
t.Fatalf("format = %q, want %q", resp.Manifests[0].Details.Format, manifest.FormatGGUF)
}
if resp.Manifests[0].ModelInfo["general.architecture"] != "test" {
t.Fatalf("architecture = %v, want %q", resp.Manifests[0].ModelInfo["general.architecture"], "test")
}
}
func TestShowAllManifestsManifestListDedupesLicenses(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
licenseLayer, err := manifest.NewLayer(bytes.NewReader([]byte("Apache-2.0")), "application/vnd.ollama.image.license")
if err != nil {
t.Fatal(err)
}
writeShowManifestVariant(t, "show-mlx", manifest.RunnerMLX, manifest.FormatSafetensors, model.ConfigV2{
ModelFormat: manifest.FormatSafetensors,
ModelFamily: "qwen3_5_moe",
ModelType: "35.1B",
FileType: "nvfp4",
Requires: "0.19.0",
Capabilities: []string{"completion", "vision", "thinking", "tools"},
}, nil, licenseLayer)
writeShowManifestVariant(t, "show-ggml", manifest.RunnerGGML, manifest.FormatGGUF, model.ConfigV2{
ModelFormat: manifest.FormatGGUF,
ModelFamily: "qwen35moe",
ModelType: "36.0B",
FileType: "Q4_K_M",
Capabilities: []string{"completion", "vision", "thinking", "tools"},
}, ggml.KV{"general.architecture": "qwen35moe"}, licenseLayer)
mlxManifest, err := manifest.ParseNamedManifestForRunner(model.ParseName("show-mlx"), manifest.RunnerMLX)
if err != nil {
t.Fatal(err)
}
ggmlManifest, err := manifest.ParseNamedManifestForRunner(model.ParseName("show-ggml"), manifest.RunnerGGML)
if err != nil {
t.Fatal(err)
}
mlxRef, err := manifest.NewManifestReference(mlxManifest.BlobDigest(), manifest.RunnerMLX, manifest.FormatSafetensors)
if err != nil {
t.Fatal(err)
}
ggmlRef, err := manifest.NewManifestReference(ggmlManifest.BlobDigest(), manifest.RunnerGGML, manifest.FormatGGUF)
if err != nil {
t.Fatal(err)
}
parentData, err := json.Marshal(manifest.Manifest{
SchemaVersion: 2,
MediaType: manifest.MediaTypeManifestList,
Manifests: []manifest.Manifest{mlxRef, ggmlRef},
})
if err != nil {
t.Fatal(err)
}
if err := manifest.WriteManifestData(model.ParseName("show-list"), parentData); err != nil {
t.Fatal(err)
}
var s Server
w := createRequest(t, s.ShowHandler, api.ShowRequest{
Model: "show-list",
AllManifests: true,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d: %s", w.Code, w.Body.String())
}
var resp api.ShowManifestsResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if len(resp.Manifests) != 2 {
t.Fatalf("manifest count = %d, want 2", len(resp.Manifests))
}
if resp.Manifests[0].Runner != manifest.RunnerMLX || resp.Manifests[1].Runner != manifest.RunnerGGML {
t.Fatalf("runner order = [%q %q], want [%q %q]", resp.Manifests[0].Runner, resp.Manifests[1].Runner, manifest.RunnerMLX, manifest.RunnerGGML)
}
if resp.License != "Apache-2.0" {
t.Fatalf("license = %q, want %q", resp.License, "Apache-2.0")
}
if resp.Manifests[0].License != "Apache-2.0" || resp.Manifests[1].License != "Apache-2.0" {
t.Fatalf("child licenses = [%q %q], want both Apache-2.0", resp.Manifests[0].License, resp.Manifests[1].License)
}
if resp.Manifests[0].Requires != "0.19.0" {
t.Fatalf("mlx requires = %q, want %q", resp.Manifests[0].Requires, "0.19.0")
}
if len(resp.Manifests[0].Tensors) != 1 {
t.Fatalf("mlx tensor count = %d, want 1", len(resp.Manifests[0].Tensors))
}
if resp.Manifests[0].Tensors[0].Name != "show-mlx.weight" {
t.Fatalf("mlx tensor name = %q, want %q", resp.Manifests[0].Tensors[0].Name, "show-mlx.weight")
}
}
func TestShowAllManifestsRejectsRunnerSelection(t *testing.T) {
var s Server
w := createRequest(t, s.ShowHandler, api.ShowRequest{
Model: "show-model",
Runner: manifest.RunnerMLX,
AllManifests: true,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d: %s", w.Code, w.Body.String())
}
if got := strings.TrimSpace(w.Body.String()); got != `{"error":"runner cannot be used with all_manifests"}` {
t.Fatalf("response = %s", got)
}
}
func TestShowCopilotUserAgentOverwritesExistingBasename(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
var s Server
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "show-model",
From: "bob",
RemoteHost: "https://ollama.com",
Info: map[string]any{
"model_family": "gptoss",
"base_name": "upstream-base-name",
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200 creating model, actual %d", w.Code)
}
h, err := s.GenerateRoutes(nil)
if err != nil {
t.Fatal(err)
}
makeRequest := func(userAgent string) api.ShowResponse {
t.Helper()
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/show", strings.NewReader(`{"model":"show-model"}`))
req.Header.Set("Content-Type", "application/json")
if userAgent != "" {
req.Header.Set("User-Agent", userAgent)
}
h.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
var resp api.ShowResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
return resp
}
withoutCopilot := makeRequest("")
if withoutCopilot.ModelInfo["general.basename"] != "upstream-base-name" {
t.Fatalf("expected general.basename to be %q, got %v", "upstream-base-name", withoutCopilot.ModelInfo["general.basename"])
}
withCopilot := makeRequest("GitHubCopilotChat/0.41.1")
if withCopilot.ModelInfo["general.basename"] != "show-model" {
t.Fatalf("expected general.basename to be %q, got %v", "show-model", withCopilot.ModelInfo["general.basename"])
}
if withCopilot.ModelInfo["general.architecture"] != "gptoss" {
t.Fatalf("expected general.architecture to be %q, got %v", "gptoss", withCopilot.ModelInfo["general.architecture"])
}
}
func TestShowCopilotUserAgentSetsBasenameWhenModelInfoIsEmpty(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
var s Server
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "show-remote",
From: "bob",
RemoteHost: "https://ollama.com",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200 creating model, actual %d", w.Code)
}
h, err := s.GenerateRoutes(nil)
if err != nil {
t.Fatal(err)
}
w = httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/show", strings.NewReader(`{"model":"show-remote"}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "GitHubCopilotChat/0.41.1")
h.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
var resp api.ShowResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.ModelInfo["general.basename"] != "show-remote" {
t.Fatalf("expected general.basename to be %q, got %v", "show-remote", resp.ModelInfo["general.basename"])
}
if len(resp.ModelInfo) != 1 {
t.Fatalf("expected model_info to contain only general.basename, got %#v", resp.ModelInfo)
}
}
func TestNormalize(t *testing.T) {
type testCase struct {
input []float32
expectError bool
}
testCases := []testCase{
{input: []float32{1}, expectError: false},
{input: []float32{0, 1, 2, 3}, expectError: false},
{input: []float32{0.1, 0.2, 0.3}, expectError: false},
{input: []float32{-0.1, 0.2, 0.3, -0.4}, expectError: false},
{input: []float32{0, 0, 0}, expectError: false},
{input: []float32{float32(math.NaN()), 0.2, 0.3}, expectError: true},
{input: []float32{0.1, float32(math.NaN()), 0.3}, expectError: true},
{input: []float32{float32(math.Inf(1)), 0.2, 0.3}, expectError: true},
{input: []float32{float32(math.Inf(-1)), 0.2, 0.3}, expectError: true},
}
isNormalized := func(vec []float32) (res bool) {
sum := 0.0
for _, v := range vec {
sum += float64(v * v)
}
if math.Abs(sum-1) > 1e-6 {
return sum == 0
} else {
return true
}
}
for _, tc := range testCases {
t.Run("", func(t *testing.T) {
normalized, err := normalize(tc.input)
if tc.expectError {
if err == nil {
t.Errorf("Expected error for input %v, but got none", tc.input)
}
} else {
if err != nil {
t.Errorf("Unexpected error for input %v: %v", tc.input, err)
}
if !isNormalized(normalized) {
t.Errorf("Vector %v is not normalized", tc.input)
}
}
})
}
}
func TestFilterThinkTags(t *testing.T) {
type testCase struct {
msgs []api.Message
want []api.Message
model *Model
}
testCases := []testCase{
{
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
{Role: "user", Content: "What is the answer?"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "abc"},
{Role: "user", Content: "What is the answer?"},
},
model: &Model{
Config: model.ConfigV2{
ModelFamily: "qwen3",
},
},
},
// with newlines inside the think tag aned newlines after
{
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... \n\nabout \nthe answer</think>\n\nabc\ndef"},
{Role: "user", Content: "What is the answer?"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "abc\ndef"},
{Role: "user", Content: "What is the answer?"},
},
model: &Model{
Config: model.ConfigV2{
ModelFamily: "qwen3",
},
},
},
// should leave thinking tags if it's after the last user message
{
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking...</think>after"},
{Role: "user", Content: "What is the answer?"},
{Role: "assistant", Content: "<think>thinking again</think>hjk"},
{Role: "assistant", Content: "<think>thinking yet again</think>hjk"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "after"},
{Role: "user", Content: "What is the answer?"},
{Role: "assistant", Content: "<think>thinking again</think>hjk"},
{Role: "assistant", Content: "<think>thinking yet again</think>hjk"},
},
model: &Model{
Config: model.ConfigV2{
ModelFamily: "qwen3",
},
},
},
{
// shouldn't strip anything because the model family isn't one of the hardcoded ones
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
{Role: "user", Content: "What is the answer?"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
{Role: "user", Content: "What is the answer?"},
},
model: &Model{
Config: model.ConfigV2{
ModelFamily: "llama3",
},
},
},
{
// deepseek-r1:-prefixed model
msgs: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "<think>Thinking... about the answer</think>abc"},
{Role: "user", Content: "What is the answer?"},
},
want: []api.Message{
{Role: "user", Content: "Hello, world!"},
{Role: "assistant", Content: "abc"},
{Role: "user", Content: "What is the answer?"},
},
model: &Model{
Name: "registry.ollama.ai/library/deepseek-r1:latest",
ShortName: "deepseek-r1:7b",
Config: model.ConfigV2{},
},
},
}
for i, tc := range testCases {
filtered := filterThinkTags(tc.msgs, tc.model)
if !reflect.DeepEqual(filtered, tc.want) {
t.Errorf("messages differ for case %d:", i)
for i := range tc.want {
if i >= len(filtered) {
t.Errorf(" missing message %d: %+v", i, tc.want[i])
continue
}
if !reflect.DeepEqual(filtered[i], tc.want[i]) {
t.Errorf(" message %d:\n want: %+v\n got: %+v", i, tc.want[i], filtered[i])
}
}
if len(filtered) > len(tc.want) {
for i := len(tc.want); i < len(filtered); i++ {
t.Errorf(" extra message %d: %+v", i, filtered[i])
}
}
}
}
}
func TestWaitForStream(t *testing.T) {
gin.SetMode(gin.TestMode)
cases := []struct {
name string
messages []any
expectCode int
expectBody string
}{
{
name: "error",
messages: []any{
gin.H{"error": "internal server error"},
},
expectCode: http.StatusInternalServerError,
expectBody: `{"error":"internal server error"}`,
},
{
name: "error status",
messages: []any{
gin.H{"status": http.StatusNotFound, "error": "not found"},
},
expectCode: http.StatusNotFound,
expectBody: `{"error":"not found"}`,
},
{
name: "unknown error",
messages: []any{
gin.H{"msg": "something else"},
},
expectCode: http.StatusInternalServerError,
expectBody: `{"error":"unknown error"}`,
},
{
name: "unknown type",
messages: []any{
struct{}{},
},
expectCode: http.StatusInternalServerError,
expectBody: `{"error":"unknown message type"}`,
},
{
name: "progress success",
messages: []any{
api.ProgressResponse{Status: "success"},
},
expectCode: http.StatusOK,
expectBody: `{"status":"success"}`,
},
{
name: "progress more than success",
messages: []any{
api.ProgressResponse{Status: "success"},
api.ProgressResponse{Status: "one more thing"},
},
expectCode: http.StatusOK,
expectBody: `{"status":"one more thing"}`,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
ch := make(chan any, len(tt.messages))
for _, msg := range tt.messages {
ch <- msg
}
close(ch)
waitForStream(c, ch)
if w.Code != tt.expectCode {
t.Errorf("expected status %d, got %d", tt.expectCode, w.Code)
}
if diff := cmp.Diff(w.Body.String(), tt.expectBody); diff != "" {
t.Errorf("body mismatch (-want +got):\n%s", diff)
}
})
}
}