Compare commits

...

3 Commits

Author SHA1 Message Date
Eva Ho
010af4e730 clean up 2026-03-12 14:32:19 -04:00
Eva Ho
6287a80587 add tests 2026-03-12 14:31:15 -04:00
Eva Ho
ebd4d0e498 server: use server's context length to set as part of config for local models 2026-03-12 14:26:35 -04:00
2 changed files with 110 additions and 13 deletions

View File

@@ -1147,7 +1147,7 @@ func (s *Server) ShowHandler(c *gin.Context) {
req.Model = modelRef.Base
resp, err := GetModelInfo(req)
resp, m, err := GetModelInfo(req)
if err != nil {
var statusErr api.StatusError
switch {
@@ -1168,27 +1168,50 @@ func (s *Server) ShowHandler(c *gin.Context) {
return
}
// For local models, override the context_length in ModelInfo
// with the server's context length so clients see what the model
// will actually run with, not just the context from the GGUF file.
if resp.RemoteHost == "" && resp.ModelInfo != nil && m != nil {
effectiveCtx := int(envconfig.ContextLength())
if effectiveCtx == 0 {
effectiveCtx = s.defaultNumCtx
}
if numCtx, ok := m.Options["num_ctx"]; ok {
if v, ok := numCtx.(float64); ok {
effectiveCtx = int(v)
}
}
for k := range resp.ModelInfo {
if strings.HasSuffix(k, ".context_length") {
resp.ModelInfo[k] = effectiveCtx
break
}
}
}
c.JSON(http.StatusOK, resp)
}
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, *Model, error) {
name := model.ParseName(req.Model)
if !name.IsValid() {
return nil, model.Unqualified(name)
return nil, nil, model.Unqualified(name)
}
name, err := getExistingName(name)
if err != nil {
return nil, err
return nil, nil, err
}
m, err := GetModel(name.String())
if err != nil {
return nil, err
return nil, nil, err
}
if m.Config.RemoteHost != "" {
if disabled, _ := internalcloud.Status(); disabled {
return nil, api.StatusError{
return nil, nil, api.StatusError{
StatusCode: http.StatusForbidden,
ErrorMessage: internalcloud.DisabledError(cloudErrRemoteModelDetailsUnavailable),
}
@@ -1240,7 +1263,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return nil, err
return nil, nil, err
}
resp := &api.ShowResponse{
@@ -1311,7 +1334,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
// skip loading tensor information if this is a remote model
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
return resp, nil
return resp, m, nil
}
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
@@ -1321,7 +1344,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
resp.Tensors = tensors
}
}
return resp, nil
return resp, m, nil
}
// For safetensors LLM models (experimental), populate ModelInfo from config.json
@@ -1335,12 +1358,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
resp.Tensors = tensors
}
}
return resp, nil
return resp, m, nil
}
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
if err != nil {
return nil, err
return nil, nil, err
}
delete(kvData, "general.name")
@@ -1356,12 +1379,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
if len(m.ProjectorPaths) > 0 {
projectorData, _, err := getModelData(m.ProjectorPaths[0], req.Verbose)
if err != nil {
return nil, err
return nil, nil, err
}
resp.ProjectorInfo = projectorData
}
return resp, nil
return resp, m, nil
}
func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) {

View File

@@ -721,6 +721,80 @@ func TestShow(t *testing.T) {
}
}
func TestShowContextLengthUsesServerDefault(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
s := Server{defaultNumCtx: 4096}
// Create a model with a training context of 131072 (128K)
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.context_length": uint32(131072),
}, nil)
createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "ctx-test-model",
Files: map[string]string{"model.gguf": digest},
})
// Without OLLAMA_CONTEXT_LENGTH set, should use VRAM-based default (4096)
w := createRequest(t, s.ShowHandler, api.ShowRequest{Name: "ctx-test-model"})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var resp api.ShowResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
ctxLen, ok := resp.ModelInfo["llama.context_length"]
if !ok {
t.Fatal("expected llama.context_length in ModelInfo")
}
// JSON decodes numbers as float64
if int(ctxLen.(float64)) != 4096 {
t.Errorf("expected context_length 4096 (server default), got %v", ctxLen)
}
}
func TestShowContextLengthUsesEnvVar(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
t.Setenv("OLLAMA_CONTEXT_LENGTH", "8192")
s := Server{defaultNumCtx: 4096}
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.context_length": uint32(131072),
}, nil)
createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "ctx-test-model",
Files: map[string]string{"model.gguf": digest},
})
w := createRequest(t, s.ShowHandler, api.ShowRequest{Name: "ctx-test-model"})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var resp api.ShowResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
ctxLen, ok := resp.ModelInfo["llama.context_length"]
if !ok {
t.Fatal("expected llama.context_length in ModelInfo")
}
if int(ctxLen.(float64)) != 8192 {
t.Errorf("expected context_length 8192 (env var), got %v", ctxLen)
}
}
func TestNormalize(t *testing.T) {
type testCase struct {
input []float32