mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 19:54:03 +02:00
* bench: add prompt calibration, context size flag, and NumCtx reporting Add --num-ctx flag to set context size, and report NumCtx in model info header. Calibrate tokens-per-word ratio during warmup using actual tokenization metrics from the model, replacing the fixed 1.3 heuristic. This produces more accurate prompt token counts for --prompt-tokens. Also add fetchContextLength() to query running model context via /api/ps. * integration: improve vision test robustness and add thinking tests Add skipIfNoVisionOverride() to skip vision tests when OLLAMA_TEST_MODEL is set to a non-vision model. Add Think:false to context exhaustion test to prevent thinking models from using all context before the test can measure it. Add third test image (ollama homepage) and replace OCR test with ImageDescription test using it. Relax match strings for broader model compatibility. Add TestThinkingEnabled and TestThinkingSuppressed to verify thinking output and channel tag handling. * gemma4: add Gemma 4 GGML model support Add full Gemma 4 model family support (E2B, E4B, 26B MoE, 31B Dense) for the GGML backend including text, vision, converter, parser, and renderer. Text model features: - Sliding window + full attention with per-layer patterns - KV sharing across layers with donor map - Per-layer embeddings (PLE) with learned projections - MoE routing with RMSNorm + learned scale - Proportional RoPE with freq_factors for global attention - Final logit softcapping Vision model features: - SigLIP vision encoder with 2D RoPE - ClippableLinear with input/output clamping via packed v.clamp_data - Adaptive average pooling with nMerge kernel - Multi-modal projection with unweighted RMSNorm Converter: - Safetensors to GGUF with vision tensor renaming - Fused MoE gate_up_proj splitting - Vision patch embedding reshape (HF to Conv2D layout) - Packed clamp data tensor for ClippableLinear bounds - Proportional RoPE freq_factors generation Also includes: - BackendGet() on ml.Tensor for reading weight tensor data - Q6_K CUDA get_rows kernel support - MoE-aware ffn_down quantization layer counting - Gemma4 parser with tool calling and thinking support - Gemma4 renderer with structured tool format - Architecture-based auto-detection of renderer/parser/stop tokens - Integration test gemma4 model list additions * gemma4: add audio support with USM conformer encoder Add audio encoding for Gemma 4 using the USM conformer architecture: - Converter: audio tensor mapping, SSCP/conformer/embedder name replacements, softplus repacker for per_dim_scale, F32 enforcement for conv weights - GGML backend: Conv1DDW and PadExt tensor ops - Audio encoder: SSCP Conv2D, 12 conformer blocks (FFW + block-local attention with relative position embeddings + LightConv1d + FFW), output projection, audio-to-text embedding projector - Audio preprocessing: WAV decode, mel spectrogram, FFT (pure Go) - Model wiring: WAV detection, audio token handling, unified PostTokenize Correctly transcribes "why is the sky blue" from test audio. * integration: add gemma4 audio tests including OpenAI API coverage Test audio transcription and response via the Ollama native API, plus two new tests exercising the OpenAI-compatible endpoints: - /v1/audio/transcriptions (multipart form upload) - /v1/chat/completions with input_audio content type All tests use capability checks and skip models without audio support. * gemma4: add OpenAI audio API support and capability detection - Add CapabilityAudio and detect from audio.block_count in GGUF - Add /v1/audio/transcriptions endpoint with TranscriptionMiddleware - Add input_audio content type support in /v1/chat/completions - Add TranscriptionRequest/Response types in openai package * gemma4: add audio input support for run command - /audio toggle in interactive mode for voice chat - Platform-specific microphone recording (AVFoundation on macOS, PulseAudio/ALSA on Linux, WASAPI on Windows) - Space to start/stop recording, automatic chunking for long audio * gemma4: add transcribe command (ollama transcribe MODEL) - Interactive mode with readline prompt and slash commands - Non-interactive mode for piped audio or record-until-Ctrl+C - Chunked streaming transcription for long recordings - Word-wrapped output matching run command style * gemma4: add parser, renderer, and integration test plumbing * gemma4: fix renderer to emit BOS token * gemma4: add OpenAI audio transcription API and input_audio support * gemma4: update converter for new weight drop naming * gemma4: add per_expert_scale to MoE router and fix moe_intermediate_size config * gemma4: rewrite renderer to match HF Jinja2 template exactly Fix 8 bugs found by building 55 reference tests verified against the HF Jinja2 chat template (VERIFY_JINJA2=1 shells out to Python): - Tool responses use separate <|turn>tool turns (not inline tags) - Tool calls emitted before content in assistant messages - Thinking content stripped from assistant history (strip_thinking) - User, tool, and system content trimmed (template does | trim) - Empty system message still emits system turn (check role, not content) - Nested object properties rendered recursively with required field - Array items specification rendered for array-type properties - OBJECT/ARRAY type-specific rendering comma logic matches template Also adds Required field to api.ToolProperty for nested object schemas, replaces old gemma4_test.go with comprehensive gemma4_reference_test.go, and commits the Jinja2 template as testdata for verification. * gemma4: fix MoE fused gate_up split and multiline tool-call arg parsing - Text MoE: split `ffn_gate_up_exps` into contiguous `[gate|up]` halves instead of stride-2 slices. - Parser: escape control characters in `<|"|>...<|"|>` string literals when converting tool-call args to JSON. - Fixes warnings like `invalid character '\n' in string literal` for multiline tool arguments. - Add Gemma4 parser regressions for multiline tool-call args and `gemma4ArgsToJSON`. * cmd: simplify audio input to dropped file attachments * gemma4: use full SWA memory for better cache reuse * gemma4: initialize clamps after backend load * convert: align gemma4 audio tensor renames with llama.cpp * Remove redundant comments in gemma4 vision model * Format Gemma4 MoE block field alignment * use 4096 kvcache.NewSWAMemCache * convert: support new Gemma4 audio_tower tensor naming (#15221) Co-authored-by: jmorganca <jmorganca@gmail.com> * fix integration test defaults for audio * review comments and lint fixes * remove unused audio/video files --------- Co-authored-by: jmorganca <jmorganca@gmail.com>
2773 lines
78 KiB
Go
2773 lines
78 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"cmp"
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"image"
|
|
"io"
|
|
"io/fs"
|
|
"log/slog"
|
|
"math"
|
|
"math/rand"
|
|
"net"
|
|
"net/http"
|
|
"net/netip"
|
|
"net/url"
|
|
"os"
|
|
"os/signal"
|
|
"slices"
|
|
"strings"
|
|
"sync/atomic"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/gin-contrib/cors"
|
|
"github.com/gin-gonic/gin"
|
|
"golang.org/x/image/webp"
|
|
"golang.org/x/sync/errgroup"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/auth"
|
|
"github.com/ollama/ollama/discover"
|
|
"github.com/ollama/ollama/envconfig"
|
|
"github.com/ollama/ollama/format"
|
|
"github.com/ollama/ollama/fs/ggml"
|
|
internalcloud "github.com/ollama/ollama/internal/cloud"
|
|
"github.com/ollama/ollama/llm"
|
|
"github.com/ollama/ollama/logutil"
|
|
"github.com/ollama/ollama/manifest"
|
|
"github.com/ollama/ollama/middleware"
|
|
"github.com/ollama/ollama/model/parsers"
|
|
"github.com/ollama/ollama/model/renderers"
|
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
|
"github.com/ollama/ollama/server/internal/registry"
|
|
"github.com/ollama/ollama/template"
|
|
"github.com/ollama/ollama/thinking"
|
|
"github.com/ollama/ollama/tools"
|
|
"github.com/ollama/ollama/types/errtypes"
|
|
"github.com/ollama/ollama/types/model"
|
|
"github.com/ollama/ollama/version"
|
|
imagegenmanifest "github.com/ollama/ollama/x/imagegen/manifest"
|
|
xserver "github.com/ollama/ollama/x/server"
|
|
)
|
|
|
|
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
|
|
|
const (
|
|
cloudErrRemoteInferenceUnavailable = "remote model is unavailable"
|
|
cloudErrRemoteModelDetailsUnavailable = "remote model details are unavailable"
|
|
cloudErrWebSearchUnavailable = "web search is unavailable"
|
|
cloudErrWebFetchUnavailable = "web fetch is unavailable"
|
|
copilotChatUserAgentPrefix = "GitHubCopilotChat/"
|
|
)
|
|
|
|
func writeModelRefParseError(c *gin.Context, err error, fallbackStatus int, fallbackMessage string) {
|
|
switch {
|
|
case errors.Is(err, errConflictingModelSource):
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
case errors.Is(err, model.ErrUnqualifiedName):
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
|
default:
|
|
c.JSON(fallbackStatus, gin.H{"error": fallbackMessage})
|
|
}
|
|
}
|
|
|
|
func shouldUseHarmony(model *Model) bool {
|
|
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
|
// heuristic to check whether the template expects to be parsed via harmony:
|
|
// search for harmony tags that are nearly always used
|
|
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func experimentEnabled(name string) bool {
|
|
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
|
|
}
|
|
|
|
var useClient2 = experimentEnabled("client2")
|
|
|
|
var mode string = gin.DebugMode
|
|
|
|
type Server struct {
|
|
addr net.Addr
|
|
sched *Scheduler
|
|
defaultNumCtx int
|
|
requestLogger *inferenceRequestLogger
|
|
}
|
|
|
|
func init() {
|
|
switch mode {
|
|
case gin.DebugMode:
|
|
case gin.ReleaseMode:
|
|
case gin.TestMode:
|
|
default:
|
|
mode = gin.DebugMode
|
|
}
|
|
|
|
gin.SetMode(mode)
|
|
|
|
// Tell renderers to use [img] tags
|
|
renderers.RenderImgTags = true
|
|
}
|
|
|
|
var (
|
|
errRequired = errors.New("is required")
|
|
errBadTemplate = errors.New("template error")
|
|
)
|
|
|
|
func (s *Server) modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) {
|
|
opts := api.DefaultOptions()
|
|
if opts.NumCtx == 0 {
|
|
opts.NumCtx = s.defaultNumCtx
|
|
}
|
|
|
|
if err := opts.FromMap(model.Options); err != nil {
|
|
return api.Options{}, err
|
|
}
|
|
|
|
if err := opts.FromMap(requestOpts); err != nil {
|
|
return api.Options{}, err
|
|
}
|
|
|
|
return opts, nil
|
|
}
|
|
|
|
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
|
|
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
|
|
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
|
if name == "" {
|
|
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
|
|
}
|
|
|
|
model, err := GetModel(name)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
if slices.Contains(model.Config.ModelFamilies, "mllama") && len(model.ProjectorPaths) > 0 {
|
|
return nil, nil, nil, fmt.Errorf("'llama3.2-vision' is no longer compatible with your version of Ollama and has been replaced by a newer version. To re-download, run 'ollama pull llama3.2-vision'")
|
|
}
|
|
|
|
if err := model.CheckCapabilities(caps...); err != nil {
|
|
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
|
|
}
|
|
|
|
// Deprecated runner override option; ignore if present.
|
|
delete(requestOpts, "use_imagegen_runner")
|
|
|
|
opts, err := s.modelOptions(model, requestOpts)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
|
|
var runner *runnerRef
|
|
select {
|
|
case runner = <-runnerCh:
|
|
case err = <-errCh:
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
return runner.llama, model, &opts, nil
|
|
}
|
|
|
|
func signinURL() (string, error) {
|
|
pubKey, err := auth.GetPublicKey()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
|
h, _ := os.Hostname()
|
|
return fmt.Sprintf(signinURLStr, url.PathEscape(h), encKey), nil
|
|
}
|
|
|
|
func (s *Server) GenerateHandler(c *gin.Context) {
|
|
checkpointStart := time.Now()
|
|
var req api.GenerateRequest
|
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
return
|
|
} else if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
|
|
return
|
|
}
|
|
|
|
modelRef, err := parseAndValidateModelRef(req.Model)
|
|
if err != nil {
|
|
writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model))
|
|
return
|
|
}
|
|
|
|
if modelRef.Source == modelSourceCloud {
|
|
// TODO(drifkin): evaluate an `/api/*` passthrough for cloud where the
|
|
// original body (modulo model name normalization) is sent to cloud.
|
|
req.Model = modelRef.Base
|
|
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
|
return
|
|
}
|
|
|
|
name := modelRef.Name
|
|
|
|
// We cannot currently consolidate this into GetModel because all we'll
|
|
// induce infinite recursion given the current code structure.
|
|
name, err = getExistingName(name)
|
|
if err != nil {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
|
return
|
|
}
|
|
|
|
m, err := GetModel(name.String())
|
|
if err != nil {
|
|
switch {
|
|
case errors.Is(err, fs.ErrNotExist):
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
|
case err.Error() == errtypes.InvalidModelNameErrMsg:
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
default:
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
}
|
|
return
|
|
}
|
|
|
|
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
|
|
return
|
|
}
|
|
|
|
if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
|
return
|
|
}
|
|
|
|
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
|
if disabled, _ := internalcloud.Status(); disabled {
|
|
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)})
|
|
return
|
|
}
|
|
|
|
origModel := req.Model
|
|
|
|
remoteURL, err := url.Parse(m.Config.RemoteHost)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) {
|
|
slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname())
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"})
|
|
return
|
|
}
|
|
|
|
req.Model = m.Config.RemoteModel
|
|
|
|
if req.Template == "" && m.Template.String() != "" {
|
|
req.Template = m.Template.String()
|
|
}
|
|
|
|
if req.Options == nil {
|
|
req.Options = map[string]any{}
|
|
}
|
|
|
|
for k, v := range m.Options {
|
|
if _, ok := req.Options[k]; !ok {
|
|
req.Options[k] = v
|
|
}
|
|
}
|
|
|
|
// update the system prompt from the model if one isn't already specified
|
|
if req.System == "" && m.System != "" {
|
|
req.System = m.System
|
|
}
|
|
|
|
if len(m.Messages) > 0 {
|
|
slog.Warn("embedded messages in the model not supported with '/api/generate'; try '/api/chat' instead")
|
|
}
|
|
|
|
contentType := "application/x-ndjson"
|
|
if req.Stream != nil && !*req.Stream {
|
|
contentType = "application/json; charset=utf-8"
|
|
}
|
|
c.Header("Content-Type", contentType)
|
|
|
|
fn := func(resp api.GenerateResponse) error {
|
|
resp.Model = origModel
|
|
resp.RemoteModel = m.Config.RemoteModel
|
|
resp.RemoteHost = m.Config.RemoteHost
|
|
|
|
data, err := json.Marshal(resp)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err = c.Writer.Write(append(data, '\n')); err != nil {
|
|
return err
|
|
}
|
|
c.Writer.Flush()
|
|
return nil
|
|
}
|
|
|
|
client := api.NewClient(remoteURL, http.DefaultClient)
|
|
err = client.Generate(c, &req, fn)
|
|
if err != nil {
|
|
var authError api.AuthorizationError
|
|
if errors.As(err, &authError) {
|
|
sURL, sErr := signinURL()
|
|
if sErr != nil {
|
|
slog.Error(sErr.Error())
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
|
|
return
|
|
}
|
|
|
|
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
|
|
return
|
|
}
|
|
var apiError api.StatusError
|
|
if errors.As(err, &apiError) {
|
|
c.JSON(apiError.StatusCode, apiError)
|
|
return
|
|
}
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// expire the runner if unload is requested (empty prompt, keep alive is 0)
|
|
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
|
s.sched.expireRunner(m)
|
|
|
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
|
Model: req.Model,
|
|
CreatedAt: time.Now().UTC(),
|
|
Response: "",
|
|
Done: true,
|
|
DoneReason: "unload",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Handle image generation models
|
|
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
|
s.handleImageGenerate(c, req, name.String(), checkpointStart)
|
|
return
|
|
}
|
|
|
|
if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
|
return
|
|
}
|
|
|
|
var builtinParser parsers.Parser
|
|
if shouldUseHarmony(m) && m.Config.Parser == "" {
|
|
m.Config.Parser = "harmony"
|
|
}
|
|
|
|
if !req.Raw && m.Config.Parser != "" {
|
|
builtinParser = parsers.ParserForName(m.Config.Parser)
|
|
if builtinParser != nil {
|
|
// no tools or last message for generate endpoint
|
|
builtinParser.Init(nil, nil, req.Think)
|
|
}
|
|
}
|
|
|
|
caps := []model.Capability{model.CapabilityCompletion}
|
|
if req.Suffix != "" {
|
|
caps = append(caps, model.CapabilityInsert)
|
|
}
|
|
|
|
modelCaps := m.Capabilities()
|
|
if slices.Contains(modelCaps, model.CapabilityThinking) {
|
|
caps = append(caps, model.CapabilityThinking)
|
|
if req.Think == nil {
|
|
req.Think = &api.ThinkValue{Value: true}
|
|
}
|
|
} else {
|
|
if req.Think != nil && req.Think.Bool() {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
|
return
|
|
}
|
|
}
|
|
|
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
|
if errors.Is(err, errCapabilityCompletion) {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
|
return
|
|
} else if err != nil {
|
|
handleScheduleError(c, req.Model, err)
|
|
return
|
|
}
|
|
|
|
checkpointLoaded := time.Now()
|
|
|
|
// load the model
|
|
if req.Prompt == "" {
|
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
|
Model: req.Model,
|
|
CreatedAt: time.Now().UTC(),
|
|
Done: true,
|
|
DoneReason: "load",
|
|
})
|
|
return
|
|
}
|
|
|
|
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(req.Images) > 1 {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image while more than one image requested"})
|
|
return
|
|
}
|
|
|
|
images := make([]llm.ImageData, len(req.Images))
|
|
for i := range req.Images {
|
|
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
|
|
}
|
|
|
|
prompt := req.Prompt
|
|
if !req.Raw {
|
|
tmpl := m.Template
|
|
if req.Template != "" {
|
|
tmpl, err = template.Parse(req.Template)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
}
|
|
|
|
var values template.Values
|
|
if req.Suffix != "" {
|
|
values.Prompt = prompt
|
|
values.Suffix = req.Suffix
|
|
} else {
|
|
var msgs []api.Message
|
|
if req.System != "" {
|
|
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
|
|
} else if m.System != "" {
|
|
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
|
|
}
|
|
|
|
if req.Context == nil {
|
|
msgs = append(msgs, m.Messages...)
|
|
}
|
|
|
|
userMsg := api.Message{Role: "user", Content: req.Prompt}
|
|
for _, i := range images {
|
|
userMsg.Images = append(userMsg.Images, i.Data)
|
|
}
|
|
values.Messages = append(msgs, userMsg)
|
|
}
|
|
|
|
values.Think = req.Think != nil && req.Think.Bool()
|
|
values.ThinkLevel = ""
|
|
if req.Think != nil {
|
|
values.ThinkLevel = req.Think.String()
|
|
}
|
|
values.IsThinkSet = req.Think != nil
|
|
|
|
var b bytes.Buffer
|
|
if req.Context != nil {
|
|
slog.Warn("the context field is deprecated and will be removed in a future version of Ollama")
|
|
s, err := r.Detokenize(c.Request.Context(), req.Context)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
b.WriteString(s)
|
|
}
|
|
|
|
// check that we're in the `api/chat`-like flow, and if so, generate the
|
|
// prompt the same way
|
|
// TEMP(drifkin): we should really just detect the chat-like flow and call
|
|
// the real chat handler, but doing this as a stopgap to get renderer
|
|
// support for generate
|
|
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
|
genTruncate := (req.Truncate == nil || *req.Truncate) && !m.IsMLX()
|
|
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, genTruncate)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
// TEMP(drifkin): req.Context will be removed very soon, but we're temporarily supporting it in this flow here
|
|
if req.Context != nil {
|
|
b.WriteString(prompt)
|
|
prompt = b.String()
|
|
}
|
|
} else {
|
|
// legacy flow
|
|
if err := tmpl.Execute(&b, values); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
prompt = b.String()
|
|
}
|
|
}
|
|
|
|
// If debug mode is enabled, return the rendered template instead of calling the model
|
|
if req.DebugRenderOnly {
|
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
|
Model: req.Model,
|
|
CreatedAt: time.Now().UTC(),
|
|
DebugInfo: &api.DebugInfo{
|
|
RenderedTemplate: prompt,
|
|
ImageCount: len(images),
|
|
},
|
|
})
|
|
return
|
|
}
|
|
|
|
var thinkingState *thinking.Parser
|
|
if builtinParser == nil {
|
|
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
|
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
|
|
thinkingState = &thinking.Parser{
|
|
OpeningTag: openingTag,
|
|
ClosingTag: closingTag,
|
|
}
|
|
if strings.HasSuffix(strings.TrimSpace(prompt), openingTag) {
|
|
thinkingState.AddContent(openingTag)
|
|
}
|
|
}
|
|
}
|
|
|
|
ch := make(chan any)
|
|
go func() {
|
|
// TODO (jmorganca): avoid building the response twice both here and below
|
|
var sb strings.Builder
|
|
defer close(ch)
|
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
|
Prompt: prompt,
|
|
Images: images,
|
|
Format: req.Format,
|
|
Options: opts,
|
|
Shift: req.Shift == nil || *req.Shift,
|
|
Truncate: req.Truncate == nil || *req.Truncate,
|
|
Logprobs: req.Logprobs,
|
|
TopLogprobs: req.TopLogprobs,
|
|
}, func(cr llm.CompletionResponse) {
|
|
res := api.GenerateResponse{
|
|
Model: req.Model,
|
|
CreatedAt: time.Now().UTC(),
|
|
Response: cr.Content,
|
|
Done: cr.Done,
|
|
Metrics: api.Metrics{
|
|
PromptEvalCount: cr.PromptEvalCount,
|
|
PromptEvalDuration: cr.PromptEvalDuration,
|
|
EvalCount: cr.EvalCount,
|
|
EvalDuration: cr.EvalDuration,
|
|
},
|
|
Logprobs: toAPILogprobs(cr.Logprobs),
|
|
}
|
|
|
|
if builtinParser != nil {
|
|
content, thinking, toolCalls, err := builtinParser.Add(cr.Content, cr.Done)
|
|
if err != nil {
|
|
ch <- gin.H{"error": err.Error()}
|
|
return
|
|
}
|
|
res.Response = content
|
|
res.Thinking = thinking
|
|
if cr.Done && len(toolCalls) > 0 {
|
|
res.ToolCalls = toolCalls
|
|
}
|
|
} else if thinkingState != nil {
|
|
thinking, content := thinkingState.AddContent(cr.Content)
|
|
res.Thinking = thinking
|
|
res.Response = content
|
|
}
|
|
|
|
if _, err := sb.WriteString(cr.Content); err != nil {
|
|
ch <- gin.H{"error": err.Error()}
|
|
}
|
|
|
|
if cr.Done {
|
|
res.DoneReason = cr.DoneReason.String()
|
|
res.TotalDuration = time.Since(checkpointStart)
|
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
|
|
if !req.Raw {
|
|
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
|
|
if err != nil {
|
|
ch <- gin.H{"error": err.Error()}
|
|
return
|
|
}
|
|
res.Context = tokens
|
|
}
|
|
}
|
|
|
|
if builtinParser != nil {
|
|
// only send messages with meaningful content (empty messages confuse clients)
|
|
if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 {
|
|
ch <- res
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
ch <- res
|
|
}); err != nil {
|
|
var serr api.StatusError
|
|
if errors.As(err, &serr) {
|
|
ch <- gin.H{"error": serr.ErrorMessage, "status": serr.StatusCode}
|
|
} else {
|
|
ch <- gin.H{"error": err.Error()}
|
|
}
|
|
}
|
|
}()
|
|
|
|
if req.Stream != nil && !*req.Stream {
|
|
var r api.GenerateResponse
|
|
var allLogprobs []api.Logprob
|
|
var sbThinking strings.Builder
|
|
var sbContent strings.Builder
|
|
for rr := range ch {
|
|
switch t := rr.(type) {
|
|
case api.GenerateResponse:
|
|
sbThinking.WriteString(t.Thinking)
|
|
sbContent.WriteString(t.Response)
|
|
r = t
|
|
// Accumulate logprobs from all chunks for non-streaming response
|
|
if len(t.Logprobs) > 0 {
|
|
allLogprobs = append(allLogprobs, t.Logprobs...)
|
|
}
|
|
case gin.H:
|
|
msg, ok := t["error"].(string)
|
|
if !ok {
|
|
msg = "unexpected error format in response"
|
|
}
|
|
|
|
status, ok := t["status"].(int)
|
|
if !ok {
|
|
status = http.StatusInternalServerError
|
|
}
|
|
|
|
c.JSON(status, gin.H{"error": msg})
|
|
return
|
|
default:
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
|
return
|
|
}
|
|
}
|
|
|
|
r.Thinking = sbThinking.String()
|
|
r.Response = sbContent.String()
|
|
r.Logprobs = allLogprobs
|
|
|
|
c.JSON(http.StatusOK, r)
|
|
return
|
|
}
|
|
|
|
streamResponse(c, ch)
|
|
}
|
|
|
|
func (s *Server) EmbedHandler(c *gin.Context) {
|
|
checkpointStart := time.Now()
|
|
var req api.EmbedRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
switch {
|
|
case errors.Is(err, io.EOF):
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
return
|
|
case err != nil:
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
modelRef, err := parseAndValidateModelRef(req.Model)
|
|
if err != nil {
|
|
writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model))
|
|
return
|
|
}
|
|
|
|
if modelRef.Source == modelSourceCloud {
|
|
req.Model = modelRef.Base
|
|
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
|
return
|
|
}
|
|
|
|
var input []string
|
|
|
|
switch i := req.Input.(type) {
|
|
case string:
|
|
if len(i) > 0 {
|
|
input = append(input, i)
|
|
}
|
|
case []any:
|
|
for _, v := range i {
|
|
if _, ok := v.(string); !ok {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
return
|
|
}
|
|
input = append(input, v.(string))
|
|
}
|
|
default:
|
|
if req.Input != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
return
|
|
}
|
|
}
|
|
|
|
name, err := getExistingName(modelRef.Name)
|
|
if err != nil {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
|
return
|
|
}
|
|
|
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
|
if err != nil {
|
|
handleScheduleError(c, req.Model, err)
|
|
return
|
|
}
|
|
|
|
checkpointLoaded := time.Now()
|
|
|
|
if len(input) == 0 {
|
|
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
|
return
|
|
}
|
|
|
|
kvData, _, err := getModelData(m.ModelPath, false)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
ctx := c.Request.Context()
|
|
|
|
embedWithRetry := func(text string) ([]float32, int, error) {
|
|
emb, tokCount, err := r.Embedding(ctx, text)
|
|
if err == nil {
|
|
return emb, tokCount, nil
|
|
}
|
|
|
|
var serr api.StatusError
|
|
if !errors.As(err, &serr) || serr.StatusCode != http.StatusBadRequest {
|
|
return nil, 0, err
|
|
}
|
|
if req.Truncate != nil && !*req.Truncate {
|
|
return nil, 0, err
|
|
}
|
|
|
|
tokens, err := r.Tokenize(ctx, text)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
// TODO @nicolepardal: avoid reaching into kvData here; pass required tokenizer metadata via model/options instead
|
|
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
|
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); len(tokens) > 0 && tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
|
|
ctxLen--
|
|
}
|
|
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); len(tokens) > 0 && tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
|
|
ctxLen--
|
|
}
|
|
|
|
if len(tokens) <= ctxLen {
|
|
return nil, 0, fmt.Errorf("input exceeds maximum context length and cannot be truncated further")
|
|
}
|
|
if ctxLen <= 0 {
|
|
return nil, 0, fmt.Errorf("input after truncation exceeds maximum context length")
|
|
}
|
|
|
|
truncatedTokens := tokens[:ctxLen]
|
|
truncated, err := r.Detokenize(ctx, truncatedTokens)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
return r.Embedding(ctx, truncated)
|
|
}
|
|
|
|
var g errgroup.Group
|
|
embeddings := make([][]float32, len(input))
|
|
var totalTokens uint64
|
|
for i, text := range input {
|
|
g.Go(func() error {
|
|
embedding, tokenCount, err := embedWithRetry(text)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// TODO: this first normalization should be done by the model
|
|
embedding, err = normalize(embedding)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if req.Dimensions > 0 && req.Dimensions < len(embedding) {
|
|
embedding, err = normalize(embedding[:req.Dimensions])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
embeddings[i] = embedding
|
|
atomic.AddUint64(&totalTokens, uint64(tokenCount))
|
|
return nil
|
|
})
|
|
}
|
|
|
|
if err := g.Wait(); err != nil {
|
|
var serr api.StatusError
|
|
if errors.As(err, &serr) {
|
|
c.AbortWithStatusJSON(serr.StatusCode, gin.H{
|
|
"error": strings.TrimSpace(serr.ErrorMessage),
|
|
})
|
|
return
|
|
}
|
|
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
|
|
"error": strings.TrimSpace(err.Error()),
|
|
})
|
|
return
|
|
}
|
|
|
|
resp := api.EmbedResponse{
|
|
Model: req.Model,
|
|
Embeddings: embeddings,
|
|
TotalDuration: time.Since(checkpointStart),
|
|
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
|
PromptEvalCount: int(totalTokens),
|
|
}
|
|
c.JSON(http.StatusOK, resp)
|
|
}
|
|
|
|
func normalize(vec []float32) ([]float32, error) {
|
|
var sum float32
|
|
for _, v := range vec {
|
|
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
|
|
return nil, errors.New("embedding contains NaN or Inf values")
|
|
}
|
|
sum += v * v
|
|
}
|
|
|
|
norm := float32(1.0 / max(math.Sqrt(float64(sum)), 1e-12))
|
|
for i := range vec {
|
|
vec[i] *= norm
|
|
}
|
|
return vec, nil
|
|
}
|
|
|
|
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|
var req api.EmbeddingRequest
|
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
return
|
|
} else if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
modelRef, err := parseAndValidateModelRef(req.Model)
|
|
if err != nil {
|
|
writeModelRefParseError(c, err, http.StatusBadRequest, "model is required")
|
|
return
|
|
}
|
|
|
|
if modelRef.Source == modelSourceCloud {
|
|
req.Model = modelRef.Base
|
|
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
|
return
|
|
}
|
|
|
|
name := modelRef.Name
|
|
|
|
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
|
if err != nil {
|
|
handleScheduleError(c, req.Model, err)
|
|
return
|
|
}
|
|
|
|
// an empty request loads the model
|
|
if req.Prompt == "" {
|
|
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
|
|
return
|
|
}
|
|
|
|
embedding, _, err := r.Embedding(c.Request.Context(), req.Prompt)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
|
|
return
|
|
}
|
|
|
|
var e []float64
|
|
for _, v := range embedding {
|
|
e = append(e, float64(v))
|
|
}
|
|
|
|
resp := api.EmbeddingResponse{
|
|
Embedding: e,
|
|
}
|
|
c.JSON(http.StatusOK, resp)
|
|
}
|
|
|
|
func (s *Server) PullHandler(c *gin.Context) {
|
|
var req api.PullRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
switch {
|
|
case errors.Is(err, io.EOF):
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
return
|
|
case err != nil:
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// TEMP(drifkin): we're temporarily allowing to continue pulling cloud model
|
|
// stub-files until we integrate cloud models into `/api/tags` (in which case
|
|
// this roundabout way of "adding" cloud models won't be needed anymore). So
|
|
// right here normalize any `:cloud` models into the legacy-style suffixes
|
|
// `:<tag>-cloud` and `:cloud`
|
|
modelRef, err := parseNormalizePullModelRef(cmp.Or(req.Model, req.Name))
|
|
if err != nil {
|
|
writeModelRefParseError(c, err, http.StatusBadRequest, errtypes.InvalidModelNameErrMsg)
|
|
return
|
|
}
|
|
|
|
name := modelRef.Name
|
|
|
|
name, err = getExistingName(name)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
ch := make(chan any)
|
|
go func() {
|
|
defer close(ch)
|
|
fn := func(r api.ProgressResponse) {
|
|
ch <- r
|
|
}
|
|
|
|
regOpts := ®istryOptions{
|
|
Insecure: req.Insecure,
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
|
defer cancel()
|
|
|
|
if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
|
|
ch <- gin.H{"error": err.Error()}
|
|
}
|
|
}()
|
|
|
|
if req.Stream != nil && !*req.Stream {
|
|
waitForStream(c, ch)
|
|
return
|
|
}
|
|
|
|
streamResponse(c, ch)
|
|
}
|
|
|
|
func (s *Server) PushHandler(c *gin.Context) {
|
|
var req api.PushRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
switch {
|
|
case errors.Is(err, io.EOF):
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
return
|
|
case err != nil:
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
var mname string
|
|
if req.Model != "" {
|
|
mname = req.Model
|
|
} else if req.Name != "" {
|
|
mname = req.Name
|
|
} else {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
return
|
|
}
|
|
|
|
ch := make(chan any)
|
|
go func() {
|
|
defer close(ch)
|
|
fn := func(r api.ProgressResponse) {
|
|
ch <- r
|
|
}
|
|
|
|
regOpts := ®istryOptions{
|
|
Insecure: req.Insecure,
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
|
defer cancel()
|
|
|
|
name, err := getExistingName(model.ParseName(mname))
|
|
if err != nil {
|
|
ch <- gin.H{"error": err.Error()}
|
|
return
|
|
}
|
|
|
|
if err := PushModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
|
|
ch <- gin.H{"error": err.Error()}
|
|
}
|
|
}()
|
|
|
|
if req.Stream != nil && !*req.Stream {
|
|
waitForStream(c, ch)
|
|
return
|
|
}
|
|
|
|
streamResponse(c, ch)
|
|
}
|
|
|
|
// getExistingName searches the models directory for the longest prefix match of
|
|
// the input name and returns the input name with all existing parts replaced
|
|
// with each part found. If no parts are found, the input name is returned as
|
|
// is.
|
|
func getExistingName(n model.Name) (model.Name, error) {
|
|
var zero model.Name
|
|
existing, err := manifest.Manifests(true)
|
|
if err != nil {
|
|
return zero, err
|
|
}
|
|
var set model.Name // tracks parts already canonicalized
|
|
for e := range existing {
|
|
if set.Host == "" && strings.EqualFold(e.Host, n.Host) {
|
|
n.Host = e.Host
|
|
}
|
|
if set.Namespace == "" && strings.EqualFold(e.Namespace, n.Namespace) {
|
|
n.Namespace = e.Namespace
|
|
}
|
|
if set.Model == "" && strings.EqualFold(e.Model, n.Model) {
|
|
n.Model = e.Model
|
|
}
|
|
if set.Tag == "" && strings.EqualFold(e.Tag, n.Tag) {
|
|
n.Tag = e.Tag
|
|
}
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
func (s *Server) DeleteHandler(c *gin.Context) {
|
|
var r api.DeleteRequest
|
|
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
return
|
|
} else if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
modelRef, err := parseNormalizePullModelRef(cmp.Or(r.Model, r.Name))
|
|
if err != nil {
|
|
switch {
|
|
case errors.Is(err, errConflictingModelSource):
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
case errors.Is(err, model.ErrUnqualifiedName):
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
|
|
default:
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
}
|
|
return
|
|
}
|
|
|
|
n, err := getExistingName(modelRef.Name)
|
|
if err != nil {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
|
|
return
|
|
}
|
|
|
|
m, err := manifest.ParseNamedManifest(n)
|
|
if err != nil {
|
|
switch {
|
|
case os.IsNotExist(err):
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
|
|
default:
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
}
|
|
return
|
|
}
|
|
|
|
if err := m.Remove(); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if err := m.RemoveLayers(); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
}
|
|
|
|
func (s *Server) ShowHandler(c *gin.Context) {
|
|
var req api.ShowRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
switch {
|
|
case errors.Is(err, io.EOF):
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
return
|
|
case err != nil:
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if req.Model != "" {
|
|
// noop
|
|
} else if req.Name != "" {
|
|
req.Model = req.Name
|
|
} else {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
return
|
|
}
|
|
|
|
modelRef, err := parseAndValidateModelRef(req.Model)
|
|
if err != nil {
|
|
writeModelRefParseError(c, err, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
if modelRef.Source == modelSourceCloud {
|
|
req.Model = modelRef.Base
|
|
proxyCloudJSONRequest(c, req, cloudErrRemoteModelDetailsUnavailable)
|
|
return
|
|
}
|
|
|
|
req.Model = modelRef.Base
|
|
|
|
resp, err := GetModelInfo(req)
|
|
if err != nil {
|
|
var statusErr api.StatusError
|
|
switch {
|
|
case os.IsNotExist(err):
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
|
case errors.As(err, &statusErr):
|
|
c.JSON(statusErr.StatusCode, gin.H{"error": statusErr.ErrorMessage})
|
|
case err.Error() == errtypes.InvalidModelNameErrMsg:
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
default:
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
}
|
|
return
|
|
}
|
|
|
|
if modelRef.Source == modelSourceLocal && resp.RemoteHost != "" {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", modelRef.Original)})
|
|
return
|
|
}
|
|
|
|
userAgent := c.Request.UserAgent()
|
|
if strings.HasPrefix(userAgent, copilotChatUserAgentPrefix) {
|
|
if resp.ModelInfo == nil {
|
|
resp.ModelInfo = map[string]any{}
|
|
}
|
|
// Copilot Chat prefers `general.basename`, but this is usually not what
|
|
// users are familiar with, so let's just echo back what we had returned in
|
|
// `/api/tags`
|
|
resp.ModelInfo["general.basename"] = req.Model
|
|
}
|
|
|
|
c.JSON(http.StatusOK, resp)
|
|
}
|
|
|
|
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|
name := model.ParseName(req.Model)
|
|
if !name.IsValid() {
|
|
return nil, model.Unqualified(name)
|
|
}
|
|
name, err := getExistingName(name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
m, err := GetModel(name.String())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if m.Config.RemoteHost != "" {
|
|
if disabled, _ := internalcloud.Status(); disabled {
|
|
return nil, api.StatusError{
|
|
StatusCode: http.StatusForbidden,
|
|
ErrorMessage: internalcloud.DisabledError(cloudErrRemoteModelDetailsUnavailable),
|
|
}
|
|
}
|
|
}
|
|
|
|
modelDetails := api.ModelDetails{
|
|
ParentModel: m.ParentModel,
|
|
Format: m.Config.ModelFormat,
|
|
Family: m.Config.ModelFamily,
|
|
Families: m.Config.ModelFamilies,
|
|
ParameterSize: m.Config.ModelType,
|
|
QuantizationLevel: m.Config.FileType,
|
|
}
|
|
|
|
// For image generation models, populate details from imagegen package
|
|
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
|
if info, err := imagegenmanifest.GetModelInfo(name.String()); err == nil {
|
|
modelDetails.Family = info.Architecture
|
|
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
|
|
modelDetails.QuantizationLevel = info.Quantization
|
|
}
|
|
}
|
|
|
|
// For safetensors LLM models (experimental), populate details from config.json
|
|
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
|
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
|
|
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
|
|
modelDetails.Family = arch
|
|
}
|
|
if paramCount, ok := info["general.parameter_count"].(int64); ok && paramCount > 0 {
|
|
modelDetails.ParameterSize = format.HumanNumber(uint64(paramCount))
|
|
}
|
|
}
|
|
// Older manifests may not have file_type populated for safetensors models.
|
|
if modelDetails.QuantizationLevel == "" {
|
|
if dtype, err := xserver.GetSafetensorsDtype(name); err == nil && dtype != "" {
|
|
modelDetails.QuantizationLevel = dtype
|
|
}
|
|
}
|
|
}
|
|
|
|
if req.System != "" {
|
|
m.System = req.System
|
|
}
|
|
|
|
msgs := make([]api.Message, len(m.Messages))
|
|
for i, msg := range m.Messages {
|
|
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
|
|
}
|
|
|
|
mf, err := manifest.ParseNamedManifest(name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp := &api.ShowResponse{
|
|
License: strings.Join(m.License, "\n"),
|
|
System: m.System,
|
|
Template: m.Template.String(),
|
|
Details: modelDetails,
|
|
Messages: msgs,
|
|
Capabilities: m.Capabilities(),
|
|
ModifiedAt: mf.FileInfo().ModTime(),
|
|
Requires: m.Config.Requires,
|
|
// Several integrations crash on a nil/omitempty+empty ModelInfo, so by
|
|
// default we return an empty map.
|
|
ModelInfo: make(map[string]any),
|
|
}
|
|
|
|
if m.Config.RemoteHost != "" {
|
|
resp.RemoteHost = m.Config.RemoteHost
|
|
resp.RemoteModel = m.Config.RemoteModel
|
|
|
|
if m.Config.ModelFamily != "" {
|
|
resp.ModelInfo = make(map[string]any)
|
|
resp.ModelInfo["general.architecture"] = m.Config.ModelFamily
|
|
|
|
if m.Config.BaseName != "" {
|
|
resp.ModelInfo["general.basename"] = m.Config.BaseName
|
|
}
|
|
|
|
if m.Config.ContextLen > 0 {
|
|
resp.ModelInfo[fmt.Sprintf("%s.context_length", m.Config.ModelFamily)] = m.Config.ContextLen
|
|
}
|
|
|
|
if m.Config.EmbedLen > 0 {
|
|
resp.ModelInfo[fmt.Sprintf("%s.embedding_length", m.Config.ModelFamily)] = m.Config.EmbedLen
|
|
}
|
|
}
|
|
}
|
|
|
|
var params []string
|
|
cs := 30
|
|
for k, v := range m.Options {
|
|
switch val := v.(type) {
|
|
case []any:
|
|
for _, nv := range val {
|
|
params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv))
|
|
}
|
|
default:
|
|
params = append(params, fmt.Sprintf("%-*s %#v", cs, k, v))
|
|
}
|
|
}
|
|
resp.Parameters = strings.Join(params, "\n")
|
|
|
|
if len(req.Options) > 0 {
|
|
if m.Options == nil {
|
|
m.Options = make(map[string]any)
|
|
}
|
|
for k, v := range req.Options {
|
|
m.Options[k] = v
|
|
}
|
|
}
|
|
|
|
var sb strings.Builder
|
|
fmt.Fprintln(&sb, "# Modelfile generated by \"ollama show\"")
|
|
fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
|
|
fmt.Fprintf(&sb, "# FROM %s\n\n", m.ShortName)
|
|
fmt.Fprint(&sb, m.String())
|
|
resp.Modelfile = sb.String()
|
|
|
|
// skip loading tensor information if this is a remote model
|
|
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
|
return resp, nil
|
|
}
|
|
|
|
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
|
|
// Populate tensor info if verbose
|
|
if req.Verbose {
|
|
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
|
|
resp.Tensors = tensors
|
|
}
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
// For safetensors LLM models (experimental), populate ModelInfo from config.json
|
|
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
|
|
if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
|
|
resp.ModelInfo = info
|
|
}
|
|
// Populate tensor info if verbose
|
|
if req.Verbose {
|
|
if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
|
|
resp.Tensors = tensors
|
|
}
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
delete(kvData, "general.name")
|
|
delete(kvData, "tokenizer.chat_template")
|
|
resp.ModelInfo = kvData
|
|
|
|
tensorData := make([]api.Tensor, len(tensors.Items()))
|
|
for cnt, t := range tensors.Items() {
|
|
tensorData[cnt] = api.Tensor{Name: t.Name, Type: t.Type(), Shape: t.Shape}
|
|
}
|
|
resp.Tensors = tensorData
|
|
|
|
if len(m.ProjectorPaths) > 0 {
|
|
projectorData, _, err := getModelData(m.ProjectorPaths[0], req.Verbose)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resp.ProjectorInfo = projectorData
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) {
|
|
maxArraySize := 0
|
|
if verbose {
|
|
maxArraySize = -1
|
|
}
|
|
data, err := llm.LoadModel(digest, maxArraySize)
|
|
if err != nil {
|
|
return nil, ggml.Tensors{}, err
|
|
}
|
|
|
|
kv := data.KV()
|
|
|
|
if !verbose {
|
|
for k := range kv {
|
|
if t, ok := kv[k].([]any); len(t) > 5 && ok {
|
|
kv[k] = []any{}
|
|
}
|
|
}
|
|
}
|
|
|
|
return kv, data.Tensors(), nil
|
|
}
|
|
|
|
func (s *Server) ListHandler(c *gin.Context) {
|
|
ms, err := manifest.Manifests(true)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
models := []api.ListModelResponse{}
|
|
for n, m := range ms {
|
|
var cf model.ConfigV2
|
|
|
|
if m.Config.Digest != "" {
|
|
f, err := m.Config.Open()
|
|
if err != nil {
|
|
slog.Warn("bad manifest filepath", "name", n, "error", err)
|
|
continue
|
|
}
|
|
defer f.Close()
|
|
|
|
if err := json.NewDecoder(f).Decode(&cf); err != nil {
|
|
slog.Warn("bad manifest config", "name", n, "error", err)
|
|
continue
|
|
}
|
|
}
|
|
|
|
// tag should never be masked
|
|
models = append(models, api.ListModelResponse{
|
|
Model: n.DisplayShortest(),
|
|
Name: n.DisplayShortest(),
|
|
RemoteModel: cf.RemoteModel,
|
|
RemoteHost: cf.RemoteHost,
|
|
Size: m.Size(),
|
|
Digest: m.Digest(),
|
|
ModifiedAt: m.FileInfo().ModTime(),
|
|
Details: api.ModelDetails{
|
|
Format: cf.ModelFormat,
|
|
Family: cf.ModelFamily,
|
|
Families: cf.ModelFamilies,
|
|
ParameterSize: cf.ModelType,
|
|
QuantizationLevel: cf.FileType,
|
|
},
|
|
})
|
|
}
|
|
|
|
slices.SortStableFunc(models, func(i, j api.ListModelResponse) int {
|
|
// most recently modified first
|
|
return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
|
|
})
|
|
|
|
c.JSON(http.StatusOK, api.ListResponse{Models: models})
|
|
}
|
|
|
|
func (s *Server) CopyHandler(c *gin.Context) {
|
|
var r api.CopyRequest
|
|
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
return
|
|
} else if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
src := model.ParseName(r.Source)
|
|
if !src.IsValid() {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)})
|
|
return
|
|
}
|
|
src, err := getExistingName(src)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
dst := model.ParseName(r.Destination)
|
|
if !dst.IsValid() {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)})
|
|
return
|
|
}
|
|
dst, err = getExistingName(dst)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
|
|
} else if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
}
|
|
}
|
|
|
|
func (s *Server) HeadBlobHandler(c *gin.Context) {
|
|
path, err := manifest.BlobsPath(c.Param("digest"))
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if _, err := os.Stat(path); err != nil {
|
|
c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
|
|
return
|
|
}
|
|
|
|
c.Status(http.StatusOK)
|
|
}
|
|
|
|
func (s *Server) CreateBlobHandler(c *gin.Context) {
|
|
if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
|
|
p, err := manifest.BlobsPath(ib)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) {
|
|
slog.Info("evicting intermediate blob which no longer exists", "digest", ib)
|
|
delete(intermediateBlobs, c.Param("digest"))
|
|
} else if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
} else {
|
|
c.Status(http.StatusOK)
|
|
return
|
|
}
|
|
}
|
|
|
|
path, err := manifest.BlobsPath(c.Param("digest"))
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
_, err = os.Stat(path)
|
|
switch {
|
|
case errors.Is(err, os.ErrNotExist):
|
|
// noop
|
|
case err != nil:
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
default:
|
|
c.Status(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
layer, err := manifest.NewLayer(c.Request.Body, "")
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if layer.Digest != c.Param("digest") {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("digest mismatch, expected %q, got %q", c.Param("digest"), layer.Digest)})
|
|
return
|
|
}
|
|
|
|
c.Status(http.StatusCreated)
|
|
}
|
|
|
|
func isLocalIP(ip netip.Addr) bool {
|
|
if interfaces, err := net.Interfaces(); err == nil {
|
|
for _, iface := range interfaces {
|
|
addrs, err := iface.Addrs()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
for _, a := range addrs {
|
|
if parsed, _, err := net.ParseCIDR(a.String()); err == nil {
|
|
if parsed.String() == ip.String() {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func allowedHost(host string) bool {
|
|
host = strings.ToLower(host)
|
|
|
|
if host == "" || host == "localhost" {
|
|
return true
|
|
}
|
|
|
|
if hostname, err := os.Hostname(); err == nil && host == strings.ToLower(hostname) {
|
|
return true
|
|
}
|
|
|
|
tlds := []string{
|
|
"localhost",
|
|
"local",
|
|
"internal",
|
|
}
|
|
|
|
// check if the host is a local TLD
|
|
for _, tld := range tlds {
|
|
if strings.HasSuffix(host, "."+tld) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if addr == nil {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
if addr, err := netip.ParseAddrPort(addr.String()); err == nil && !addr.Addr().IsLoopback() {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
host, _, err := net.SplitHostPort(c.Request.Host)
|
|
if err != nil {
|
|
host = c.Request.Host
|
|
}
|
|
|
|
if addr, err := netip.ParseAddr(host); err == nil {
|
|
if addr.IsLoopback() || addr.IsPrivate() || addr.IsUnspecified() || isLocalIP(addr) {
|
|
c.Next()
|
|
return
|
|
}
|
|
}
|
|
|
|
if allowedHost(host) {
|
|
if c.Request.Method == http.MethodOptions {
|
|
c.AbortWithStatus(http.StatusNoContent)
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
c.AbortWithStatus(http.StatusForbidden)
|
|
}
|
|
}
|
|
|
|
func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|
corsConfig := cors.DefaultConfig()
|
|
corsConfig.AllowWildcard = true
|
|
corsConfig.AllowBrowserExtensions = true
|
|
corsConfig.AllowHeaders = []string{
|
|
"Authorization",
|
|
"Content-Type",
|
|
"User-Agent",
|
|
"Accept",
|
|
"X-Requested-With",
|
|
|
|
// OpenAI compatibility headers
|
|
"OpenAI-Beta",
|
|
"x-stainless-arch",
|
|
"x-stainless-async",
|
|
"x-stainless-custom-poll-interval",
|
|
"x-stainless-helper-method",
|
|
"x-stainless-lang",
|
|
"x-stainless-os",
|
|
"x-stainless-package-version",
|
|
"x-stainless-poll-helper",
|
|
"x-stainless-retry-count",
|
|
"x-stainless-runtime",
|
|
"x-stainless-runtime-version",
|
|
"x-stainless-timeout",
|
|
}
|
|
corsConfig.AllowOrigins = envconfig.AllowedOrigins()
|
|
|
|
r := gin.Default()
|
|
r.HandleMethodNotAllowed = true
|
|
r.Use(
|
|
cors.New(corsConfig),
|
|
allowedHostsMiddleware(s.addr),
|
|
)
|
|
|
|
// General
|
|
r.HEAD("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
|
|
r.GET("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
|
|
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
|
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
|
r.GET("/api/status", s.StatusHandler)
|
|
|
|
// Local model cache management (new implementation is at end of function)
|
|
r.POST("/api/pull", s.PullHandler)
|
|
r.POST("/api/push", s.PushHandler)
|
|
r.HEAD("/api/tags", s.ListHandler)
|
|
r.GET("/api/tags", s.ListHandler)
|
|
r.POST("/api/show", s.ShowHandler)
|
|
r.DELETE("/api/delete", s.DeleteHandler)
|
|
|
|
r.POST("/api/me", s.WhoamiHandler)
|
|
|
|
r.POST("/api/signout", s.SignoutHandler)
|
|
// deprecated
|
|
r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler)
|
|
|
|
// Create
|
|
r.POST("/api/create", s.CreateHandler)
|
|
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
|
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
|
|
r.POST("/api/copy", s.CopyHandler)
|
|
r.POST("/api/experimental/web_search", s.WebSearchExperimentalHandler)
|
|
r.POST("/api/experimental/web_fetch", s.WebFetchExperimentalHandler)
|
|
|
|
// Inference
|
|
r.GET("/api/ps", s.PsHandler)
|
|
r.POST("/api/generate", s.withInferenceRequestLogging("/api/generate", s.GenerateHandler)...)
|
|
r.POST("/api/chat", s.withInferenceRequestLogging("/api/chat", s.ChatHandler)...)
|
|
r.POST("/api/embed", s.EmbedHandler)
|
|
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
|
|
|
// Inference (OpenAI compatibility)
|
|
// TODO(cloud-stage-a): apply Modelfile overlay deltas for local models with cloud
|
|
// parents on v1 request families while preserving this explicit :cloud passthrough.
|
|
r.POST("/v1/chat/completions", s.withInferenceRequestLogging("/v1/chat/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ChatMiddleware(), s.ChatHandler)...)
|
|
r.POST("/v1/completions", s.withInferenceRequestLogging("/v1/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.CompletionsMiddleware(), s.GenerateHandler)...)
|
|
r.POST("/v1/embeddings", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
|
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
|
r.GET("/v1/models/:model", cloudModelPathPassthroughMiddleware(cloudErrRemoteModelDetailsUnavailable), middleware.RetrieveMiddleware(), s.ShowHandler)
|
|
r.POST("/v1/responses", s.withInferenceRequestLogging("/v1/responses", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ResponsesMiddleware(), s.ChatHandler)...)
|
|
// OpenAI-compatible image generation endpoints
|
|
r.POST("/v1/images/generations", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
|
r.POST("/v1/images/edits", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
|
// OpenAI-compatible audio endpoint
|
|
r.POST("/v1/audio/transcriptions", middleware.TranscriptionMiddleware(), s.ChatHandler)
|
|
|
|
// Inference (Anthropic compatibility)
|
|
r.POST("/v1/messages", s.withInferenceRequestLogging("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)...)
|
|
|
|
if rc != nil {
|
|
// wrap old with new
|
|
rs := ®istry.Local{
|
|
Client: rc,
|
|
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
|
|
Fallback: r,
|
|
|
|
Prune: PruneLayers,
|
|
}
|
|
return rs, nil
|
|
}
|
|
|
|
return r, nil
|
|
}
|
|
|
|
func Serve(ln net.Listener) error {
|
|
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
|
slog.Info("server config", "env", envconfig.Values())
|
|
cloudDisabled, _ := internalcloud.Status()
|
|
slog.Info(fmt.Sprintf("Ollama cloud disabled: %t", cloudDisabled))
|
|
|
|
blobsDir, err := manifest.BlobsPath("")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := fixBlobs(blobsDir); err != nil {
|
|
return err
|
|
}
|
|
|
|
if !envconfig.NoPrune() {
|
|
if _, err := manifest.Manifests(false); err != nil {
|
|
slog.Warn("corrupt manifests detected, skipping prune operation. Re-pull or delete to clear", "error", err)
|
|
} else {
|
|
// clean up unused layers and manifests
|
|
if err := PruneLayers(); err != nil {
|
|
return err
|
|
}
|
|
|
|
manifestsPath, err := manifest.Path()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := manifest.PruneDirectory(manifestsPath); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
s := &Server{addr: ln.Addr()}
|
|
if err := s.initRequestLogging(); err != nil {
|
|
return err
|
|
}
|
|
|
|
var rc *ollama.Registry
|
|
if useClient2 {
|
|
var err error
|
|
rc, err = ollama.DefaultRegistry()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
h, err := s.GenerateRoutes(rc)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
http.Handle("/", h)
|
|
|
|
ctx, done := context.WithCancel(context.Background())
|
|
schedCtx, schedDone := context.WithCancel(ctx)
|
|
sched := InitScheduler(schedCtx)
|
|
s.sched = sched
|
|
|
|
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
|
|
srvr := &http.Server{
|
|
// Use http.DefaultServeMux so we get net/http/pprof for
|
|
// free.
|
|
//
|
|
// TODO(bmizerany): Decide if we want to make this
|
|
// configurable so it is not exposed by default, or allow
|
|
// users to bind it to a different port. This was a quick
|
|
// and easy way to get pprof, but it may not be the best
|
|
// way.
|
|
Handler: nil,
|
|
}
|
|
|
|
// listen for a ctrl+c and stop any loaded llm
|
|
signals := make(chan os.Signal, 1)
|
|
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
|
go func() {
|
|
<-signals
|
|
srvr.Close()
|
|
schedDone()
|
|
sched.unloadAllRunners()
|
|
done()
|
|
}()
|
|
|
|
s.sched.Run(schedCtx)
|
|
|
|
// register the experimental webp decoder
|
|
// so webp images can be used in multimodal inputs
|
|
image.RegisterFormat("webp", "RIFF????WEBP", webp.Decode, webp.DecodeConfig)
|
|
|
|
// At startup we retrieve GPU information so we can get log messages before loading a model
|
|
// This will log warnings to the log in case we have problems with detected GPUs
|
|
gpus := discover.GPUDevices(ctx, nil)
|
|
discover.LogDetails(gpus)
|
|
|
|
var totalVRAM uint64
|
|
for _, gpu := range gpus {
|
|
totalVRAM += gpu.TotalMemory - envconfig.GpuOverhead()
|
|
}
|
|
|
|
// Set default context based on VRAM tier
|
|
// Use slightly lower thresholds (47/23 GiB vs. 48/24 GiB) to account for small differences in the exact value
|
|
switch {
|
|
case totalVRAM >= 47*format.GibiByte:
|
|
s.defaultNumCtx = 262144
|
|
case totalVRAM >= 23*format.GibiByte:
|
|
s.defaultNumCtx = 32768
|
|
default:
|
|
s.defaultNumCtx = 4096
|
|
}
|
|
slog.Info("vram-based default context", "total_vram", format.HumanBytes2(totalVRAM), "default_num_ctx", s.defaultNumCtx)
|
|
|
|
err = srvr.Serve(ln)
|
|
// If server is closed from the signal handler, wait for the ctx to be done
|
|
// otherwise error out quickly
|
|
if !errors.Is(err, http.ErrServerClosed) {
|
|
return err
|
|
}
|
|
<-ctx.Done()
|
|
return nil
|
|
}
|
|
|
|
func waitForStream(c *gin.Context, ch chan any) {
|
|
c.Header("Content-Type", "application/json")
|
|
var latest api.ProgressResponse
|
|
for resp := range ch {
|
|
switch r := resp.(type) {
|
|
case api.ProgressResponse:
|
|
latest = r
|
|
case gin.H:
|
|
status, ok := r["status"].(int)
|
|
if !ok {
|
|
status = http.StatusInternalServerError
|
|
}
|
|
errorMsg, ok := r["error"].(string)
|
|
if !ok {
|
|
errorMsg = "unknown error"
|
|
}
|
|
c.JSON(status, gin.H{"error": errorMsg})
|
|
return
|
|
default:
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "unknown message type"})
|
|
return
|
|
}
|
|
}
|
|
|
|
c.JSON(http.StatusOK, latest)
|
|
}
|
|
|
|
func streamResponse(c *gin.Context, ch chan any) {
|
|
c.Header("Content-Type", "application/x-ndjson")
|
|
c.Stream(func(w io.Writer) bool {
|
|
val, ok := <-ch
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
// errors are provided as a gin.H with an "error" field and
|
|
// an optional "status" field. For errors that are streamed
|
|
// before any content, we need to set the status code and
|
|
// content type for the error.
|
|
if h, ok := val.(gin.H); ok {
|
|
if e, ok := h["error"].(string); ok {
|
|
status, ok := h["status"].(int)
|
|
if !ok {
|
|
status = http.StatusInternalServerError
|
|
}
|
|
|
|
if !c.Writer.Written() {
|
|
c.Header("Content-Type", "application/json")
|
|
c.JSON(status, gin.H{"error": e})
|
|
} else {
|
|
if err := json.NewEncoder(c.Writer).Encode(gin.H{"error": e}); err != nil {
|
|
slog.Error("streamResponse failed to encode json error", "error", err)
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
}
|
|
|
|
bts, err := json.Marshal(val)
|
|
if err != nil {
|
|
slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
|
|
return false
|
|
}
|
|
|
|
// Delineate chunks with new-line delimiter
|
|
bts = append(bts, '\n')
|
|
if _, err := w.Write(bts); err != nil {
|
|
slog.Info(fmt.Sprintf("streamResponse: w.Write failed with %s", err))
|
|
return false
|
|
}
|
|
|
|
return true
|
|
})
|
|
}
|
|
|
|
func (s *Server) StatusHandler(c *gin.Context) {
|
|
disabled, source := internalcloud.Status()
|
|
c.JSON(http.StatusOK, api.StatusResponse{
|
|
Cloud: api.CloudStatus{
|
|
Disabled: disabled,
|
|
Source: source,
|
|
},
|
|
})
|
|
}
|
|
|
|
func (s *Server) WebSearchExperimentalHandler(c *gin.Context) {
|
|
s.webExperimentalProxyHandler(c, "/api/web_search", cloudErrWebSearchUnavailable)
|
|
}
|
|
|
|
func (s *Server) WebFetchExperimentalHandler(c *gin.Context) {
|
|
s.webExperimentalProxyHandler(c, "/api/web_fetch", cloudErrWebFetchUnavailable)
|
|
}
|
|
|
|
func (s *Server) webExperimentalProxyHandler(c *gin.Context, proxyPath, disabledOperation string) {
|
|
body, err := readRequestBody(c.Request)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if len(bytes.TrimSpace(body)) == 0 {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
return
|
|
}
|
|
|
|
proxyCloudRequestWithPath(c, body, proxyPath, disabledOperation)
|
|
}
|
|
|
|
func (s *Server) WhoamiHandler(c *gin.Context) {
|
|
// todo allow other hosts
|
|
u, err := url.Parse("https://ollama.com")
|
|
if err != nil {
|
|
slog.Error(err.Error())
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"})
|
|
return
|
|
}
|
|
|
|
client := api.NewClient(u, http.DefaultClient)
|
|
user, err := client.Whoami(c)
|
|
if err != nil {
|
|
slog.Error(err.Error())
|
|
}
|
|
|
|
// user isn't signed in
|
|
if user != nil && user.Name == "" {
|
|
sURL, sErr := signinURL()
|
|
if sErr != nil {
|
|
slog.Error(sErr.Error())
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": sURL})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, user)
|
|
}
|
|
|
|
func (s *Server) SignoutHandler(c *gin.Context) {
|
|
pubKey, err := auth.GetPublicKey()
|
|
if err != nil {
|
|
slog.Error("couldn't get public key", "error", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
|
|
return
|
|
}
|
|
|
|
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
|
|
|
// todo allow other hosts
|
|
u, err := url.Parse("https://ollama.com")
|
|
if err != nil {
|
|
slog.Error(err.Error())
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"})
|
|
return
|
|
}
|
|
|
|
client := api.NewClient(u, http.DefaultClient)
|
|
err = client.Disconnect(c, encKey)
|
|
if err != nil {
|
|
var authError api.AuthorizationError
|
|
if errors.As(err, &authError) {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "you are not currently signed in"})
|
|
return
|
|
}
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, nil)
|
|
}
|
|
|
|
func (s *Server) PsHandler(c *gin.Context) {
|
|
models := []api.ProcessModelResponse{}
|
|
|
|
for _, v := range s.sched.loaded {
|
|
model := v.model
|
|
modelDetails := api.ModelDetails{
|
|
Format: model.Config.ModelFormat,
|
|
Family: model.Config.ModelFamily,
|
|
Families: model.Config.ModelFamilies,
|
|
ParameterSize: model.Config.ModelType,
|
|
QuantizationLevel: model.Config.FileType,
|
|
}
|
|
|
|
mr := api.ProcessModelResponse{
|
|
Model: model.ShortName,
|
|
Name: model.ShortName,
|
|
Size: int64(v.totalSize),
|
|
SizeVRAM: int64(v.vramSize),
|
|
Digest: model.Digest,
|
|
Details: modelDetails,
|
|
ExpiresAt: v.expiresAt,
|
|
}
|
|
if v.llama != nil {
|
|
mr.ContextLength = v.llama.ContextLength()
|
|
total, vram := v.llama.MemorySize()
|
|
mr.Size = int64(total)
|
|
mr.SizeVRAM = int64(vram)
|
|
}
|
|
// The scheduler waits to set expiresAt, so if a model is loading it's
|
|
// possible that it will be set to the unix epoch. For those cases, just
|
|
// calculate the time w/ the sessionDuration instead.
|
|
var epoch time.Time
|
|
if v.expiresAt == epoch {
|
|
mr.ExpiresAt = time.Now().Add(v.sessionDuration)
|
|
}
|
|
|
|
models = append(models, mr)
|
|
}
|
|
|
|
slices.SortStableFunc(models, func(i, j api.ProcessModelResponse) int {
|
|
// longest duration remaining listed first
|
|
return cmp.Compare(j.ExpiresAt.Unix(), i.ExpiresAt.Unix())
|
|
})
|
|
|
|
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
|
}
|
|
|
|
func toolCallId() string {
|
|
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
|
b := make([]byte, 8)
|
|
for i := range b {
|
|
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
|
}
|
|
return "call_" + strings.ToLower(string(b))
|
|
}
|
|
|
|
func (s *Server) ChatHandler(c *gin.Context) {
|
|
checkpointStart := time.Now()
|
|
|
|
var req api.ChatRequest
|
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
return
|
|
} else if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
|
|
return
|
|
}
|
|
|
|
modelRef, err := parseAndValidateModelRef(req.Model)
|
|
if err != nil {
|
|
writeModelRefParseError(c, err, http.StatusBadRequest, "model is required")
|
|
return
|
|
}
|
|
|
|
if modelRef.Source == modelSourceCloud {
|
|
req.Model = modelRef.Base
|
|
if c.GetBool(legacyCloudAnthropicKey) {
|
|
proxyCloudJSONRequestWithPath(c, req, "/api/chat", cloudErrRemoteInferenceUnavailable)
|
|
return
|
|
}
|
|
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
|
return
|
|
}
|
|
|
|
name := modelRef.Name
|
|
|
|
name, err = getExistingName(name)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
return
|
|
}
|
|
|
|
m, err := GetModel(name.String())
|
|
if err != nil {
|
|
switch {
|
|
case os.IsNotExist(err):
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
|
case err.Error() == errtypes.InvalidModelNameErrMsg:
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
default:
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
}
|
|
return
|
|
}
|
|
|
|
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
|
|
return
|
|
}
|
|
|
|
if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
|
return
|
|
}
|
|
|
|
// expire the runner
|
|
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
|
s.sched.expireRunner(m)
|
|
|
|
c.JSON(http.StatusOK, api.ChatResponse{
|
|
Model: req.Model,
|
|
CreatedAt: time.Now().UTC(),
|
|
Message: api.Message{Role: "assistant"},
|
|
Done: true,
|
|
DoneReason: "unload",
|
|
})
|
|
return
|
|
}
|
|
|
|
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
|
if disabled, _ := internalcloud.Status(); disabled {
|
|
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)})
|
|
return
|
|
}
|
|
|
|
origModel := req.Model
|
|
|
|
remoteURL, err := url.Parse(m.Config.RemoteHost)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) {
|
|
slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname())
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"})
|
|
return
|
|
}
|
|
|
|
req.Model = m.Config.RemoteModel
|
|
if req.Options == nil {
|
|
req.Options = map[string]any{}
|
|
}
|
|
|
|
var msgs []api.Message
|
|
if len(req.Messages) > 0 {
|
|
msgs = append(m.Messages, req.Messages...)
|
|
if req.Messages[0].Role != "system" && m.System != "" {
|
|
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
|
|
}
|
|
}
|
|
|
|
msgs = filterThinkTags(msgs, m)
|
|
req.Messages = msgs
|
|
|
|
for k, v := range m.Options {
|
|
if _, ok := req.Options[k]; !ok {
|
|
req.Options[k] = v
|
|
}
|
|
}
|
|
|
|
contentType := "application/x-ndjson"
|
|
if req.Stream != nil && !*req.Stream {
|
|
contentType = "application/json; charset=utf-8"
|
|
}
|
|
c.Header("Content-Type", contentType)
|
|
|
|
fn := func(resp api.ChatResponse) error {
|
|
resp.Model = origModel
|
|
resp.RemoteModel = m.Config.RemoteModel
|
|
resp.RemoteHost = m.Config.RemoteHost
|
|
|
|
data, err := json.Marshal(resp)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err = c.Writer.Write(append(data, '\n')); err != nil {
|
|
return err
|
|
}
|
|
c.Writer.Flush()
|
|
return nil
|
|
}
|
|
|
|
client := api.NewClient(remoteURL, http.DefaultClient)
|
|
err = client.Chat(c, &req, fn)
|
|
if err != nil {
|
|
var authError api.AuthorizationError
|
|
if errors.As(err, &authError) {
|
|
sURL, sErr := signinURL()
|
|
if sErr != nil {
|
|
slog.Error(sErr.Error())
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
|
|
return
|
|
}
|
|
|
|
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
|
|
return
|
|
}
|
|
var apiError api.StatusError
|
|
if errors.As(err, &apiError) {
|
|
c.JSON(apiError.StatusCode, apiError)
|
|
return
|
|
}
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
caps := []model.Capability{model.CapabilityCompletion}
|
|
if len(req.Tools) > 0 {
|
|
caps = append(caps, model.CapabilityTools)
|
|
}
|
|
|
|
modelCaps := m.Capabilities()
|
|
if slices.Contains(modelCaps, model.CapabilityThinking) {
|
|
caps = append(caps, model.CapabilityThinking)
|
|
if req.Think == nil {
|
|
req.Think = &api.ThinkValue{Value: true}
|
|
}
|
|
} else {
|
|
if req.Think != nil && req.Think.Bool() {
|
|
// Set think to nil when being used with Anthropic API to connect to tools like claude code
|
|
if _, ok := c.Get("relax_thinking"); ok {
|
|
slog.Warn("model does not support thinking, relaxing thinking to nil", "model", req.Model)
|
|
req.Think = nil
|
|
} else {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
|
if errors.Is(err, errCapabilityCompletion) {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
|
return
|
|
} else if err != nil {
|
|
handleScheduleError(c, req.Model, err)
|
|
return
|
|
}
|
|
|
|
checkpointLoaded := time.Now()
|
|
|
|
if len(req.Messages) == 0 {
|
|
c.JSON(http.StatusOK, api.ChatResponse{
|
|
Model: req.Model,
|
|
CreatedAt: time.Now().UTC(),
|
|
Message: api.Message{Role: "assistant"},
|
|
Done: true,
|
|
DoneReason: "load",
|
|
})
|
|
return
|
|
}
|
|
|
|
msgs := append(m.Messages, req.Messages...)
|
|
if req.Messages[0].Role != "system" && m.System != "" {
|
|
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
|
|
}
|
|
msgs = filterThinkTags(msgs, m)
|
|
|
|
if shouldUseHarmony(m) && m.Config.Parser == "" {
|
|
m.Config.Parser = "harmony"
|
|
}
|
|
|
|
var builtinParser parsers.Parser
|
|
processedTools := req.Tools
|
|
|
|
if m.Config.Parser != "" {
|
|
builtinParser = parsers.ParserForName(m.Config.Parser)
|
|
if builtinParser != nil {
|
|
// Determine last message for chat prefill
|
|
var lastMessage *api.Message
|
|
if len(msgs) > 0 {
|
|
lastMessage = &msgs[len(msgs)-1]
|
|
}
|
|
// Initialize parser and get processed tools
|
|
processedTools = builtinParser.Init(req.Tools, lastMessage, req.Think)
|
|
}
|
|
}
|
|
|
|
truncate := req.Truncate == nil || *req.Truncate
|
|
if m.IsMLX() {
|
|
truncate = false
|
|
}
|
|
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
|
if err != nil {
|
|
slog.Error("chat prompt error", "error", err)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// If debug mode is enabled, return the rendered template instead of calling the model
|
|
if req.DebugRenderOnly {
|
|
c.JSON(http.StatusOK, api.ChatResponse{
|
|
Model: req.Model,
|
|
CreatedAt: time.Now().UTC(),
|
|
DebugInfo: &api.DebugInfo{
|
|
RenderedTemplate: prompt,
|
|
ImageCount: len(images),
|
|
},
|
|
})
|
|
return
|
|
}
|
|
|
|
var thinkingState *thinking.Parser
|
|
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
|
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
|
|
thinkingState = &thinking.Parser{
|
|
OpeningTag: openingTag,
|
|
ClosingTag: closingTag,
|
|
}
|
|
|
|
if strings.HasSuffix(strings.TrimSpace(prompt), openingTag) {
|
|
thinkingState.AddContent(openingTag)
|
|
}
|
|
}
|
|
|
|
var toolParser *tools.Parser
|
|
if len(req.Tools) > 0 && (builtinParser == nil || !builtinParser.HasToolSupport()) {
|
|
toolParser = tools.NewParser(m.Template.Template, req.Tools)
|
|
}
|
|
|
|
type structuredOutputsState int
|
|
const (
|
|
structuredOutputsState_None structuredOutputsState = iota
|
|
structuredOutputsState_ReadyToApply
|
|
structuredOutputsState_Applying
|
|
)
|
|
|
|
ch := make(chan any)
|
|
go func() {
|
|
defer close(ch)
|
|
|
|
structuredOutputsState := structuredOutputsState_None
|
|
|
|
for {
|
|
var tb strings.Builder
|
|
|
|
currentFormat := req.Format
|
|
// structured outputs via double request is enabled when:
|
|
// 1. the model supports the thinking capability and
|
|
// 2. it uses a built-in parser or our generic thinking parser
|
|
|
|
// Note that the current approach does not work for (potential future)
|
|
// non-thinking models that emit anything before actual content. This
|
|
// current approach uses the transition from parsed thinking content to
|
|
// parsed non-thinking content as the signal to turn constraining on
|
|
|
|
if req.Format != nil && structuredOutputsState == structuredOutputsState_None && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) {
|
|
currentFormat = nil
|
|
}
|
|
|
|
// sets up new context given parent context per request
|
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
|
err := r.Completion(ctx, llm.CompletionRequest{
|
|
Prompt: prompt,
|
|
Images: images,
|
|
Format: currentFormat,
|
|
Options: opts,
|
|
Shift: req.Shift == nil || *req.Shift,
|
|
Truncate: truncate,
|
|
Logprobs: req.Logprobs,
|
|
TopLogprobs: req.TopLogprobs,
|
|
}, func(r llm.CompletionResponse) {
|
|
res := api.ChatResponse{
|
|
Model: req.Model,
|
|
CreatedAt: time.Now().UTC(),
|
|
Message: api.Message{Role: "assistant", Content: r.Content},
|
|
Done: r.Done,
|
|
Metrics: api.Metrics{
|
|
PromptEvalCount: r.PromptEvalCount,
|
|
PromptEvalDuration: r.PromptEvalDuration,
|
|
EvalCount: r.EvalCount,
|
|
EvalDuration: r.EvalDuration,
|
|
},
|
|
Logprobs: toAPILogprobs(r.Logprobs),
|
|
}
|
|
|
|
if r.Done {
|
|
res.DoneReason = r.DoneReason.String()
|
|
res.TotalDuration = time.Since(checkpointStart)
|
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
}
|
|
|
|
if builtinParser != nil {
|
|
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
|
|
|
|
content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done)
|
|
if err != nil {
|
|
ch <- gin.H{"error": err.Error()}
|
|
return
|
|
}
|
|
|
|
res.Message.Content = content
|
|
res.Message.Thinking = thinking
|
|
for i := range toolCalls {
|
|
toolCalls[i].ID = toolCallId()
|
|
}
|
|
res.Message.ToolCalls = toolCalls
|
|
|
|
tb.WriteString(thinking)
|
|
// we are now receiving content from the model - we should start applying structured outputs
|
|
if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && res.Message.Content != "" {
|
|
structuredOutputsState = structuredOutputsState_ReadyToApply
|
|
cancel()
|
|
return
|
|
}
|
|
|
|
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done || len(res.Logprobs) > 0 {
|
|
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
|
|
ch <- res
|
|
} else {
|
|
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
|
|
}
|
|
return
|
|
}
|
|
|
|
if thinkingState != nil {
|
|
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
|
|
if thinkingContent == "" && remainingContent == "" && !r.Done {
|
|
// need to accumulate more to decide what to send
|
|
return
|
|
}
|
|
res.Message.Thinking = thinkingContent
|
|
tb.WriteString(thinkingContent)
|
|
// emit the collected thinking text before restarting with structured outputs and clear unstructured content
|
|
// to avoid leaking mixed tokens like "</think>Hello"
|
|
if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && remainingContent != "" {
|
|
structuredOutputsState = structuredOutputsState_ReadyToApply
|
|
res.Message.Content = ""
|
|
ch <- res
|
|
cancel()
|
|
return
|
|
}
|
|
res.Message.Content = remainingContent
|
|
}
|
|
|
|
if len(req.Tools) > 0 {
|
|
toolCalls, content := toolParser.Add(res.Message.Content)
|
|
if len(content) > 0 {
|
|
res.Message.Content = content
|
|
} else if len(toolCalls) > 0 {
|
|
for i := range toolCalls {
|
|
toolCalls[i].ID = toolCallId()
|
|
}
|
|
res.Message.ToolCalls = toolCalls
|
|
res.Message.Content = ""
|
|
} else if res.Message.Thinking != "" {
|
|
// don't return, fall through to send
|
|
} else {
|
|
// Send logprobs while content is being buffered by the parser for tool calls
|
|
if len(res.Logprobs) > 0 && !r.Done {
|
|
logprobRes := res
|
|
logprobRes.Message.Content = ""
|
|
logprobRes.Message.ToolCalls = nil
|
|
ch <- logprobRes
|
|
}
|
|
|
|
if r.Done {
|
|
res.Message.Content = toolParser.Content()
|
|
ch <- res
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
ch <- res
|
|
})
|
|
if err != nil {
|
|
if structuredOutputsState == structuredOutputsState_ReadyToApply && strings.Contains(err.Error(), "context canceled") && c.Request.Context().Err() == nil {
|
|
// only ignores error if it's a context cancellation due to setting structured outputs
|
|
} else {
|
|
var serr api.StatusError
|
|
if errors.As(err, &serr) {
|
|
ch <- gin.H{"error": serr.ErrorMessage, "status": serr.StatusCode}
|
|
} else {
|
|
ch <- gin.H{"error": err.Error()}
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
// ignored structured outputs cancellation falls through to here, start a new request with the structured outputs and updated prompt. use the
|
|
if structuredOutputsState == structuredOutputsState_ReadyToApply {
|
|
structuredOutputsState = structuredOutputsState_Applying
|
|
msg := api.Message{
|
|
Role: "assistant",
|
|
Thinking: tb.String(),
|
|
}
|
|
|
|
msgs = append(msgs, msg)
|
|
prompt, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
|
if err != nil {
|
|
slog.Error("chat prompt error applying structured outputs", "error", err)
|
|
ch <- gin.H{"error": err.Error()}
|
|
return
|
|
}
|
|
// force constraining by terminating thinking header, the parser is already at this state
|
|
// when the last message is thinking, the rendered for gpt-oss cannot disambiguate between having the
|
|
// model continue thinking or ending thinking and outputting the final message.
|
|
// TODO(parthsareen): consider adding prefill disambiguation logic to the renderer for structured outputs.
|
|
if shouldUseHarmony(m) || (builtinParser != nil && m.Config.Parser == "harmony") {
|
|
prompt += "<|end|><|start|>assistant<|channel|>final<|message|>"
|
|
}
|
|
continue
|
|
}
|
|
|
|
break
|
|
}
|
|
}()
|
|
|
|
if req.Stream != nil && !*req.Stream {
|
|
var resp api.ChatResponse
|
|
var toolCalls []api.ToolCall
|
|
var allLogprobs []api.Logprob
|
|
var sbThinking strings.Builder
|
|
var sbContent strings.Builder
|
|
for rr := range ch {
|
|
switch t := rr.(type) {
|
|
case api.ChatResponse:
|
|
sbThinking.WriteString(t.Message.Thinking)
|
|
sbContent.WriteString(t.Message.Content)
|
|
resp = t
|
|
if len(req.Tools) > 0 {
|
|
toolCalls = append(toolCalls, t.Message.ToolCalls...)
|
|
}
|
|
// Accumulate logprobs from all chunks for non-streaming response
|
|
if len(t.Logprobs) > 0 {
|
|
allLogprobs = append(allLogprobs, t.Logprobs...)
|
|
}
|
|
case gin.H:
|
|
msg, ok := t["error"].(string)
|
|
if !ok {
|
|
msg = "unexpected error format in response"
|
|
}
|
|
|
|
status, ok := t["status"].(int)
|
|
if !ok {
|
|
status = http.StatusInternalServerError
|
|
}
|
|
|
|
c.JSON(status, gin.H{"error": msg})
|
|
return
|
|
default:
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
|
return
|
|
}
|
|
}
|
|
|
|
resp.Message.Content = sbContent.String()
|
|
resp.Message.Thinking = sbThinking.String()
|
|
resp.Logprobs = allLogprobs
|
|
|
|
if len(toolCalls) > 0 {
|
|
resp.Message.ToolCalls = toolCalls
|
|
}
|
|
|
|
c.JSON(http.StatusOK, resp)
|
|
return
|
|
}
|
|
|
|
streamResponse(c, ch)
|
|
}
|
|
|
|
func handleScheduleError(c *gin.Context, name string, err error) {
|
|
switch {
|
|
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
case errors.Is(err, context.Canceled):
|
|
c.JSON(499, gin.H{"error": "request canceled"})
|
|
case errors.Is(err, ErrMaxQueue):
|
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
|
|
case errors.Is(err, os.ErrNotExist):
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
|
|
default:
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
}
|
|
}
|
|
|
|
func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
|
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
|
|
finalUserIndex := -1
|
|
for i, msg := range msgs {
|
|
if msg.Role == "user" {
|
|
finalUserIndex = i
|
|
}
|
|
}
|
|
|
|
for i, msg := range msgs {
|
|
if msg.Role == "assistant" && i < finalUserIndex {
|
|
// TODO(drifkin): this is from before we added proper thinking support.
|
|
// However, even if thinking is not enabled (and therefore we shouldn't
|
|
// change the user output), we should probably perform this filtering
|
|
// for all thinking models (not just qwen3 & deepseek-r1) since it tends
|
|
// to save tokens and improve quality.
|
|
thinkingState := &thinking.Parser{
|
|
OpeningTag: "<think>",
|
|
ClosingTag: "</think>",
|
|
}
|
|
_, content := thinkingState.AddContent(msg.Content)
|
|
msgs[i].Content = content
|
|
}
|
|
}
|
|
}
|
|
return msgs
|
|
}
|
|
|
|
// handleImageGenerate handles image generation requests within GenerateHandler.
|
|
// This is called when the model has the Image capability.
|
|
func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, modelName string, checkpointStart time.Time) {
|
|
// Validate image dimensions
|
|
const maxDimension int32 = 4096
|
|
if req.Width > maxDimension || req.Height > maxDimension {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("width and height must be <= %d", maxDimension)})
|
|
return
|
|
}
|
|
|
|
// Schedule the runner for image generation
|
|
runner, _, _, err := s.scheduleRunner(c.Request.Context(), modelName, []model.Capability{model.CapabilityImage}, nil, req.KeepAlive)
|
|
if err != nil {
|
|
handleScheduleError(c, req.Model, err)
|
|
return
|
|
}
|
|
|
|
checkpointLoaded := time.Now()
|
|
|
|
// Handle load-only request (empty prompt)
|
|
if req.Prompt == "" {
|
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
|
Model: req.Model,
|
|
CreatedAt: time.Now().UTC(),
|
|
Done: true,
|
|
DoneReason: "load",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Check streaming preference
|
|
isStreaming := req.Stream == nil || *req.Stream
|
|
|
|
contentType := "application/x-ndjson"
|
|
if !isStreaming {
|
|
contentType = "application/json; charset=utf-8"
|
|
}
|
|
c.Header("Content-Type", contentType)
|
|
|
|
// Get seed from options if provided
|
|
var seed int64
|
|
if s, ok := req.Options["seed"]; ok {
|
|
switch v := s.(type) {
|
|
case int:
|
|
seed = int64(v)
|
|
case int64:
|
|
seed = v
|
|
case float64:
|
|
seed = int64(v)
|
|
}
|
|
}
|
|
|
|
var images []llm.ImageData
|
|
for i, imgData := range req.Images {
|
|
images = append(images, llm.ImageData{ID: i, Data: imgData})
|
|
}
|
|
|
|
var streamStarted bool
|
|
var finalResponse api.GenerateResponse
|
|
|
|
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
|
|
Prompt: req.Prompt,
|
|
Width: req.Width,
|
|
Height: req.Height,
|
|
Steps: req.Steps,
|
|
Seed: seed,
|
|
Images: images,
|
|
}, func(cr llm.CompletionResponse) {
|
|
streamStarted = true
|
|
res := api.GenerateResponse{
|
|
Model: req.Model,
|
|
CreatedAt: time.Now().UTC(),
|
|
Done: cr.Done,
|
|
}
|
|
|
|
if cr.TotalSteps > 0 {
|
|
res.Completed = int64(cr.Step)
|
|
res.Total = int64(cr.TotalSteps)
|
|
}
|
|
|
|
if cr.Image != "" {
|
|
res.Image = cr.Image
|
|
}
|
|
|
|
if cr.Done {
|
|
res.DoneReason = cr.DoneReason.String()
|
|
res.Metrics.TotalDuration = time.Since(checkpointStart)
|
|
res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
}
|
|
|
|
if !isStreaming {
|
|
finalResponse = res
|
|
return
|
|
}
|
|
|
|
data, _ := json.Marshal(res)
|
|
c.Writer.Write(append(data, '\n'))
|
|
c.Writer.Flush()
|
|
}); err != nil {
|
|
// Only send JSON error if streaming hasn't started yet
|
|
// (once streaming starts, headers are committed and we can't change status code)
|
|
if !streamStarted {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
}
|
|
return
|
|
}
|
|
|
|
if !isStreaming {
|
|
c.JSON(http.StatusOK, finalResponse)
|
|
}
|
|
}
|