mirror of
https://github.com/ollama/ollama.git
synced 2026-04-24 01:35:49 +02:00
Compare commits
6 Commits
pdevine/qw
...
parth-anth
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
40f56cf543 | ||
|
|
22c2bdbd8a | ||
|
|
6df6d097d9 | ||
|
|
d7c176ab91 | ||
|
|
0ff7d724ff | ||
|
|
46cb7795e1 |
@@ -372,6 +372,24 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||
return convertedRequest, nil
|
||||
}
|
||||
|
||||
func extractBase64Image(blockMap map[string]any) (api.ImageData, error) {
|
||||
source, ok := blockMap["source"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid image source")
|
||||
}
|
||||
|
||||
sourceType, _ := source["type"].(string)
|
||||
if sourceType == "base64" {
|
||||
data, _ := source["data"].(string)
|
||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported", sourceType)
|
||||
}
|
||||
|
||||
// convertMessage converts an Anthropic MessageParam to Ollama api.Message(s)
|
||||
func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
var messages []api.Message
|
||||
@@ -414,26 +432,12 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
|
||||
case "image":
|
||||
imageBlocks++
|
||||
source, ok := blockMap["source"].(map[string]any)
|
||||
if !ok {
|
||||
logutil.Trace("anthropic: invalid image source", "role", role)
|
||||
return nil, errors.New("invalid image source")
|
||||
decoded, err := extractBase64Image(blockMap)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: failed to extract image", "role", role, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sourceType, _ := source["type"].(string)
|
||||
if sourceType == "base64" {
|
||||
data, _ := source["data"].(string)
|
||||
decoded, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
|
||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||
}
|
||||
images = append(images, decoded)
|
||||
} else {
|
||||
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", sourceType)
|
||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType)
|
||||
}
|
||||
// URL images would need to be fetched - skip for now
|
||||
images = append(images, decoded)
|
||||
|
||||
case "tool_use":
|
||||
toolUseBlocks++
|
||||
@@ -462,6 +466,7 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
toolResultBlocks++
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
var resultContent string
|
||||
var resultImages []api.ImageData
|
||||
|
||||
switch c := blockMap["content"].(type) {
|
||||
case string:
|
||||
@@ -469,10 +474,18 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
case []any:
|
||||
for _, cb := range c {
|
||||
if cbMap, ok := cb.(map[string]any); ok {
|
||||
if cbMap["type"] == "text" {
|
||||
switch cbMap["type"] {
|
||||
case "text":
|
||||
if text, ok := cbMap["text"].(string); ok {
|
||||
resultContent += text
|
||||
}
|
||||
case "image":
|
||||
decoded, err := extractBase64Image(cbMap)
|
||||
if err != nil {
|
||||
logutil.Trace("anthropic: failed to extract image from tool_result", "role", role, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
resultImages = append(resultImages, decoded)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -481,6 +494,7 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: resultContent,
|
||||
Images: resultImages,
|
||||
ToolCallID: toolUseID,
|
||||
})
|
||||
|
||||
|
||||
@@ -266,6 +266,124 @@ func TestFromMessagesRequest_WithToolResult(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolResultContainingImage(t *testing.T) {
|
||||
imgData, _ := base64.StdEncoding.DecodeString(testImage)
|
||||
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_456",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "Here is the screenshot:"},
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": testImage,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
msg := result.Messages[0]
|
||||
if msg.Role != "tool" {
|
||||
t.Errorf("expected role 'tool', got %q", msg.Role)
|
||||
}
|
||||
if msg.ToolCallID != "call_456" {
|
||||
t.Errorf("expected tool_call_id 'call_456', got %q", msg.ToolCallID)
|
||||
}
|
||||
if msg.Content != "Here is the screenshot:" {
|
||||
t.Errorf("expected content 'Here is the screenshot:', got %q", msg.Content)
|
||||
}
|
||||
if len(msg.Images) != 1 {
|
||||
t.Fatalf("expected 1 image in tool result, got %d", len(msg.Images))
|
||||
}
|
||||
if string(msg.Images[0]) != string(imgData) {
|
||||
t.Error("image data mismatch in tool result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithToolResultContainingMultipleImages(t *testing.T) {
|
||||
imgData, _ := base64.StdEncoding.DecodeString(testImage)
|
||||
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
MaxTokens: 1024,
|
||||
Messages: []MessageParam{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_789",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": testImage,
|
||||
},
|
||||
},
|
||||
map[string]any{"type": "text", "text": "First image above, second below:"},
|
||||
map[string]any{
|
||||
"type": "image",
|
||||
"source": map[string]any{
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": testImage,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromMessagesRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
msg := result.Messages[0]
|
||||
if msg.Role != "tool" {
|
||||
t.Errorf("expected role 'tool', got %q", msg.Role)
|
||||
}
|
||||
if len(msg.Images) != 2 {
|
||||
t.Fatalf("expected 2 images in tool result, got %d", len(msg.Images))
|
||||
}
|
||||
for i, img := range msg.Images {
|
||||
if string(img) != string(imgData) {
|
||||
t.Errorf("image %d data mismatch in tool result", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromMessagesRequest_WithTools(t *testing.T) {
|
||||
req := MessagesRequest{
|
||||
Model: "test-model",
|
||||
|
||||
@@ -80,6 +80,12 @@ func (c *Openclaw) Run(model string, args []string) error {
|
||||
}
|
||||
if canInstallDaemon() {
|
||||
onboardArgs = append(onboardArgs, "--install-daemon")
|
||||
} else {
|
||||
// When we can't install a daemon (e.g. no systemd, sudo dropped
|
||||
// XDG_RUNTIME_DIR, or container environment), skip the gateway
|
||||
// health check so non-interactive onboarding completes. The
|
||||
// gateway is started as a foreground child process after onboarding.
|
||||
onboardArgs = append(onboardArgs, "--skip-health")
|
||||
}
|
||||
cmd := exec.Command(bin, onboardArgs...)
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
@@ -160,6 +160,12 @@
|
||||
"group": "More information",
|
||||
"pages": [
|
||||
"/cli",
|
||||
{
|
||||
"group": "Assistant Sandboxing",
|
||||
"pages": [
|
||||
"/integrations/nemoclaw"
|
||||
]
|
||||
},
|
||||
"/modelfile",
|
||||
"/context-length",
|
||||
"/linux",
|
||||
|
||||
67
docs/integrations/nemoclaw.mdx
Normal file
67
docs/integrations/nemoclaw.mdx
Normal file
@@ -0,0 +1,67 @@
|
||||
---
|
||||
title: NemoClaw
|
||||
---
|
||||
|
||||
NemoClaw is NVIDIA's open source security stack for [OpenClaw](/integrations/openclaw). It wraps OpenClaw with the NVIDIA OpenShell runtime to provide kernel-level sandboxing, network policy controls, and audit trails for AI agents.
|
||||
|
||||
## Quick start
|
||||
|
||||
Pull a model:
|
||||
|
||||
```bash
|
||||
ollama pull nemotron-3-nano:30b
|
||||
```
|
||||
|
||||
Run the installer:
|
||||
|
||||
```bash
|
||||
curl -fsSL https://www.nvidia.com/nemoclaw.sh | \
|
||||
NEMOCLAW_NON_INTERACTIVE=1 \
|
||||
NEMOCLAW_PROVIDER=ollama \
|
||||
NEMOCLAW_MODEL=nemotron-3-nano:30b \
|
||||
bash
|
||||
```
|
||||
|
||||
Connect to your sandbox:
|
||||
|
||||
```bash
|
||||
nemoclaw my-assistant connect
|
||||
```
|
||||
|
||||
Open the TUI:
|
||||
|
||||
```bash
|
||||
openclaw tui
|
||||
```
|
||||
|
||||
<Note>Ollama support in NemoClaw is still experimental.</Note>
|
||||
|
||||
## Platform support
|
||||
|
||||
| Platform | Runtime | Status |
|
||||
|----------|---------|--------|
|
||||
| Linux (Ubuntu 22.04+) | Docker | Primary |
|
||||
| macOS (Apple Silicon) | Colima or Docker Desktop | Supported |
|
||||
| Windows | WSL2 with Docker Desktop | Supported |
|
||||
|
||||
CMD and PowerShell are not supported on Windows — WSL2 is required.
|
||||
|
||||
<Note>Ollama must be installed and running before the installer runs. When running inside WSL2 or a container, ensure Ollama is reachable from the sandbox (e.g. `OLLAMA_HOST=0.0.0.0`).</Note>
|
||||
|
||||
## System requirements
|
||||
|
||||
- CPU: 4 vCPU minimum
|
||||
- RAM: 8 GB minimum (16 GB recommended)
|
||||
- Disk: 20 GB free (40 GB recommended for local models)
|
||||
- Node.js 20+ and npm 10+
|
||||
- Container runtime (Docker preferred)
|
||||
|
||||
## Recommended models
|
||||
|
||||
- `nemotron-3-super:cloud` — Strong reasoning and coding
|
||||
- `qwen3.5:cloud` — 397B; reasoning and code generation
|
||||
- `nemotron-3-nano:30b` — Recommended local model; fits in 24 GB VRAM
|
||||
- `qwen3.5:27b` — Fast local reasoning (~18 GB VRAM)
|
||||
- `glm-4.7-flash` — Reasoning and code generation (~25 GB VRAM)
|
||||
|
||||
More models at [ollama.com/search](https://ollama.com/search).
|
||||
@@ -214,6 +214,8 @@ func LogLevel() slog.Level {
|
||||
var (
|
||||
// FlashAttention enables the experimental flash attention feature.
|
||||
FlashAttention = BoolWithDefault("OLLAMA_FLASH_ATTENTION")
|
||||
// DebugLogRequests logs inference requests to disk for replay/debugging.
|
||||
DebugLogRequests = Bool("OLLAMA_DEBUG_LOG_REQUESTS")
|
||||
// KvCacheType is the quantization type for the K/V cache.
|
||||
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
|
||||
// NoHistory disables readline history.
|
||||
@@ -302,28 +304,29 @@ type EnvVar struct {
|
||||
|
||||
func AsMap() map[string]EnvVar {
|
||||
ret := map[string]EnvVar{
|
||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
|
||||
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
|
||||
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
|
||||
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
|
||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
|
||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
|
||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
|
||||
"OLLAMA_NO_CLOUD": {"OLLAMA_NO_CLOUD", NoCloud(), "Disable Ollama cloud features (remote inference and web search)"},
|
||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
|
||||
"OLLAMA_EDITOR": {"OLLAMA_EDITOR", Editor(), "Path to editor for interactive prompt editing (Ctrl+G)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||
"OLLAMA_DEBUG_LOG_REQUESTS": {"OLLAMA_DEBUG_LOG_REQUESTS", DebugLogRequests(), "Log inference request bodies and replay curl commands to a temp directory"},
|
||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
|
||||
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
|
||||
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
|
||||
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
|
||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
|
||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
|
||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
|
||||
"OLLAMA_NO_CLOUD": {"OLLAMA_NO_CLOUD", NoCloud(), "Disable Ollama cloud features (remote inference and web search)"},
|
||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
|
||||
"OLLAMA_EDITOR": {"OLLAMA_EDITOR", Editor(), "Path to editor for interactive prompt editing (Ctrl+G)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||
|
||||
// Informational
|
||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||
|
||||
@@ -87,7 +87,8 @@ type LlamaServer interface {
|
||||
type llmServer struct {
|
||||
port int
|
||||
cmd *exec.Cmd
|
||||
done chan error // Channel to signal when the process exits
|
||||
done chan struct{} // closed when the process exits
|
||||
doneErr error // valid after done is closed
|
||||
status *StatusWriter
|
||||
options api.Options
|
||||
modelPath string
|
||||
@@ -280,7 +281,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||
totalLayers: f.KV().BlockCount() + 1,
|
||||
loadStart: time.Now(),
|
||||
done: make(chan error, 1),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -304,10 +305,11 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
if strings.Contains(s.status.LastErrMsg, "unknown model") {
|
||||
s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
|
||||
}
|
||||
s.done <- errors.New(s.status.LastErrMsg)
|
||||
s.doneErr = errors.New(s.status.LastErrMsg)
|
||||
} else {
|
||||
s.done <- err
|
||||
s.doneErr = err
|
||||
}
|
||||
close(s.done)
|
||||
}()
|
||||
|
||||
if tok != nil {
|
||||
@@ -1356,8 +1358,8 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||
case <-ctx.Done():
|
||||
slog.Warn("client connection closed before server finished loading, aborting load")
|
||||
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
|
||||
case err := <-s.done:
|
||||
return fmt.Errorf("llama runner process has terminated: %w", err)
|
||||
case <-s.done:
|
||||
return fmt.Errorf("llama runner process has terminated: %w", s.doneErr)
|
||||
default:
|
||||
}
|
||||
if time.Now().After(stallTimer) {
|
||||
|
||||
144
server/inference_request_log.go
Normal file
144
server/inference_request_log.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
type inferenceRequestLogger struct {
|
||||
dir string
|
||||
counter uint64
|
||||
}
|
||||
|
||||
func newInferenceRequestLogger() (*inferenceRequestLogger, error) {
|
||||
dir, err := os.MkdirTemp("", "ollama-request-logs-*")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &inferenceRequestLogger{dir: dir}, nil
|
||||
}
|
||||
|
||||
func (s *Server) initRequestLogging() error {
|
||||
if !envconfig.DebugLogRequests() {
|
||||
return nil
|
||||
}
|
||||
|
||||
requestLogger, err := newInferenceRequestLogger()
|
||||
if err != nil {
|
||||
return fmt.Errorf("enable OLLAMA_DEBUG_LOG_REQUESTS: %w", err)
|
||||
}
|
||||
|
||||
s.requestLogger = requestLogger
|
||||
slog.Info(fmt.Sprintf("request debug logging enabled; inference request logs will be stored in %s and include request bodies and replay curl commands", requestLogger.dir))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) withInferenceRequestLogging(route string, handlers ...gin.HandlerFunc) []gin.HandlerFunc {
|
||||
if s.requestLogger == nil {
|
||||
return handlers
|
||||
}
|
||||
|
||||
return append([]gin.HandlerFunc{s.requestLogger.middleware(route)}, handlers...)
|
||||
}
|
||||
|
||||
func (l *inferenceRequestLogger) middleware(route string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request == nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
method := c.Request.Method
|
||||
host := c.Request.Host
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
|
||||
var body []byte
|
||||
if c.Request.Body != nil {
|
||||
var err error
|
||||
body, err = io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(body))
|
||||
if err != nil {
|
||||
slog.Warn("failed to read request body for debug logging", "route", route, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
l.log(route, method, scheme, host, contentType, body)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *inferenceRequestLogger) log(route, method, scheme, host, contentType string, body []byte) {
|
||||
if l == nil || l.dir == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
if host == "" || scheme == "" {
|
||||
base := envconfig.Host()
|
||||
if host == "" {
|
||||
host = base.Host
|
||||
}
|
||||
if scheme == "" {
|
||||
scheme = base.Scheme
|
||||
}
|
||||
}
|
||||
|
||||
routeForFilename := sanitizeRouteForFilename(route)
|
||||
timestamp := fmt.Sprintf("%s-%06d", time.Now().UTC().Format("20060102T150405.000000000Z"), atomic.AddUint64(&l.counter, 1))
|
||||
bodyFilename := fmt.Sprintf("%s_%s_body.json", timestamp, routeForFilename)
|
||||
curlFilename := fmt.Sprintf("%s_%s_request.sh", timestamp, routeForFilename)
|
||||
bodyPath := filepath.Join(l.dir, bodyFilename)
|
||||
curlPath := filepath.Join(l.dir, curlFilename)
|
||||
|
||||
if err := os.WriteFile(bodyPath, body, 0o600); err != nil {
|
||||
slog.Warn("failed to write debug request body", "route", route, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s://%s%s", scheme, host, route)
|
||||
curl := fmt.Sprintf("#!/bin/sh\nSCRIPT_DIR=\"$(CDPATH= cd -- \"$(dirname -- \"$0\")\" && pwd)\"\ncurl --request %s --url %q --header %q --data-binary @\"${SCRIPT_DIR}/%s\"\n", method, url, "Content-Type: "+contentType, bodyFilename)
|
||||
if err := os.WriteFile(curlPath, []byte(curl), 0o600); err != nil {
|
||||
slog.Warn("failed to write debug request replay command", "route", route, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info(fmt.Sprintf("logged to %s, replay using curl with `sh %s`", bodyPath, curlPath))
|
||||
}
|
||||
|
||||
func sanitizeRouteForFilename(route string) string {
|
||||
route = strings.TrimPrefix(route, "/")
|
||||
if route == "" {
|
||||
return "root"
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(route))
|
||||
for _, r := range route {
|
||||
if ('a' <= r && r <= 'z') || ('A' <= r && r <= 'Z') || ('0' <= r && r <= '9') {
|
||||
b.WriteRune(r)
|
||||
} else {
|
||||
b.WriteByte('_')
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
@@ -100,6 +100,7 @@ type Server struct {
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
defaultNumCtx int
|
||||
requestLogger *inferenceRequestLogger
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -1686,26 +1687,26 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
|
||||
// Inference
|
||||
r.GET("/api/ps", s.PsHandler)
|
||||
r.POST("/api/generate", s.GenerateHandler)
|
||||
r.POST("/api/chat", s.ChatHandler)
|
||||
r.POST("/api/generate", s.withInferenceRequestLogging("/api/generate", s.GenerateHandler)...)
|
||||
r.POST("/api/chat", s.withInferenceRequestLogging("/api/chat", s.ChatHandler)...)
|
||||
r.POST("/api/embed", s.EmbedHandler)
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
|
||||
// Inference (OpenAI compatibility)
|
||||
// TODO(cloud-stage-a): apply Modelfile overlay deltas for local models with cloud
|
||||
// parents on v1 request families while preserving this explicit :cloud passthrough.
|
||||
r.POST("/v1/chat/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.CompletionsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/chat/completions", s.withInferenceRequestLogging("/v1/chat/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ChatMiddleware(), s.ChatHandler)...)
|
||||
r.POST("/v1/completions", s.withInferenceRequestLogging("/v1/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.CompletionsMiddleware(), s.GenerateHandler)...)
|
||||
r.POST("/v1/embeddings", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", cloudModelPathPassthroughMiddleware(cloudErrRemoteModelDetailsUnavailable), middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/responses", s.withInferenceRequestLogging("/v1/responses", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ResponsesMiddleware(), s.ChatHandler)...)
|
||||
// OpenAI-compatible image generation endpoints
|
||||
r.POST("/v1/images/generations", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/edits", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/messages", s.withInferenceRequestLogging("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)...)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
@@ -1757,6 +1758,9 @@ func Serve(ln net.Listener) error {
|
||||
}
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
if err := s.initRequestLogging(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var rc *ollama.Registry
|
||||
if useClient2 {
|
||||
|
||||
128
server/routes_request_log_test.go
Normal file
128
server/routes_request_log_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestInferenceRequestLoggerMiddlewareWritesReplayArtifacts(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logDir := t.TempDir()
|
||||
requestLogger := &inferenceRequestLogger{dir: logDir}
|
||||
|
||||
const route = "/v1/chat/completions"
|
||||
const requestBody = `{"model":"test-model","messages":[{"role":"user","content":"hello"}]}`
|
||||
|
||||
var bodySeenByHandler string
|
||||
|
||||
r := gin.New()
|
||||
r.POST(route, requestLogger.middleware(route), func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read body in handler: %v", err)
|
||||
}
|
||||
|
||||
bodySeenByHandler = string(body)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, route, strings.NewReader(requestBody))
|
||||
req.Host = "127.0.0.1:11434"
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if bodySeenByHandler != requestBody {
|
||||
t.Fatalf("handler body mismatch:\nexpected: %s\ngot: %s", requestBody, bodySeenByHandler)
|
||||
}
|
||||
|
||||
bodyFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_body.json"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to glob body logs: %v", err)
|
||||
}
|
||||
if len(bodyFiles) != 1 {
|
||||
t.Fatalf("expected 1 body log, got %d (%v)", len(bodyFiles), bodyFiles)
|
||||
}
|
||||
|
||||
curlFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_request.sh"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to glob curl logs: %v", err)
|
||||
}
|
||||
if len(curlFiles) != 1 {
|
||||
t.Fatalf("expected 1 curl log, got %d (%v)", len(curlFiles), curlFiles)
|
||||
}
|
||||
|
||||
bodyData, err := os.ReadFile(bodyFiles[0])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read body log: %v", err)
|
||||
}
|
||||
if string(bodyData) != requestBody {
|
||||
t.Fatalf("body log mismatch:\nexpected: %s\ngot: %s", requestBody, string(bodyData))
|
||||
}
|
||||
|
||||
curlData, err := os.ReadFile(curlFiles[0])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read curl log: %v", err)
|
||||
}
|
||||
|
||||
curlString := string(curlData)
|
||||
if !strings.Contains(curlString, "http://127.0.0.1:11434"+route) {
|
||||
t.Fatalf("curl log does not contain expected route URL: %s", curlString)
|
||||
}
|
||||
|
||||
bodyFileName := filepath.Base(bodyFiles[0])
|
||||
if !strings.Contains(curlString, "@\"${SCRIPT_DIR}/"+bodyFileName+"\"") {
|
||||
t.Fatalf("curl log does not reference sibling body file: %s", curlString)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewInferenceRequestLoggerCreatesDirectory(t *testing.T) {
|
||||
requestLogger, err := newInferenceRequestLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error creating request logger: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(requestLogger.dir)
|
||||
})
|
||||
|
||||
if requestLogger == nil || requestLogger.dir == "" {
|
||||
t.Fatalf("expected request logger directory to be set")
|
||||
}
|
||||
|
||||
info, err := os.Stat(requestLogger.dir)
|
||||
if err != nil {
|
||||
t.Fatalf("expected directory to exist: %v", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
t.Fatalf("expected %q to be a directory", requestLogger.dir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeRouteForFilename(t *testing.T) {
|
||||
tests := []struct {
|
||||
route string
|
||||
want string
|
||||
}{
|
||||
{route: "/api/generate", want: "api_generate"},
|
||||
{route: "/v1/chat/completions", want: "v1_chat_completions"},
|
||||
{route: "/v1/messages", want: "v1_messages"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := sanitizeRouteForFilename(tt.route); got != tt.want {
|
||||
t.Fatalf("sanitizeRouteForFilename(%q) = %q, want %q", tt.route, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package mlxrunner
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -36,14 +37,69 @@ type Client struct {
|
||||
modelName string
|
||||
contextLength atomic.Int64
|
||||
memory atomic.Uint64
|
||||
done chan error
|
||||
done chan struct{}
|
||||
doneErr error // valid after done is closed
|
||||
client *http.Client
|
||||
lastErr string
|
||||
lastErrLock sync.Mutex
|
||||
status *statusWriter
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
// statusWriter captures the last stderr line from the subprocess while
|
||||
// forwarding all output to os.Stderr. Lines longer than maxStatusLen are
|
||||
// truncated to the first maxStatusLen bytes.
|
||||
type statusWriter struct {
|
||||
lastErrMsg string
|
||||
buf []byte
|
||||
discarding bool
|
||||
mu sync.Mutex
|
||||
out *os.File
|
||||
}
|
||||
|
||||
const maxStatusLen = 256
|
||||
|
||||
func (w *statusWriter) Write(b []byte) (int, error) {
|
||||
n, err := w.out.Write(b)
|
||||
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
w.buf = append(w.buf, b...)
|
||||
for {
|
||||
i := bytes.IndexByte(w.buf, '\n')
|
||||
if i < 0 {
|
||||
break
|
||||
}
|
||||
if !w.discarding {
|
||||
line := bytes.TrimSpace(w.buf[:i])
|
||||
if len(line) > 0 {
|
||||
if len(line) > maxStatusLen {
|
||||
line = line[:maxStatusLen]
|
||||
}
|
||||
w.lastErrMsg = string(line)
|
||||
}
|
||||
}
|
||||
w.buf = w.buf[i+1:]
|
||||
w.discarding = false
|
||||
}
|
||||
// if the buffer grows past maxStatusLen without a newline, keep the front
|
||||
if len(w.buf) > maxStatusLen {
|
||||
if !w.discarding {
|
||||
w.lastErrMsg = string(bytes.TrimSpace(w.buf[:maxStatusLen]))
|
||||
w.discarding = true
|
||||
}
|
||||
w.buf = w.buf[:0]
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *statusWriter) getLastErr() string {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return w.lastErrMsg
|
||||
}
|
||||
|
||||
// NewClient prepares a new MLX runner client for LLM models.
|
||||
// The subprocess is not started until Load() is called.
|
||||
func NewClient(modelName string) (*Client, error) {
|
||||
@@ -53,7 +109,7 @@ func NewClient(modelName string) (*Client, error) {
|
||||
|
||||
c := &Client{
|
||||
modelName: modelName,
|
||||
done: make(chan error, 1),
|
||||
done: make(chan struct{}),
|
||||
client: &http.Client{Timeout: 10 * time.Minute},
|
||||
}
|
||||
|
||||
@@ -66,12 +122,6 @@ func NewClient(modelName string) (*Client, error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Client) getLastErr() string {
|
||||
c.lastErrLock.Lock()
|
||||
defer c.lastErrLock.Unlock()
|
||||
return c.lastErr
|
||||
}
|
||||
|
||||
// WaitUntilRunning waits for the subprocess to be ready.
|
||||
func (c *Client) WaitUntilRunning(ctx context.Context) error {
|
||||
timeout := time.After(2 * time.Minute)
|
||||
@@ -82,16 +132,14 @@ func (c *Client) WaitUntilRunning(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-c.done:
|
||||
errMsg := c.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err)
|
||||
case <-c.done:
|
||||
if msg := c.status.getLastErr(); msg != "" {
|
||||
return fmt.Errorf("mlx runner failed: %s (exit: %v)", msg, c.doneErr)
|
||||
}
|
||||
return fmt.Errorf("mlx runner exited unexpectedly: %w", err)
|
||||
return fmt.Errorf("mlx runner exited unexpectedly: %w", c.doneErr)
|
||||
case <-timeout:
|
||||
errMsg := c.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg)
|
||||
if msg := c.status.getLastErr(); msg != "" {
|
||||
return fmt.Errorf("timeout waiting for mlx runner: %s", msg)
|
||||
}
|
||||
return errors.New("timeout waiting for mlx runner to start")
|
||||
case <-ticker.C:
|
||||
@@ -348,18 +396,13 @@ func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo
|
||||
// Forward subprocess stdout/stderr to server logs
|
||||
stdout, _ := cmd.StdoutPipe()
|
||||
stderr, _ := cmd.StderrPipe()
|
||||
status := &statusWriter{out: os.Stderr}
|
||||
c.status = status
|
||||
go func() {
|
||||
io.Copy(os.Stderr, stdout) //nolint:errcheck
|
||||
}()
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
fmt.Fprintln(os.Stderr, line)
|
||||
c.lastErrLock.Lock()
|
||||
c.lastErr = line
|
||||
c.lastErrLock.Unlock()
|
||||
}
|
||||
io.Copy(status, stderr) //nolint:errcheck
|
||||
}()
|
||||
|
||||
slog.Info("starting mlx runner subprocess", "model", c.modelName, "port", c.port)
|
||||
@@ -369,8 +412,8 @@ func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo
|
||||
|
||||
// Reap subprocess when it exits
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
c.done <- err
|
||||
c.doneErr = cmd.Wait()
|
||||
close(c.done)
|
||||
}()
|
||||
|
||||
return nil, nil
|
||||
|
||||
Reference in New Issue
Block a user