Files
ollama/server/routes.go
Daniel Hiltgen 96b202d34b Add support for gemma4 (#15214)
* bench: add prompt calibration, context size flag, and NumCtx reporting

Add --num-ctx flag to set context size, and report NumCtx in model info
header. Calibrate tokens-per-word ratio during warmup using actual
tokenization metrics from the model, replacing the fixed 1.3 heuristic.
This produces more accurate prompt token counts for --prompt-tokens.

Also add fetchContextLength() to query running model context via /api/ps.

* integration: improve vision test robustness and add thinking tests

Add skipIfNoVisionOverride() to skip vision tests when OLLAMA_TEST_MODEL
is set to a non-vision model. Add Think:false to context exhaustion test
to prevent thinking models from using all context before the test can
measure it. Add third test image (ollama homepage) and replace OCR test
with ImageDescription test using it. Relax match strings for broader
model compatibility. Add TestThinkingEnabled and TestThinkingSuppressed
to verify thinking output and channel tag handling.

* gemma4: add Gemma 4 GGML model support

Add full Gemma 4 model family support (E2B, E4B, 26B MoE, 31B Dense)
for the GGML backend including text, vision, converter, parser, and
renderer.

Text model features:
- Sliding window + full attention with per-layer patterns
- KV sharing across layers with donor map
- Per-layer embeddings (PLE) with learned projections
- MoE routing with RMSNorm + learned scale
- Proportional RoPE with freq_factors for global attention
- Final logit softcapping

Vision model features:
- SigLIP vision encoder with 2D RoPE
- ClippableLinear with input/output clamping via packed v.clamp_data
- Adaptive average pooling with nMerge kernel
- Multi-modal projection with unweighted RMSNorm

Converter:
- Safetensors to GGUF with vision tensor renaming
- Fused MoE gate_up_proj splitting
- Vision patch embedding reshape (HF to Conv2D layout)
- Packed clamp data tensor for ClippableLinear bounds
- Proportional RoPE freq_factors generation

Also includes:
- BackendGet() on ml.Tensor for reading weight tensor data
- Q6_K CUDA get_rows kernel support
- MoE-aware ffn_down quantization layer counting
- Gemma4 parser with tool calling and thinking support
- Gemma4 renderer with structured tool format
- Architecture-based auto-detection of renderer/parser/stop tokens
- Integration test gemma4 model list additions

* gemma4: add audio support with USM conformer encoder

Add audio encoding for Gemma 4 using the USM conformer architecture:
- Converter: audio tensor mapping, SSCP/conformer/embedder name replacements,
  softplus repacker for per_dim_scale, F32 enforcement for conv weights
- GGML backend: Conv1DDW and PadExt tensor ops
- Audio encoder: SSCP Conv2D, 12 conformer blocks (FFW + block-local
  attention with relative position embeddings + LightConv1d + FFW),
  output projection, audio-to-text embedding projector
- Audio preprocessing: WAV decode, mel spectrogram, FFT (pure Go)
- Model wiring: WAV detection, audio token handling, unified PostTokenize

Correctly transcribes "why is the sky blue" from test audio.

* integration: add gemma4 audio tests including OpenAI API coverage

Test audio transcription and response via the Ollama native API, plus
two new tests exercising the OpenAI-compatible endpoints:
- /v1/audio/transcriptions (multipart form upload)
- /v1/chat/completions with input_audio content type

All tests use capability checks and skip models without audio support.

* gemma4: add OpenAI audio API support and capability detection

- Add CapabilityAudio and detect from audio.block_count in GGUF
- Add /v1/audio/transcriptions endpoint with TranscriptionMiddleware
- Add input_audio content type support in /v1/chat/completions
- Add TranscriptionRequest/Response types in openai package

* gemma4: add audio input support for run command

- /audio toggle in interactive mode for voice chat
- Platform-specific microphone recording (AVFoundation on macOS,
  PulseAudio/ALSA on Linux, WASAPI on Windows)
- Space to start/stop recording, automatic chunking for long audio

* gemma4: add transcribe command (ollama transcribe MODEL)

- Interactive mode with readline prompt and slash commands
- Non-interactive mode for piped audio or record-until-Ctrl+C
- Chunked streaming transcription for long recordings
- Word-wrapped output matching run command style

* gemma4: add parser, renderer, and integration test plumbing

* gemma4: fix renderer to emit BOS token

* gemma4: add OpenAI audio transcription API and input_audio support

* gemma4: update converter for new weight drop naming

* gemma4: add per_expert_scale to MoE router and fix moe_intermediate_size config

* gemma4: rewrite renderer to match HF Jinja2 template exactly

Fix 8 bugs found by building 55 reference tests verified against the
HF Jinja2 chat template (VERIFY_JINJA2=1 shells out to Python):

- Tool responses use separate <|turn>tool turns (not inline tags)
- Tool calls emitted before content in assistant messages
- Thinking content stripped from assistant history (strip_thinking)
- User, tool, and system content trimmed (template does | trim)
- Empty system message still emits system turn (check role, not content)
- Nested object properties rendered recursively with required field
- Array items specification rendered for array-type properties
- OBJECT/ARRAY type-specific rendering comma logic matches template

Also adds Required field to api.ToolProperty for nested object schemas,
replaces old gemma4_test.go with comprehensive gemma4_reference_test.go,
and commits the Jinja2 template as testdata for verification.

* gemma4: fix MoE fused gate_up split and multiline tool-call arg parsing

- Text MoE: split `ffn_gate_up_exps` into contiguous `[gate|up]` halves instead of stride-2 slices.
- Parser: escape control characters in `<|"|>...<|"|>` string literals when converting tool-call args to JSON.
- Fixes warnings like `invalid character '\n' in string literal` for multiline tool arguments.
- Add Gemma4 parser regressions for multiline tool-call args and `gemma4ArgsToJSON`.

* cmd: simplify audio input to dropped file attachments

* gemma4: use full SWA memory for better cache reuse

* gemma4: initialize clamps after backend load

* convert: align gemma4 audio tensor renames with llama.cpp

* Remove redundant comments in gemma4 vision model

* Format Gemma4 MoE block field alignment

* use 4096 kvcache.NewSWAMemCache

* convert: support new Gemma4 audio_tower tensor naming (#15221)

Co-authored-by: jmorganca <jmorganca@gmail.com>

* fix integration test defaults for audio

* review comments and lint fixes

* remove unused audio/video files

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2026-04-02 11:33:33 -07:00

2773 lines
78 KiB
Go

package server
import (
"bytes"
"cmp"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"image"
"io"
"io/fs"
"log/slog"
"math"
"math/rand"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"os/signal"
"slices"
"strings"
"sync/atomic"
"syscall"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"golang.org/x/image/webp"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/middleware"
"github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/model/renderers"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/thinking"
"github.com/ollama/ollama/tools"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
imagegenmanifest "github.com/ollama/ollama/x/imagegen/manifest"
xserver "github.com/ollama/ollama/x/server"
)
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
const (
cloudErrRemoteInferenceUnavailable = "remote model is unavailable"
cloudErrRemoteModelDetailsUnavailable = "remote model details are unavailable"
cloudErrWebSearchUnavailable = "web search is unavailable"
cloudErrWebFetchUnavailable = "web fetch is unavailable"
copilotChatUserAgentPrefix = "GitHubCopilotChat/"
)
func writeModelRefParseError(c *gin.Context, err error, fallbackStatus int, fallbackMessage string) {
switch {
case errors.Is(err, errConflictingModelSource):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
case errors.Is(err, model.ErrUnqualifiedName):
c.JSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
default:
c.JSON(fallbackStatus, gin.H{"error": fallbackMessage})
}
}
func shouldUseHarmony(model *Model) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
// heuristic to check whether the template expects to be parsed via harmony:
// search for harmony tags that are nearly always used
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
return true
}
}
return false
}
func experimentEnabled(name string) bool {
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
}
var useClient2 = experimentEnabled("client2")
var mode string = gin.DebugMode
type Server struct {
addr net.Addr
sched *Scheduler
defaultNumCtx int
requestLogger *inferenceRequestLogger
}
func init() {
switch mode {
case gin.DebugMode:
case gin.ReleaseMode:
case gin.TestMode:
default:
mode = gin.DebugMode
}
gin.SetMode(mode)
// Tell renderers to use [img] tags
renderers.RenderImgTags = true
}
var (
errRequired = errors.New("is required")
errBadTemplate = errors.New("template error")
)
func (s *Server) modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) {
opts := api.DefaultOptions()
if opts.NumCtx == 0 {
opts.NumCtx = s.defaultNumCtx
}
if err := opts.FromMap(model.Options); err != nil {
return api.Options{}, err
}
if err := opts.FromMap(requestOpts); err != nil {
return api.Options{}, err
}
return opts, nil
}
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
if name == "" {
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
}
model, err := GetModel(name)
if err != nil {
return nil, nil, nil, err
}
if slices.Contains(model.Config.ModelFamilies, "mllama") && len(model.ProjectorPaths) > 0 {
return nil, nil, nil, fmt.Errorf("'llama3.2-vision' is no longer compatible with your version of Ollama and has been replaced by a newer version. To re-download, run 'ollama pull llama3.2-vision'")
}
if err := model.CheckCapabilities(caps...); err != nil {
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
}
// Deprecated runner override option; ignore if present.
delete(requestOpts, "use_imagegen_runner")
opts, err := s.modelOptions(model, requestOpts)
if err != nil {
return nil, nil, nil, err
}
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err = <-errCh:
return nil, nil, nil, err
}
return runner.llama, model, &opts, nil
}
func signinURL() (string, error) {
pubKey, err := auth.GetPublicKey()
if err != nil {
return "", err
}
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
h, _ := os.Hostname()
return fmt.Sprintf(signinURLStr, url.PathEscape(h), encKey), nil
}
func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
return
}
modelRef, err := parseAndValidateModelRef(req.Model)
if err != nil {
writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model))
return
}
if modelRef.Source == modelSourceCloud {
// TODO(drifkin): evaluate an `/api/*` passthrough for cloud where the
// original body (modulo model name normalization) is sent to cloud.
req.Model = modelRef.Base
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
return
}
name := modelRef.Name
// We cannot currently consolidate this into GetModel because all we'll
// induce infinite recursion given the current code structure.
name, err = getExistingName(name)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
m, err := GetModel(name.String())
if err != nil {
switch {
case errors.Is(err, fs.ErrNotExist):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == errtypes.InvalidModelNameErrMsg:
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
return
}
if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
if disabled, _ := internalcloud.Status(); disabled {
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)})
return
}
origModel := req.Model
remoteURL, err := url.Parse(m.Config.RemoteHost)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) {
slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname())
c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"})
return
}
req.Model = m.Config.RemoteModel
if req.Template == "" && m.Template.String() != "" {
req.Template = m.Template.String()
}
if req.Options == nil {
req.Options = map[string]any{}
}
for k, v := range m.Options {
if _, ok := req.Options[k]; !ok {
req.Options[k] = v
}
}
// update the system prompt from the model if one isn't already specified
if req.System == "" && m.System != "" {
req.System = m.System
}
if len(m.Messages) > 0 {
slog.Warn("embedded messages in the model not supported with '/api/generate'; try '/api/chat' instead")
}
contentType := "application/x-ndjson"
if req.Stream != nil && !*req.Stream {
contentType = "application/json; charset=utf-8"
}
c.Header("Content-Type", contentType)
fn := func(resp api.GenerateResponse) error {
resp.Model = origModel
resp.RemoteModel = m.Config.RemoteModel
resp.RemoteHost = m.Config.RemoteHost
data, err := json.Marshal(resp)
if err != nil {
return err
}
if _, err = c.Writer.Write(append(data, '\n')); err != nil {
return err
}
c.Writer.Flush()
return nil
}
client := api.NewClient(remoteURL, http.DefaultClient)
err = client.Generate(c, &req, fn)
if err != nil {
var authError api.AuthorizationError
if errors.As(err, &authError) {
sURL, sErr := signinURL()
if sErr != nil {
slog.Error(sErr.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
return
}
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
return
}
var apiError api.StatusError
if errors.As(err, &apiError) {
c.JSON(apiError.StatusCode, apiError)
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
return
}
// expire the runner if unload is requested (empty prompt, keep alive is 0)
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
s.sched.expireRunner(m)
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Response: "",
Done: true,
DoneReason: "unload",
})
return
}
// Handle image generation models
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
s.handleImageGenerate(c, req, name.String(), checkpointStart)
return
}
if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return
}
var builtinParser parsers.Parser
if shouldUseHarmony(m) && m.Config.Parser == "" {
m.Config.Parser = "harmony"
}
if !req.Raw && m.Config.Parser != "" {
builtinParser = parsers.ParserForName(m.Config.Parser)
if builtinParser != nil {
// no tools or last message for generate endpoint
builtinParser.Init(nil, nil, req.Think)
}
}
caps := []model.Capability{model.CapabilityCompletion}
if req.Suffix != "" {
caps = append(caps, model.CapabilityInsert)
}
modelCaps := m.Capabilities()
if slices.Contains(modelCaps, model.CapabilityThinking) {
caps = append(caps, model.CapabilityThinking)
if req.Think == nil {
req.Think = &api.ThinkValue{Value: true}
}
} else {
if req.Think != nil && req.Think.Bool() {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
return
}
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
return
} else if err != nil {
handleScheduleError(c, req.Model, err)
return
}
checkpointLoaded := time.Now()
// load the model
if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: true,
DoneReason: "load",
})
return
}
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(req.Images) > 1 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image while more than one image requested"})
return
}
images := make([]llm.ImageData, len(req.Images))
for i := range req.Images {
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
}
prompt := req.Prompt
if !req.Raw {
tmpl := m.Template
if req.Template != "" {
tmpl, err = template.Parse(req.Template)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
var values template.Values
if req.Suffix != "" {
values.Prompt = prompt
values.Suffix = req.Suffix
} else {
var msgs []api.Message
if req.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
} else if m.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
}
if req.Context == nil {
msgs = append(msgs, m.Messages...)
}
userMsg := api.Message{Role: "user", Content: req.Prompt}
for _, i := range images {
userMsg.Images = append(userMsg.Images, i.Data)
}
values.Messages = append(msgs, userMsg)
}
values.Think = req.Think != nil && req.Think.Bool()
values.ThinkLevel = ""
if req.Think != nil {
values.ThinkLevel = req.Think.String()
}
values.IsThinkSet = req.Think != nil
var b bytes.Buffer
if req.Context != nil {
slog.Warn("the context field is deprecated and will be removed in a future version of Ollama")
s, err := r.Detokenize(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
b.WriteString(s)
}
// check that we're in the `api/chat`-like flow, and if so, generate the
// prompt the same way
// TEMP(drifkin): we should really just detect the chat-like flow and call
// the real chat handler, but doing this as a stopgap to get renderer
// support for generate
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX()
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// TEMP(drifkin): req.Context will be removed very soon, but we're temporarily supporting it in this flow here
if req.Context != nil {
b.WriteString(prompt)
prompt = b.String()
}
} else {
// legacy flow
if err := tmpl.Execute(&b, values); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
prompt = b.String()
}
}
// If debug mode is enabled, return the rendered template instead of calling the model
if req.DebugRenderOnly {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
DebugInfo: &api.DebugInfo{
RenderedTemplate: prompt,
ImageCount: len(images),
},
})
return
}
var thinkingState *thinking.Parser
if builtinParser == nil {
openingTag, closingTag := thinking.InferTags(m.Template.Template)
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
thinkingState = &thinking.Parser{
OpeningTag: openingTag,
ClosingTag: closingTag,
}
if strings.HasSuffix(strings.TrimSpace(prompt), openingTag) {
thinkingState.AddContent(openingTag)
}
}
}
ch := make(chan any)
go func() {
// TODO (jmorganca): avoid building the response twice both here and below
var sb strings.Builder
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
Shift: req.Shift == nil || *req.Shift,
Truncate: req.Truncate == nil || *req.Truncate,
Logprobs: req.Logprobs,
TopLogprobs: req.TopLogprobs,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Response: cr.Content,
Done: cr.Done,
Metrics: api.Metrics{
PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: cr.PromptEvalDuration,
EvalCount: cr.EvalCount,
EvalDuration: cr.EvalDuration,
},
Logprobs: toAPILogprobs(cr.Logprobs),
}
if builtinParser != nil {
content, thinking, toolCalls, err := builtinParser.Add(cr.Content, cr.Done)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
res.Response = content
res.Thinking = thinking
if cr.Done && len(toolCalls) > 0 {
res.ToolCalls = toolCalls
}
} else if thinkingState != nil {
thinking, content := thinkingState.AddContent(cr.Content)
res.Thinking = thinking
res.Response = content
}
if _, err := sb.WriteString(cr.Content); err != nil {
ch <- gin.H{"error": err.Error()}
}
if cr.Done {
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
res.Context = tokens
}
}
if builtinParser != nil {
// only send messages with meaningful content (empty messages confuse clients)
if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 {
ch <- res
}
return
}
ch <- res
}); err != nil {
var serr api.StatusError
if errors.As(err, &serr) {
ch <- gin.H{"error": serr.ErrorMessage, "status": serr.StatusCode}
} else {
ch <- gin.H{"error": err.Error()}
}
}
}()
if req.Stream != nil && !*req.Stream {
var r api.GenerateResponse
var allLogprobs []api.Logprob
var sbThinking strings.Builder
var sbContent strings.Builder
for rr := range ch {
switch t := rr.(type) {
case api.GenerateResponse:
sbThinking.WriteString(t.Thinking)
sbContent.WriteString(t.Response)
r = t
// Accumulate logprobs from all chunks for non-streaming response
if len(t.Logprobs) > 0 {
allLogprobs = append(allLogprobs, t.Logprobs...)
}
case gin.H:
msg, ok := t["error"].(string)
if !ok {
msg = "unexpected error format in response"
}
status, ok := t["status"].(int)
if !ok {
status = http.StatusInternalServerError
}
c.JSON(status, gin.H{"error": msg})
return
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
return
}
}
r.Thinking = sbThinking.String()
r.Response = sbContent.String()
r.Logprobs = allLogprobs
c.JSON(http.StatusOK, r)
return
}
streamResponse(c, ch)
}
func (s *Server) EmbedHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.EmbedRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
modelRef, err := parseAndValidateModelRef(req.Model)
if err != nil {
writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model))
return
}
if modelRef.Source == modelSourceCloud {
req.Model = modelRef.Base
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
return
}
var input []string
switch i := req.Input.(type) {
case string:
if len(i) > 0 {
input = append(input, i)
}
case []any:
for _, v := range i {
if _, ok := v.(string); !ok {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
input = append(input, v.(string))
}
default:
if req.Input != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
}
name, err := getExistingName(modelRef.Name)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
}
checkpointLoaded := time.Now()
if len(input) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
kvData, _, err := getModelData(m.ModelPath, false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctx := c.Request.Context()
embedWithRetry := func(text string) ([]float32, int, error) {
emb, tokCount, err := r.Embedding(ctx, text)
if err == nil {
return emb, tokCount, nil
}
var serr api.StatusError
if !errors.As(err, &serr) || serr.StatusCode != http.StatusBadRequest {
return nil, 0, err
}
if req.Truncate != nil && !*req.Truncate {
return nil, 0, err
}
tokens, err := r.Tokenize(ctx, text)
if err != nil {
return nil, 0, err
}
// TODO @nicolepardal: avoid reaching into kvData here; pass required tokenizer metadata via model/options instead
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); len(tokens) > 0 && tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
ctxLen--
}
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); len(tokens) > 0 && tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
ctxLen--
}
if len(tokens) <= ctxLen {
return nil, 0, fmt.Errorf("input exceeds maximum context length and cannot be truncated further")
}
if ctxLen <= 0 {
return nil, 0, fmt.Errorf("input after truncation exceeds maximum context length")
}
truncatedTokens := tokens[:ctxLen]
truncated, err := r.Detokenize(ctx, truncatedTokens)
if err != nil {
return nil, 0, err
}
return r.Embedding(ctx, truncated)
}
var g errgroup.Group
embeddings := make([][]float32, len(input))
var totalTokens uint64
for i, text := range input {
g.Go(func() error {
embedding, tokenCount, err := embedWithRetry(text)
if err != nil {
return err
}
// TODO: this first normalization should be done by the model
embedding, err = normalize(embedding)
if err != nil {
return err
}
if req.Dimensions > 0 && req.Dimensions < len(embedding) {
embedding, err = normalize(embedding[:req.Dimensions])
if err != nil {
return err
}
}
embeddings[i] = embedding
atomic.AddUint64(&totalTokens, uint64(tokenCount))
return nil
})
}
if err := g.Wait(); err != nil {
var serr api.StatusError
if errors.As(err, &serr) {
c.AbortWithStatusJSON(serr.StatusCode, gin.H{
"error": strings.TrimSpace(serr.ErrorMessage),
})
return
}
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
"error": strings.TrimSpace(err.Error()),
})
return
}
resp := api.EmbedResponse{
Model: req.Model,
Embeddings: embeddings,
TotalDuration: time.Since(checkpointStart),
LoadDuration: checkpointLoaded.Sub(checkpointStart),
PromptEvalCount: int(totalTokens),
}
c.JSON(http.StatusOK, resp)
}
func normalize(vec []float32) ([]float32, error) {
var sum float32
for _, v := range vec {
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
return nil, errors.New("embedding contains NaN or Inf values")
}
sum += v * v
}
norm := float32(1.0 / max(math.Sqrt(float64(sum)), 1e-12))
for i := range vec {
vec[i] *= norm
}
return vec, nil
}
func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
modelRef, err := parseAndValidateModelRef(req.Model)
if err != nil {
writeModelRefParseError(c, err, http.StatusBadRequest, "model is required")
return
}
if modelRef.Source == modelSourceCloud {
req.Model = modelRef.Base
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
return
}
name := modelRef.Name
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
}
// an empty request loads the model
if req.Prompt == "" {
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
return
}
embedding, _, err := r.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
return
}
var e []float64
for _, v := range embedding {
e = append(e, float64(v))
}
resp := api.EmbeddingResponse{
Embedding: e,
}
c.JSON(http.StatusOK, resp)
}
func (s *Server) PullHandler(c *gin.Context) {
var req api.PullRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// TEMP(drifkin): we're temporarily allowing to continue pulling cloud model
// stub-files until we integrate cloud models into `/api/tags` (in which case
// this roundabout way of "adding" cloud models won't be needed anymore). So
// right here normalize any `:cloud` models into the legacy-style suffixes
// `:<tag>-cloud` and `:cloud`
modelRef, err := parseNormalizePullModelRef(cmp.Or(req.Model, req.Name))
if err != nil {
writeModelRefParseError(c, err, http.StatusBadRequest, errtypes.InvalidModelNameErrMsg)
return
}
name := modelRef.Name
name, err = getExistingName(name)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ch := make(chan any)
go func() {
defer close(ch)
fn := func(r api.ProgressResponse) {
ch <- r
}
regOpts := &registryOptions{
Insecure: req.Insecure,
}
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
waitForStream(c, ch)
return
}
streamResponse(c, ch)
}
func (s *Server) PushHandler(c *gin.Context) {
var req api.PushRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
var mname string
if req.Model != "" {
mname = req.Model
} else if req.Name != "" {
mname = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
ch := make(chan any)
go func() {
defer close(ch)
fn := func(r api.ProgressResponse) {
ch <- r
}
regOpts := &registryOptions{
Insecure: req.Insecure,
}
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
name, err := getExistingName(model.ParseName(mname))
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
if err := PushModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
waitForStream(c, ch)
return
}
streamResponse(c, ch)
}
// getExistingName searches the models directory for the longest prefix match of
// the input name and returns the input name with all existing parts replaced
// with each part found. If no parts are found, the input name is returned as
// is.
func getExistingName(n model.Name) (model.Name, error) {
var zero model.Name
existing, err := manifest.Manifests(true)
if err != nil {
return zero, err
}
var set model.Name // tracks parts already canonicalized
for e := range existing {
if set.Host == "" && strings.EqualFold(e.Host, n.Host) {
n.Host = e.Host
}
if set.Namespace == "" && strings.EqualFold(e.Namespace, n.Namespace) {
n.Namespace = e.Namespace
}
if set.Model == "" && strings.EqualFold(e.Model, n.Model) {
n.Model = e.Model
}
if set.Tag == "" && strings.EqualFold(e.Tag, n.Tag) {
n.Tag = e.Tag
}
}
return n, nil
}
func (s *Server) DeleteHandler(c *gin.Context) {
var r api.DeleteRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
modelRef, err := parseNormalizePullModelRef(cmp.Or(r.Model, r.Name))
if err != nil {
switch {
case errors.Is(err, errConflictingModelSource):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
case errors.Is(err, model.ErrUnqualifiedName):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
}
return
}
n, err := getExistingName(modelRef.Name)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
return
}
m, err := manifest.ParseNamedManifest(n)
if err != nil {
switch {
case os.IsNotExist(err):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
if err := m.Remove(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := m.RemoveLayers(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
func (s *Server) ShowHandler(c *gin.Context) {
var req api.ShowRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Model != "" {
// noop
} else if req.Name != "" {
req.Model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
modelRef, err := parseAndValidateModelRef(req.Model)
if err != nil {
writeModelRefParseError(c, err, http.StatusBadRequest, err.Error())
return
}
if modelRef.Source == modelSourceCloud {
req.Model = modelRef.Base
proxyCloudJSONRequest(c, req, cloudErrRemoteModelDetailsUnavailable)
return
}
req.Model = modelRef.Base
resp, err := GetModelInfo(req)
if err != nil {
var statusErr api.StatusError
switch {
case os.IsNotExist(err):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case errors.As(err, &statusErr):
c.JSON(statusErr.StatusCode, gin.H{"error": statusErr.ErrorMessage})
case err.Error() == errtypes.InvalidModelNameErrMsg:
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
if modelRef.Source == modelSourceLocal && resp.RemoteHost != "" {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", modelRef.Original)})
return
}
userAgent := c.Request.UserAgent()
if strings.HasPrefix(userAgent, copilotChatUserAgentPrefix) {
if resp.ModelInfo == nil {
resp.ModelInfo = map[string]any{}
}
// Copilot Chat prefers `general.basename`, but this is usually not what
// users are familiar with, so let's just echo back what we had returned in
// `/api/tags`
resp.ModelInfo["general.basename"] = req.Model
}
c.JSON(http.StatusOK, resp)
}
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
name := model.ParseName(req.Model)
if !name.IsValid() {
return nil, model.Unqualified(name)
}
name, err := getExistingName(name)
if err != nil {
return nil, err
}
m, err := GetModel(name.String())
if err != nil {
return nil, err
}
if m.Config.RemoteHost != "" {
if disabled, _ := internalcloud.Status(); disabled {
return nil, api.StatusError{
StatusCode: http.StatusForbidden,
ErrorMessage: internalcloud.DisabledError(cloudErrRemoteModelDetailsUnavailable),
}
}
}
modelDetails := api.ModelDetails{
ParentModel: m.ParentModel,
Format: m.Config.ModelFormat,
Family: m.Config.ModelFamily,
Families: m.Config.ModelFamilies,
ParameterSize: m.Config.ModelType,
QuantizationLevel: m.Config.FileType,
}
// For image generation models, populate details from imagegen package
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
if info, err := imagegenmanifest.GetModelInfo(name.String()); err == nil {
modelDetails.Family = info.Architecture
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
modelDetails.QuantizationLevel = info.Quantization
}
}
// For safetensors LLM models (experimental), populate details from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
modelDetails.Family = arch
}
if paramCount, ok := info["general.parameter_count"].(int64); ok && paramCount > 0 {
modelDetails.ParameterSize = format.HumanNumber(uint64(paramCount))
}
}
// Older manifests may not have file_type populated for safetensors models.
if modelDetails.QuantizationLevel == "" {
if dtype, err := xserver.GetSafetensorsDtype(name); err == nil && dtype != "" {
modelDetails.QuantizationLevel = dtype
}
}
}
if req.System != "" {
m.System = req.System
}
msgs := make([]api.Message, len(m.Messages))
for i, msg := range m.Messages {
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
}
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return nil, err
}
resp := &api.ShowResponse{
License: strings.Join(m.License, "\n"),
System: m.System,
Template: m.Template.String(),
Details: modelDetails,
Messages: msgs,
Capabilities: m.Capabilities(),
ModifiedAt: mf.FileInfo().ModTime(),
Requires: m.Config.Requires,
// Several integrations crash on a nil/omitempty+empty ModelInfo, so by
// default we return an empty map.
ModelInfo: make(map[string]any),
}
if m.Config.RemoteHost != "" {
resp.RemoteHost = m.Config.RemoteHost
resp.RemoteModel = m.Config.RemoteModel
if m.Config.ModelFamily != "" {
resp.ModelInfo = make(map[string]any)
resp.ModelInfo["general.architecture"] = m.Config.ModelFamily
if m.Config.BaseName != "" {
resp.ModelInfo["general.basename"] = m.Config.BaseName
}
if m.Config.ContextLen > 0 {
resp.ModelInfo[fmt.Sprintf("%s.context_length", m.Config.ModelFamily)] = m.Config.ContextLen
}
if m.Config.EmbedLen > 0 {
resp.ModelInfo[fmt.Sprintf("%s.embedding_length", m.Config.ModelFamily)] = m.Config.EmbedLen
}
}
}
var params []string
cs := 30
for k, v := range m.Options {
switch val := v.(type) {
case []any:
for _, nv := range val {
params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv))
}
default:
params = append(params, fmt.Sprintf("%-*s %#v", cs, k, v))
}
}
resp.Parameters = strings.Join(params, "\n")
if len(req.Options) > 0 {
if m.Options == nil {
m.Options = make(map[string]any)
}
for k, v := range req.Options {
m.Options[k] = v
}
}
var sb strings.Builder
fmt.Fprintln(&sb, "# Modelfile generated by \"ollama show\"")
fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
fmt.Fprintf(&sb, "# FROM %s\n\n", m.ShortName)
fmt.Fprint(&sb, m.String())
resp.Modelfile = sb.String()
// skip loading tensor information if this is a remote model
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
return resp, nil
}
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
// Populate tensor info if verbose
if req.Verbose {
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
resp.Tensors = tensors
}
}
return resp, nil
}
// For safetensors LLM models (experimental), populate ModelInfo from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
resp.ModelInfo = info
}
// Populate tensor info if verbose
if req.Verbose {
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
resp.Tensors = tensors
}
}
return resp, nil
}
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
if err != nil {
return nil, err
}
delete(kvData, "general.name")
delete(kvData, "tokenizer.chat_template")
resp.ModelInfo = kvData
tensorData := make([]api.Tensor, len(tensors.Items()))
for cnt, t := range tensors.Items() {
tensorData[cnt] = api.Tensor{Name: t.Name, Type: t.Type(), Shape: t.Shape}
}
resp.Tensors = tensorData
if len(m.ProjectorPaths) > 0 {
projectorData, _, err := getModelData(m.ProjectorPaths[0], req.Verbose)
if err != nil {
return nil, err
}
resp.ProjectorInfo = projectorData
}
return resp, nil
}
func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) {
maxArraySize := 0
if verbose {
maxArraySize = -1
}
data, err := llm.LoadModel(digest, maxArraySize)
if err != nil {
return nil, ggml.Tensors{}, err
}
kv := data.KV()
if !verbose {
for k := range kv {
if t, ok := kv[k].([]any); len(t) > 5 && ok {
kv[k] = []any{}
}
}
}
return kv, data.Tensors(), nil
}
func (s *Server) ListHandler(c *gin.Context) {
ms, err := manifest.Manifests(true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
models := []api.ListModelResponse{}
for n, m := range ms {
var cf model.ConfigV2
if m.Config.Digest != "" {
f, err := m.Config.Open()
if err != nil {
slog.Warn("bad manifest filepath", "name", n, "error", err)
continue
}
defer f.Close()
if err := json.NewDecoder(f).Decode(&cf); err != nil {
slog.Warn("bad manifest config", "name", n, "error", err)
continue
}
}
// tag should never be masked
models = append(models, api.ListModelResponse{
Model: n.DisplayShortest(),
Name: n.DisplayShortest(),
RemoteModel: cf.RemoteModel,
RemoteHost: cf.RemoteHost,
Size: m.Size(),
Digest: m.Digest(),
ModifiedAt: m.FileInfo().ModTime(),
Details: api.ModelDetails{
Format: cf.ModelFormat,
Family: cf.ModelFamily,
Families: cf.ModelFamilies,
ParameterSize: cf.ModelType,
QuantizationLevel: cf.FileType,
},
})
}
slices.SortStableFunc(models, func(i, j api.ListModelResponse) int {
// most recently modified first
return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
})
c.JSON(http.StatusOK, api.ListResponse{Models: models})
}
func (s *Server) CopyHandler(c *gin.Context) {
var r api.CopyRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
src := model.ParseName(r.Source)
if !src.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)})
return
}
src, err := getExistingName(src)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
dst := model.ParseName(r.Destination)
if !dst.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)})
return
}
dst, err = getExistingName(dst)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
} else if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
}
func (s *Server) HeadBlobHandler(c *gin.Context) {
path, err := manifest.BlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if _, err := os.Stat(path); err != nil {
c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
return
}
c.Status(http.StatusOK)
}
func (s *Server) CreateBlobHandler(c *gin.Context) {
if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
p, err := manifest.BlobsPath(ib)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) {
slog.Info("evicting intermediate blob which no longer exists", "digest", ib)
delete(intermediateBlobs, c.Param("digest"))
} else if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
} else {
c.Status(http.StatusOK)
return
}
}
path, err := manifest.BlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
_, err = os.Stat(path)
switch {
case errors.Is(err, os.ErrNotExist):
// noop
case err != nil:
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
default:
c.Status(http.StatusOK)
return
}
layer, err := manifest.NewLayer(c.Request.Body, "")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if layer.Digest != c.Param("digest") {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("digest mismatch, expected %q, got %q", c.Param("digest"), layer.Digest)})
return
}
c.Status(http.StatusCreated)
}
func isLocalIP(ip netip.Addr) bool {
if interfaces, err := net.Interfaces(); err == nil {
for _, iface := range interfaces {
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, a := range addrs {
if parsed, _, err := net.ParseCIDR(a.String()); err == nil {
if parsed.String() == ip.String() {
return true
}
}
}
}
}
return false
}
func allowedHost(host string) bool {
host = strings.ToLower(host)
if host == "" || host == "localhost" {
return true
}
if hostname, err := os.Hostname(); err == nil && host == strings.ToLower(hostname) {
return true
}
tlds := []string{
"localhost",
"local",
"internal",
}
// check if the host is a local TLD
for _, tld := range tlds {
if strings.HasSuffix(host, "."+tld) {
return true
}
}
return false
}
func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
return func(c *gin.Context) {
if addr == nil {
c.Next()
return
}
if addr, err := netip.ParseAddrPort(addr.String()); err == nil && !addr.Addr().IsLoopback() {
c.Next()
return
}
host, _, err := net.SplitHostPort(c.Request.Host)
if err != nil {
host = c.Request.Host
}
if addr, err := netip.ParseAddr(host); err == nil {
if addr.IsLoopback() || addr.IsPrivate() || addr.IsUnspecified() || isLocalIP(addr) {
c.Next()
return
}
}
if allowedHost(host) {
if c.Request.Method == http.MethodOptions {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
return
}
c.AbortWithStatus(http.StatusForbidden)
}
}
func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
corsConfig := cors.DefaultConfig()
corsConfig.AllowWildcard = true
corsConfig.AllowBrowserExtensions = true
corsConfig.AllowHeaders = []string{
"Authorization",
"Content-Type",
"User-Agent",
"Accept",
"X-Requested-With",
// OpenAI compatibility headers
"OpenAI-Beta",
"x-stainless-arch",
"x-stainless-async",
"x-stainless-custom-poll-interval",
"x-stainless-helper-method",
"x-stainless-lang",
"x-stainless-os",
"x-stainless-package-version",
"x-stainless-poll-helper",
"x-stainless-retry-count",
"x-stainless-runtime",
"x-stainless-runtime-version",
"x-stainless-timeout",
}
corsConfig.AllowOrigins = envconfig.AllowedOrigins()
r := gin.Default()
r.HandleMethodNotAllowed = true
r.Use(
cors.New(corsConfig),
allowedHostsMiddleware(s.addr),
)
// General
r.HEAD("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
r.GET("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
r.GET("/api/status", s.StatusHandler)
// Local model cache management (new implementation is at end of function)
r.POST("/api/pull", s.PullHandler)
r.POST("/api/push", s.PushHandler)
r.HEAD("/api/tags", s.ListHandler)
r.GET("/api/tags", s.ListHandler)
r.POST("/api/show", s.ShowHandler)
r.DELETE("/api/delete", s.DeleteHandler)
r.POST("/api/me", s.WhoamiHandler)
r.POST("/api/signout", s.SignoutHandler)
// deprecated
r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler)
// Create
r.POST("/api/create", s.CreateHandler)
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.POST("/api/copy", s.CopyHandler)
r.POST("/api/experimental/web_search", s.WebSearchExperimentalHandler)
r.POST("/api/experimental/web_fetch", s.WebFetchExperimentalHandler)
// Inference
r.GET("/api/ps", s.PsHandler)
r.POST("/api/generate", s.withInferenceRequestLogging("/api/generate", s.GenerateHandler)...)
r.POST("/api/chat", s.withInferenceRequestLogging("/api/chat", s.ChatHandler)...)
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
// Inference (OpenAI compatibility)
// TODO(cloud-stage-a): apply Modelfile overlay deltas for local models with cloud
// parents on v1 request families while preserving this explicit :cloud passthrough.
r.POST("/v1/chat/completions", s.withInferenceRequestLogging("/v1/chat/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ChatMiddleware(), s.ChatHandler)...)
r.POST("/v1/completions", s.withInferenceRequestLogging("/v1/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.CompletionsMiddleware(), s.GenerateHandler)...)
r.POST("/v1/embeddings", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", cloudModelPathPassthroughMiddleware(cloudErrRemoteModelDetailsUnavailable), middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", s.withInferenceRequestLogging("/v1/responses", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ResponsesMiddleware(), s.ChatHandler)...)
// OpenAI-compatible image generation endpoints
r.POST("/v1/images/generations", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
r.POST("/v1/images/edits", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageEditsMiddleware(), s.GenerateHandler)
// OpenAI-compatible audio endpoint
r.POST("/v1/audio/transcriptions", middleware.TranscriptionMiddleware(), s.ChatHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", s.withInferenceRequestLogging("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)...)
if rc != nil {
// wrap old with new
rs := &registry.Local{
Client: rc,
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
Fallback: r,
Prune: PruneLayers,
}
return rs, nil
}
return r, nil
}
func Serve(ln net.Listener) error {
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
slog.Info("server config", "env", envconfig.Values())
cloudDisabled, _ := internalcloud.Status()
slog.Info(fmt.Sprintf("Ollama cloud disabled: %t", cloudDisabled))
blobsDir, err := manifest.BlobsPath("")
if err != nil {
return err
}
if err := fixBlobs(blobsDir); err != nil {
return err
}
if !envconfig.NoPrune() {
if _, err := manifest.Manifests(false); err != nil {
slog.Warn("corrupt manifests detected, skipping prune operation. Re-pull or delete to clear", "error", err)
} else {
// clean up unused layers and manifests
if err := PruneLayers(); err != nil {
return err
}
manifestsPath, err := manifest.Path()
if err != nil {
return err
}
if err := manifest.PruneDirectory(manifestsPath); err != nil {
return err
}
}
}
s := &Server{addr: ln.Addr()}
if err := s.initRequestLogging(); err != nil {
return err
}
var rc *ollama.Registry
if useClient2 {
var err error
rc, err = ollama.DefaultRegistry()
if err != nil {
return err
}
}
h, err := s.GenerateRoutes(rc)
if err != nil {
return err
}
http.Handle("/", h)
ctx, done := context.WithCancel(context.Background())
schedCtx, schedDone := context.WithCancel(ctx)
sched := InitScheduler(schedCtx)
s.sched = sched
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
srvr := &http.Server{
// Use http.DefaultServeMux so we get net/http/pprof for
// free.
//
// TODO(bmizerany): Decide if we want to make this
// configurable so it is not exposed by default, or allow
// users to bind it to a different port. This was a quick
// and easy way to get pprof, but it may not be the best
// way.
Handler: nil,
}
// listen for a ctrl+c and stop any loaded llm
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-signals
srvr.Close()
schedDone()
sched.unloadAllRunners()
done()
}()
s.sched.Run(schedCtx)
// register the experimental webp decoder
// so webp images can be used in multimodal inputs
image.RegisterFormat("webp", "RIFF????WEBP", webp.Decode, webp.DecodeConfig)
// At startup we retrieve GPU information so we can get log messages before loading a model
// This will log warnings to the log in case we have problems with detected GPUs
gpus := discover.GPUDevices(ctx, nil)
discover.LogDetails(gpus)
var totalVRAM uint64
for _, gpu := range gpus {
totalVRAM += gpu.TotalMemory - envconfig.GpuOverhead()
}
// Set default context based on VRAM tier
// Use slightly lower thresholds (47/23 GiB vs. 48/24 GiB) to account for small differences in the exact value
switch {
case totalVRAM >= 47*format.GibiByte:
s.defaultNumCtx = 262144
case totalVRAM >= 23*format.GibiByte:
s.defaultNumCtx = 32768
default:
s.defaultNumCtx = 4096
}
slog.Info("vram-based default context", "total_vram", format.HumanBytes2(totalVRAM), "default_num_ctx", s.defaultNumCtx)
err = srvr.Serve(ln)
// If server is closed from the signal handler, wait for the ctx to be done
// otherwise error out quickly
if !errors.Is(err, http.ErrServerClosed) {
return err
}
<-ctx.Done()
return nil
}
func waitForStream(c *gin.Context, ch chan any) {
c.Header("Content-Type", "application/json")
var latest api.ProgressResponse
for resp := range ch {
switch r := resp.(type) {
case api.ProgressResponse:
latest = r
case gin.H:
status, ok := r["status"].(int)
if !ok {
status = http.StatusInternalServerError
}
errorMsg, ok := r["error"].(string)
if !ok {
errorMsg = "unknown error"
}
c.JSON(status, gin.H{"error": errorMsg})
return
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unknown message type"})
return
}
}
c.JSON(http.StatusOK, latest)
}
func streamResponse(c *gin.Context, ch chan any) {
c.Header("Content-Type", "application/x-ndjson")
c.Stream(func(w io.Writer) bool {
val, ok := <-ch
if !ok {
return false
}
// errors are provided as a gin.H with an "error" field and
// an optional "status" field. For errors that are streamed
// before any content, we need to set the status code and
// content type for the error.
if h, ok := val.(gin.H); ok {
if e, ok := h["error"].(string); ok {
status, ok := h["status"].(int)
if !ok {
status = http.StatusInternalServerError
}
if !c.Writer.Written() {
c.Header("Content-Type", "application/json")
c.JSON(status, gin.H{"error": e})
} else {
if err := json.NewEncoder(c.Writer).Encode(gin.H{"error": e}); err != nil {
slog.Error("streamResponse failed to encode json error", "error", err)
}
}
return false
}
}
bts, err := json.Marshal(val)
if err != nil {
slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
return false
}
// Delineate chunks with new-line delimiter
bts = append(bts, '\n')
if _, err := w.Write(bts); err != nil {
slog.Info(fmt.Sprintf("streamResponse: w.Write failed with %s", err))
return false
}
return true
})
}
func (s *Server) StatusHandler(c *gin.Context) {
disabled, source := internalcloud.Status()
c.JSON(http.StatusOK, api.StatusResponse{
Cloud: api.CloudStatus{
Disabled: disabled,
Source: source,
},
})
}
func (s *Server) WebSearchExperimentalHandler(c *gin.Context) {
s.webExperimentalProxyHandler(c, "/api/web_search", cloudErrWebSearchUnavailable)
}
func (s *Server) WebFetchExperimentalHandler(c *gin.Context) {
s.webExperimentalProxyHandler(c, "/api/web_fetch", cloudErrWebFetchUnavailable)
}
func (s *Server) webExperimentalProxyHandler(c *gin.Context, proxyPath, disabledOperation string) {
body, err := readRequestBody(c.Request)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if len(bytes.TrimSpace(body)) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
}
proxyCloudRequestWithPath(c, body, proxyPath, disabledOperation)
}
func (s *Server) WhoamiHandler(c *gin.Context) {
// todo allow other hosts
u, err := url.Parse("https://ollama.com")
if err != nil {
slog.Error(err.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"})
return
}
client := api.NewClient(u, http.DefaultClient)
user, err := client.Whoami(c)
if err != nil {
slog.Error(err.Error())
}
// user isn't signed in
if user != nil && user.Name == "" {
sURL, sErr := signinURL()
if sErr != nil {
slog.Error(sErr.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
return
}
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": sURL})
return
}
c.JSON(http.StatusOK, user)
}
func (s *Server) SignoutHandler(c *gin.Context) {
pubKey, err := auth.GetPublicKey()
if err != nil {
slog.Error("couldn't get public key", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
return
}
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
// todo allow other hosts
u, err := url.Parse("https://ollama.com")
if err != nil {
slog.Error(err.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"})
return
}
client := api.NewClient(u, http.DefaultClient)
err = client.Disconnect(c, encKey)
if err != nil {
var authError api.AuthorizationError
if errors.As(err, &authError) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "you are not currently signed in"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
return
}
c.JSON(http.StatusOK, nil)
}
func (s *Server) PsHandler(c *gin.Context) {
models := []api.ProcessModelResponse{}
for _, v := range s.sched.loaded {
model := v.model
modelDetails := api.ModelDetails{
Format: model.Config.ModelFormat,
Family: model.Config.ModelFamily,
Families: model.Config.ModelFamilies,
ParameterSize: model.Config.ModelType,
QuantizationLevel: model.Config.FileType,
}
mr := api.ProcessModelResponse{
Model: model.ShortName,
Name: model.ShortName,
Size: int64(v.totalSize),
SizeVRAM: int64(v.vramSize),
Digest: model.Digest,
Details: modelDetails,
ExpiresAt: v.expiresAt,
}
if v.llama != nil {
mr.ContextLength = v.llama.ContextLength()
total, vram := v.llama.MemorySize()
mr.Size = int64(total)
mr.SizeVRAM = int64(vram)
}
// The scheduler waits to set expiresAt, so if a model is loading it's
// possible that it will be set to the unix epoch. For those cases, just
// calculate the time w/ the sessionDuration instead.
var epoch time.Time
if v.expiresAt == epoch {
mr.ExpiresAt = time.Now().Add(v.sessionDuration)
}
models = append(models, mr)
}
slices.SortStableFunc(models, func(i, j api.ProcessModelResponse) int {
// longest duration remaining listed first
return cmp.Compare(j.ExpiresAt.Unix(), i.ExpiresAt.Unix())
})
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
}
func toolCallId() string {
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, 8)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return "call_" + strings.ToLower(string(b))
}
func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.ChatRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
return
}
modelRef, err := parseAndValidateModelRef(req.Model)
if err != nil {
writeModelRefParseError(c, err, http.StatusBadRequest, "model is required")
return
}
if modelRef.Source == modelSourceCloud {
req.Model = modelRef.Base
if c.GetBool(legacyCloudAnthropicKey) {
proxyCloudJSONRequestWithPath(c, req, "/api/chat", cloudErrRemoteInferenceUnavailable)
return
}
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
return
}
name := modelRef.Name
name, err = getExistingName(name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
m, err := GetModel(name.String())
if err != nil {
switch {
case os.IsNotExist(err):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == errtypes.InvalidModelNameErrMsg:
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
return
}
if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
// expire the runner
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
s.sched.expireRunner(m)
c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant"},
Done: true,
DoneReason: "unload",
})
return
}
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
if disabled, _ := internalcloud.Status(); disabled {
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)})
return
}
origModel := req.Model
remoteURL, err := url.Parse(m.Config.RemoteHost)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) {
slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname())
c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"})
return
}
req.Model = m.Config.RemoteModel
if req.Options == nil {
req.Options = map[string]any{}
}
var msgs []api.Message
if len(req.Messages) > 0 {
msgs = append(m.Messages, req.Messages...)
if req.Messages[0].Role != "system" && m.System != "" {
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
}
}
msgs = filterThinkTags(msgs, m)
req.Messages = msgs
for k, v := range m.Options {
if _, ok := req.Options[k]; !ok {
req.Options[k] = v
}
}
contentType := "application/x-ndjson"
if req.Stream != nil && !*req.Stream {
contentType = "application/json; charset=utf-8"
}
c.Header("Content-Type", contentType)
fn := func(resp api.ChatResponse) error {
resp.Model = origModel
resp.RemoteModel = m.Config.RemoteModel
resp.RemoteHost = m.Config.RemoteHost
data, err := json.Marshal(resp)
if err != nil {
return err
}
if _, err = c.Writer.Write(append(data, '\n')); err != nil {
return err
}
c.Writer.Flush()
return nil
}
client := api.NewClient(remoteURL, http.DefaultClient)
err = client.Chat(c, &req, fn)
if err != nil {
var authError api.AuthorizationError
if errors.As(err, &authError) {
sURL, sErr := signinURL()
if sErr != nil {
slog.Error(sErr.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
return
}
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
return
}
var apiError api.StatusError
if errors.As(err, &apiError) {
c.JSON(apiError.StatusCode, apiError)
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
return
}
caps := []model.Capability{model.CapabilityCompletion}
if len(req.Tools) > 0 {
caps = append(caps, model.CapabilityTools)
}
modelCaps := m.Capabilities()
if slices.Contains(modelCaps, model.CapabilityThinking) {
caps = append(caps, model.CapabilityThinking)
if req.Think == nil {
req.Think = &api.ThinkValue{Value: true}
}
} else {
if req.Think != nil && req.Think.Bool() {
// Set think to nil when being used with Anthropic API to connect to tools like claude code
if _, ok := c.Get("relax_thinking"); ok {
slog.Warn("model does not support thinking, relaxing thinking to nil", "model", req.Model)
req.Think = nil
} else {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
return
}
}
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
return
} else if err != nil {
handleScheduleError(c, req.Model, err)
return
}
checkpointLoaded := time.Now()
if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant"},
Done: true,
DoneReason: "load",
})
return
}
msgs := append(m.Messages, req.Messages...)
if req.Messages[0].Role != "system" && m.System != "" {
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
}
msgs = filterThinkTags(msgs, m)
if shouldUseHarmony(m) && m.Config.Parser == "" {
m.Config.Parser = "harmony"
}
var builtinParser parsers.Parser
processedTools := req.Tools
if m.Config.Parser != "" {
builtinParser = parsers.ParserForName(m.Config.Parser)
if builtinParser != nil {
// Determine last message for chat prefill
var lastMessage *api.Message
if len(msgs) > 0 {
lastMessage = &msgs[len(msgs)-1]
}
// Initialize parser and get processed tools
processedTools = builtinParser.Init(req.Tools, lastMessage, req.Think)
}
}
truncate := req.Truncate == nil || *req.Truncate
if m.IsMLX() {
truncate = false
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
if err != nil {
slog.Error("chat prompt error", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// If debug mode is enabled, return the rendered template instead of calling the model
if req.DebugRenderOnly {
c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
DebugInfo: &api.DebugInfo{
RenderedTemplate: prompt,
ImageCount: len(images),
},
})
return
}
var thinkingState *thinking.Parser
openingTag, closingTag := thinking.InferTags(m.Template.Template)
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
thinkingState = &thinking.Parser{
OpeningTag: openingTag,
ClosingTag: closingTag,
}
if strings.HasSuffix(strings.TrimSpace(prompt), openingTag) {
thinkingState.AddContent(openingTag)
}
}
var toolParser *tools.Parser
if len(req.Tools) > 0 && (builtinParser == nil || !builtinParser.HasToolSupport()) {
toolParser = tools.NewParser(m.Template.Template, req.Tools)
}
type structuredOutputsState int
const (
structuredOutputsState_None structuredOutputsState = iota
structuredOutputsState_ReadyToApply
structuredOutputsState_Applying
)
ch := make(chan any)
go func() {
defer close(ch)
structuredOutputsState := structuredOutputsState_None
for {
var tb strings.Builder
currentFormat := req.Format
// structured outputs via double request is enabled when:
// 1. the model supports the thinking capability and
// 2. it uses a built-in parser or our generic thinking parser
// Note that the current approach does not work for (potential future)
// non-thinking models that emit anything before actual content. This
// current approach uses the transition from parsed thinking content to
// parsed non-thinking content as the signal to turn constraining on
if req.Format != nil && structuredOutputsState == structuredOutputsState_None && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) {
currentFormat = nil
}
// sets up new context given parent context per request
ctx, cancel := context.WithCancel(c.Request.Context())
err := r.Completion(ctx, llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: currentFormat,
Options: opts,
Shift: req.Shift == nil || *req.Shift,
Truncate: truncate,
Logprobs: req.Logprobs,
TopLogprobs: req.TopLogprobs,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
Logprobs: toAPILogprobs(r.Logprobs),
}
if r.Done {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
if builtinParser != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
res.Message.Content = content
res.Message.Thinking = thinking
for i := range toolCalls {
toolCalls[i].ID = toolCallId()
}
res.Message.ToolCalls = toolCalls
tb.WriteString(thinking)
// we are now receiving content from the model - we should start applying structured outputs
if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && res.Message.Content != "" {
structuredOutputsState = structuredOutputsState_ReadyToApply
cancel()
return
}
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done || len(res.Logprobs) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
ch <- res
} else {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
}
return
}
if thinkingState != nil {
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
if thinkingContent == "" && remainingContent == "" && !r.Done {
// need to accumulate more to decide what to send
return
}
res.Message.Thinking = thinkingContent
tb.WriteString(thinkingContent)
// emit the collected thinking text before restarting with structured outputs and clear unstructured content
// to avoid leaking mixed tokens like "</think>Hello"
if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && remainingContent != "" {
structuredOutputsState = structuredOutputsState_ReadyToApply
res.Message.Content = ""
ch <- res
cancel()
return
}
res.Message.Content = remainingContent
}
if len(req.Tools) > 0 {
toolCalls, content := toolParser.Add(res.Message.Content)
if len(content) > 0 {
res.Message.Content = content
} else if len(toolCalls) > 0 {
for i := range toolCalls {
toolCalls[i].ID = toolCallId()
}
res.Message.ToolCalls = toolCalls
res.Message.Content = ""
} else if res.Message.Thinking != "" {
// don't return, fall through to send
} else {
// Send logprobs while content is being buffered by the parser for tool calls
if len(res.Logprobs) > 0 && !r.Done {
logprobRes := res
logprobRes.Message.Content = ""
logprobRes.Message.ToolCalls = nil
ch <- logprobRes
}
if r.Done {
res.Message.Content = toolParser.Content()
ch <- res
}
return
}
}
ch <- res
})
if err != nil {
if structuredOutputsState == structuredOutputsState_ReadyToApply && strings.Contains(err.Error(), "context canceled") && c.Request.Context().Err() == nil {
// only ignores error if it's a context cancellation due to setting structured outputs
} else {
var serr api.StatusError
if errors.As(err, &serr) {
ch <- gin.H{"error": serr.ErrorMessage, "status": serr.StatusCode}
} else {
ch <- gin.H{"error": err.Error()}
}
return
}
}
// ignored structured outputs cancellation falls through to here, start a new request with the structured outputs and updated prompt. use the
if structuredOutputsState == structuredOutputsState_ReadyToApply {
structuredOutputsState = structuredOutputsState_Applying
msg := api.Message{
Role: "assistant",
Thinking: tb.String(),
}
msgs = append(msgs, msg)
prompt, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
if err != nil {
slog.Error("chat prompt error applying structured outputs", "error", err)
ch <- gin.H{"error": err.Error()}
return
}
// force constraining by terminating thinking header, the parser is already at this state
// when the last message is thinking, the rendered for gpt-oss cannot disambiguate between having the
// model continue thinking or ending thinking and outputting the final message.
// TODO(parthsareen): consider adding prefill disambiguation logic to the renderer for structured outputs.
if shouldUseHarmony(m) || (builtinParser != nil && m.Config.Parser == "harmony") {
prompt += "<|end|><|start|>assistant<|channel|>final<|message|>"
}
continue
}
break
}
}()
if req.Stream != nil && !*req.Stream {
var resp api.ChatResponse
var toolCalls []api.ToolCall
var allLogprobs []api.Logprob
var sbThinking strings.Builder
var sbContent strings.Builder
for rr := range ch {
switch t := rr.(type) {
case api.ChatResponse:
sbThinking.WriteString(t.Message.Thinking)
sbContent.WriteString(t.Message.Content)
resp = t
if len(req.Tools) > 0 {
toolCalls = append(toolCalls, t.Message.ToolCalls...)
}
// Accumulate logprobs from all chunks for non-streaming response
if len(t.Logprobs) > 0 {
allLogprobs = append(allLogprobs, t.Logprobs...)
}
case gin.H:
msg, ok := t["error"].(string)
if !ok {
msg = "unexpected error format in response"
}
status, ok := t["status"].(int)
if !ok {
status = http.StatusInternalServerError
}
c.JSON(status, gin.H{"error": msg})
return
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
return
}
}
resp.Message.Content = sbContent.String()
resp.Message.Thinking = sbThinking.String()
resp.Logprobs = allLogprobs
if len(toolCalls) > 0 {
resp.Message.ToolCalls = toolCalls
}
c.JSON(http.StatusOK, resp)
return
}
streamResponse(c, ch)
}
func handleScheduleError(c *gin.Context, name string, err error) {
switch {
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
case errors.Is(err, context.Canceled):
c.JSON(499, gin.H{"error": "request canceled"})
case errors.Is(err, ErrMaxQueue):
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
case errors.Is(err, os.ErrNotExist):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
}
func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
finalUserIndex := -1
for i, msg := range msgs {
if msg.Role == "user" {
finalUserIndex = i
}
}
for i, msg := range msgs {
if msg.Role == "assistant" && i < finalUserIndex {
// TODO(drifkin): this is from before we added proper thinking support.
// However, even if thinking is not enabled (and therefore we shouldn't
// change the user output), we should probably perform this filtering
// for all thinking models (not just qwen3 & deepseek-r1) since it tends
// to save tokens and improve quality.
thinkingState := &thinking.Parser{
OpeningTag: "<think>",
ClosingTag: "</think>",
}
_, content := thinkingState.AddContent(msg.Content)
msgs[i].Content = content
}
}
}
return msgs
}
// handleImageGenerate handles image generation requests within GenerateHandler.
// This is called when the model has the Image capability.
func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, modelName string, checkpointStart time.Time) {
// Validate image dimensions
const maxDimension int32 = 4096
if req.Width > maxDimension || req.Height > maxDimension {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("width and height must be <= %d", maxDimension)})
return
}
// Schedule the runner for image generation
runner, _, _, err := s.scheduleRunner(c.Request.Context(), modelName, []model.Capability{model.CapabilityImage}, nil, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
}
checkpointLoaded := time.Now()
// Handle load-only request (empty prompt)
if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: true,
DoneReason: "load",
})
return
}
// Check streaming preference
isStreaming := req.Stream == nil || *req.Stream
contentType := "application/x-ndjson"
if !isStreaming {
contentType = "application/json; charset=utf-8"
}
c.Header("Content-Type", contentType)
// Get seed from options if provided
var seed int64
if s, ok := req.Options["seed"]; ok {
switch v := s.(type) {
case int:
seed = int64(v)
case int64:
seed = v
case float64:
seed = int64(v)
}
}
var images []llm.ImageData
for i, imgData := range req.Images {
images = append(images, llm.ImageData{ID: i, Data: imgData})
}
var streamStarted bool
var finalResponse api.GenerateResponse
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: seed,
Images: images,
}, func(cr llm.CompletionResponse) {
streamStarted = true
res := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: cr.Done,
}
if cr.TotalSteps > 0 {
res.Completed = int64(cr.Step)
res.Total = int64(cr.TotalSteps)
}
if cr.Image != "" {
res.Image = cr.Image
}
if cr.Done {
res.DoneReason = cr.DoneReason.String()
res.Metrics.TotalDuration = time.Since(checkpointStart)
res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
if !isStreaming {
finalResponse = res
return
}
data, _ := json.Marshal(res)
c.Writer.Write(append(data, '\n'))
c.Writer.Flush()
}); err != nil {
// Only send JSON error if streaming hasn't started yet
// (once streaming starts, headers are committed and we can't change status code)
if !streamStarted {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
if !isStreaming {
c.JSON(http.StatusOK, finalResponse)
}
}