mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 13:54:11 +02:00
Compare commits
6 Commits
pdevine/qw
...
pdevine/ml
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6f26695eae | ||
|
|
22c2bdbd8a | ||
|
|
6df6d097d9 | ||
|
|
d7c176ab91 | ||
|
|
0ff7d724ff | ||
|
|
46cb7795e1 |
@@ -1 +1 @@
|
||||
v0.30.6
|
||||
v0.31.1
|
||||
|
||||
@@ -1 +1 @@
|
||||
v0.5.0
|
||||
v0.6.0
|
||||
|
||||
@@ -80,6 +80,12 @@ func (c *Openclaw) Run(model string, args []string) error {
|
||||
}
|
||||
if canInstallDaemon() {
|
||||
onboardArgs = append(onboardArgs, "--install-daemon")
|
||||
} else {
|
||||
// When we can't install a daemon (e.g. no systemd, sudo dropped
|
||||
// XDG_RUNTIME_DIR, or container environment), skip the gateway
|
||||
// health check so non-interactive onboarding completes. The
|
||||
// gateway is started as a foreground child process after onboarding.
|
||||
onboardArgs = append(onboardArgs, "--skip-health")
|
||||
}
|
||||
cmd := exec.Command(bin, onboardArgs...)
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
@@ -160,6 +160,12 @@
|
||||
"group": "More information",
|
||||
"pages": [
|
||||
"/cli",
|
||||
{
|
||||
"group": "Assistant Sandboxing",
|
||||
"pages": [
|
||||
"/integrations/nemoclaw"
|
||||
]
|
||||
},
|
||||
"/modelfile",
|
||||
"/context-length",
|
||||
"/linux",
|
||||
|
||||
67
docs/integrations/nemoclaw.mdx
Normal file
67
docs/integrations/nemoclaw.mdx
Normal file
@@ -0,0 +1,67 @@
|
||||
---
|
||||
title: NemoClaw
|
||||
---
|
||||
|
||||
NemoClaw is NVIDIA's open source security stack for [OpenClaw](/integrations/openclaw). It wraps OpenClaw with the NVIDIA OpenShell runtime to provide kernel-level sandboxing, network policy controls, and audit trails for AI agents.
|
||||
|
||||
## Quick start
|
||||
|
||||
Pull a model:
|
||||
|
||||
```bash
|
||||
ollama pull nemotron-3-nano:30b
|
||||
```
|
||||
|
||||
Run the installer:
|
||||
|
||||
```bash
|
||||
curl -fsSL https://www.nvidia.com/nemoclaw.sh | \
|
||||
NEMOCLAW_NON_INTERACTIVE=1 \
|
||||
NEMOCLAW_PROVIDER=ollama \
|
||||
NEMOCLAW_MODEL=nemotron-3-nano:30b \
|
||||
bash
|
||||
```
|
||||
|
||||
Connect to your sandbox:
|
||||
|
||||
```bash
|
||||
nemoclaw my-assistant connect
|
||||
```
|
||||
|
||||
Open the TUI:
|
||||
|
||||
```bash
|
||||
openclaw tui
|
||||
```
|
||||
|
||||
<Note>Ollama support in NemoClaw is still experimental.</Note>
|
||||
|
||||
## Platform support
|
||||
|
||||
| Platform | Runtime | Status |
|
||||
|----------|---------|--------|
|
||||
| Linux (Ubuntu 22.04+) | Docker | Primary |
|
||||
| macOS (Apple Silicon) | Colima or Docker Desktop | Supported |
|
||||
| Windows | WSL2 with Docker Desktop | Supported |
|
||||
|
||||
CMD and PowerShell are not supported on Windows — WSL2 is required.
|
||||
|
||||
<Note>Ollama must be installed and running before the installer runs. When running inside WSL2 or a container, ensure Ollama is reachable from the sandbox (e.g. `OLLAMA_HOST=0.0.0.0`).</Note>
|
||||
|
||||
## System requirements
|
||||
|
||||
- CPU: 4 vCPU minimum
|
||||
- RAM: 8 GB minimum (16 GB recommended)
|
||||
- Disk: 20 GB free (40 GB recommended for local models)
|
||||
- Node.js 20+ and npm 10+
|
||||
- Container runtime (Docker preferred)
|
||||
|
||||
## Recommended models
|
||||
|
||||
- `nemotron-3-super:cloud` — Strong reasoning and coding
|
||||
- `qwen3.5:cloud` — 397B; reasoning and code generation
|
||||
- `nemotron-3-nano:30b` — Recommended local model; fits in 24 GB VRAM
|
||||
- `qwen3.5:27b` — Fast local reasoning (~18 GB VRAM)
|
||||
- `glm-4.7-flash` — Reasoning and code generation (~25 GB VRAM)
|
||||
|
||||
More models at [ollama.com/search](https://ollama.com/search).
|
||||
@@ -214,6 +214,8 @@ func LogLevel() slog.Level {
|
||||
var (
|
||||
// FlashAttention enables the experimental flash attention feature.
|
||||
FlashAttention = BoolWithDefault("OLLAMA_FLASH_ATTENTION")
|
||||
// DebugLogRequests logs inference requests to disk for replay/debugging.
|
||||
DebugLogRequests = Bool("OLLAMA_DEBUG_LOG_REQUESTS")
|
||||
// KvCacheType is the quantization type for the K/V cache.
|
||||
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
|
||||
// NoHistory disables readline history.
|
||||
@@ -302,28 +304,29 @@ type EnvVar struct {
|
||||
|
||||
func AsMap() map[string]EnvVar {
|
||||
ret := map[string]EnvVar{
|
||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
|
||||
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
|
||||
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
|
||||
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
|
||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
|
||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
|
||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
|
||||
"OLLAMA_NO_CLOUD": {"OLLAMA_NO_CLOUD", NoCloud(), "Disable Ollama cloud features (remote inference and web search)"},
|
||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
|
||||
"OLLAMA_EDITOR": {"OLLAMA_EDITOR", Editor(), "Path to editor for interactive prompt editing (Ctrl+G)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||
"OLLAMA_DEBUG_LOG_REQUESTS": {"OLLAMA_DEBUG_LOG_REQUESTS", DebugLogRequests(), "Log inference request bodies and replay curl commands to a temp directory"},
|
||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
|
||||
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
|
||||
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
|
||||
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
|
||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
|
||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
|
||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
|
||||
"OLLAMA_NO_CLOUD": {"OLLAMA_NO_CLOUD", NoCloud(), "Disable Ollama cloud features (remote inference and web search)"},
|
||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
|
||||
"OLLAMA_EDITOR": {"OLLAMA_EDITOR", Editor(), "Path to editor for interactive prompt editing (Ctrl+G)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||
|
||||
// Informational
|
||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||
|
||||
@@ -87,7 +87,8 @@ type LlamaServer interface {
|
||||
type llmServer struct {
|
||||
port int
|
||||
cmd *exec.Cmd
|
||||
done chan error // Channel to signal when the process exits
|
||||
done chan struct{} // closed when the process exits
|
||||
doneErr error // valid after done is closed
|
||||
status *StatusWriter
|
||||
options api.Options
|
||||
modelPath string
|
||||
@@ -280,7 +281,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||
totalLayers: f.KV().BlockCount() + 1,
|
||||
loadStart: time.Now(),
|
||||
done: make(chan error, 1),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -304,10 +305,11 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||
if strings.Contains(s.status.LastErrMsg, "unknown model") {
|
||||
s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
|
||||
}
|
||||
s.done <- errors.New(s.status.LastErrMsg)
|
||||
s.doneErr = errors.New(s.status.LastErrMsg)
|
||||
} else {
|
||||
s.done <- err
|
||||
s.doneErr = err
|
||||
}
|
||||
close(s.done)
|
||||
}()
|
||||
|
||||
if tok != nil {
|
||||
@@ -1356,8 +1358,8 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||
case <-ctx.Done():
|
||||
slog.Warn("client connection closed before server finished loading, aborting load")
|
||||
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
|
||||
case err := <-s.done:
|
||||
return fmt.Errorf("llama runner process has terminated: %w", err)
|
||||
case <-s.done:
|
||||
return fmt.Errorf("llama runner process has terminated: %w", s.doneErr)
|
||||
default:
|
||||
}
|
||||
if time.Now().After(stallTimer) {
|
||||
|
||||
144
server/inference_request_log.go
Normal file
144
server/inference_request_log.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
type inferenceRequestLogger struct {
|
||||
dir string
|
||||
counter uint64
|
||||
}
|
||||
|
||||
func newInferenceRequestLogger() (*inferenceRequestLogger, error) {
|
||||
dir, err := os.MkdirTemp("", "ollama-request-logs-*")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &inferenceRequestLogger{dir: dir}, nil
|
||||
}
|
||||
|
||||
func (s *Server) initRequestLogging() error {
|
||||
if !envconfig.DebugLogRequests() {
|
||||
return nil
|
||||
}
|
||||
|
||||
requestLogger, err := newInferenceRequestLogger()
|
||||
if err != nil {
|
||||
return fmt.Errorf("enable OLLAMA_DEBUG_LOG_REQUESTS: %w", err)
|
||||
}
|
||||
|
||||
s.requestLogger = requestLogger
|
||||
slog.Info(fmt.Sprintf("request debug logging enabled; inference request logs will be stored in %s and include request bodies and replay curl commands", requestLogger.dir))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) withInferenceRequestLogging(route string, handlers ...gin.HandlerFunc) []gin.HandlerFunc {
|
||||
if s.requestLogger == nil {
|
||||
return handlers
|
||||
}
|
||||
|
||||
return append([]gin.HandlerFunc{s.requestLogger.middleware(route)}, handlers...)
|
||||
}
|
||||
|
||||
func (l *inferenceRequestLogger) middleware(route string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request == nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
method := c.Request.Method
|
||||
host := c.Request.Host
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
|
||||
var body []byte
|
||||
if c.Request.Body != nil {
|
||||
var err error
|
||||
body, err = io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(body))
|
||||
if err != nil {
|
||||
slog.Warn("failed to read request body for debug logging", "route", route, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
l.log(route, method, scheme, host, contentType, body)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *inferenceRequestLogger) log(route, method, scheme, host, contentType string, body []byte) {
|
||||
if l == nil || l.dir == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
if host == "" || scheme == "" {
|
||||
base := envconfig.Host()
|
||||
if host == "" {
|
||||
host = base.Host
|
||||
}
|
||||
if scheme == "" {
|
||||
scheme = base.Scheme
|
||||
}
|
||||
}
|
||||
|
||||
routeForFilename := sanitizeRouteForFilename(route)
|
||||
timestamp := fmt.Sprintf("%s-%06d", time.Now().UTC().Format("20060102T150405.000000000Z"), atomic.AddUint64(&l.counter, 1))
|
||||
bodyFilename := fmt.Sprintf("%s_%s_body.json", timestamp, routeForFilename)
|
||||
curlFilename := fmt.Sprintf("%s_%s_request.sh", timestamp, routeForFilename)
|
||||
bodyPath := filepath.Join(l.dir, bodyFilename)
|
||||
curlPath := filepath.Join(l.dir, curlFilename)
|
||||
|
||||
if err := os.WriteFile(bodyPath, body, 0o600); err != nil {
|
||||
slog.Warn("failed to write debug request body", "route", route, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s://%s%s", scheme, host, route)
|
||||
curl := fmt.Sprintf("#!/bin/sh\nSCRIPT_DIR=\"$(CDPATH= cd -- \"$(dirname -- \"$0\")\" && pwd)\"\ncurl --request %s --url %q --header %q --data-binary @\"${SCRIPT_DIR}/%s\"\n", method, url, "Content-Type: "+contentType, bodyFilename)
|
||||
if err := os.WriteFile(curlPath, []byte(curl), 0o600); err != nil {
|
||||
slog.Warn("failed to write debug request replay command", "route", route, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info(fmt.Sprintf("logged to %s, replay using curl with `sh %s`", bodyPath, curlPath))
|
||||
}
|
||||
|
||||
func sanitizeRouteForFilename(route string) string {
|
||||
route = strings.TrimPrefix(route, "/")
|
||||
if route == "" {
|
||||
return "root"
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(route))
|
||||
for _, r := range route {
|
||||
if ('a' <= r && r <= 'z') || ('A' <= r && r <= 'Z') || ('0' <= r && r <= '9') {
|
||||
b.WriteRune(r)
|
||||
} else {
|
||||
b.WriteByte('_')
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
@@ -100,6 +100,7 @@ type Server struct {
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
defaultNumCtx int
|
||||
requestLogger *inferenceRequestLogger
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -1686,26 +1687,26 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
|
||||
// Inference
|
||||
r.GET("/api/ps", s.PsHandler)
|
||||
r.POST("/api/generate", s.GenerateHandler)
|
||||
r.POST("/api/chat", s.ChatHandler)
|
||||
r.POST("/api/generate", s.withInferenceRequestLogging("/api/generate", s.GenerateHandler)...)
|
||||
r.POST("/api/chat", s.withInferenceRequestLogging("/api/chat", s.ChatHandler)...)
|
||||
r.POST("/api/embed", s.EmbedHandler)
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
|
||||
// Inference (OpenAI compatibility)
|
||||
// TODO(cloud-stage-a): apply Modelfile overlay deltas for local models with cloud
|
||||
// parents on v1 request families while preserving this explicit :cloud passthrough.
|
||||
r.POST("/v1/chat/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.CompletionsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/chat/completions", s.withInferenceRequestLogging("/v1/chat/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ChatMiddleware(), s.ChatHandler)...)
|
||||
r.POST("/v1/completions", s.withInferenceRequestLogging("/v1/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.CompletionsMiddleware(), s.GenerateHandler)...)
|
||||
r.POST("/v1/embeddings", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", cloudModelPathPassthroughMiddleware(cloudErrRemoteModelDetailsUnavailable), middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/responses", s.withInferenceRequestLogging("/v1/responses", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ResponsesMiddleware(), s.ChatHandler)...)
|
||||
// OpenAI-compatible image generation endpoints
|
||||
r.POST("/v1/images/generations", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/edits", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/messages", s.withInferenceRequestLogging("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)...)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
@@ -1757,6 +1758,9 @@ func Serve(ln net.Listener) error {
|
||||
}
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
if err := s.initRequestLogging(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var rc *ollama.Registry
|
||||
if useClient2 {
|
||||
|
||||
128
server/routes_request_log_test.go
Normal file
128
server/routes_request_log_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestInferenceRequestLoggerMiddlewareWritesReplayArtifacts(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logDir := t.TempDir()
|
||||
requestLogger := &inferenceRequestLogger{dir: logDir}
|
||||
|
||||
const route = "/v1/chat/completions"
|
||||
const requestBody = `{"model":"test-model","messages":[{"role":"user","content":"hello"}]}`
|
||||
|
||||
var bodySeenByHandler string
|
||||
|
||||
r := gin.New()
|
||||
r.POST(route, requestLogger.middleware(route), func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read body in handler: %v", err)
|
||||
}
|
||||
|
||||
bodySeenByHandler = string(body)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, route, strings.NewReader(requestBody))
|
||||
req.Host = "127.0.0.1:11434"
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if bodySeenByHandler != requestBody {
|
||||
t.Fatalf("handler body mismatch:\nexpected: %s\ngot: %s", requestBody, bodySeenByHandler)
|
||||
}
|
||||
|
||||
bodyFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_body.json"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to glob body logs: %v", err)
|
||||
}
|
||||
if len(bodyFiles) != 1 {
|
||||
t.Fatalf("expected 1 body log, got %d (%v)", len(bodyFiles), bodyFiles)
|
||||
}
|
||||
|
||||
curlFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_request.sh"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to glob curl logs: %v", err)
|
||||
}
|
||||
if len(curlFiles) != 1 {
|
||||
t.Fatalf("expected 1 curl log, got %d (%v)", len(curlFiles), curlFiles)
|
||||
}
|
||||
|
||||
bodyData, err := os.ReadFile(bodyFiles[0])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read body log: %v", err)
|
||||
}
|
||||
if string(bodyData) != requestBody {
|
||||
t.Fatalf("body log mismatch:\nexpected: %s\ngot: %s", requestBody, string(bodyData))
|
||||
}
|
||||
|
||||
curlData, err := os.ReadFile(curlFiles[0])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read curl log: %v", err)
|
||||
}
|
||||
|
||||
curlString := string(curlData)
|
||||
if !strings.Contains(curlString, "http://127.0.0.1:11434"+route) {
|
||||
t.Fatalf("curl log does not contain expected route URL: %s", curlString)
|
||||
}
|
||||
|
||||
bodyFileName := filepath.Base(bodyFiles[0])
|
||||
if !strings.Contains(curlString, "@\"${SCRIPT_DIR}/"+bodyFileName+"\"") {
|
||||
t.Fatalf("curl log does not reference sibling body file: %s", curlString)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewInferenceRequestLoggerCreatesDirectory(t *testing.T) {
|
||||
requestLogger, err := newInferenceRequestLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error creating request logger: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(requestLogger.dir)
|
||||
})
|
||||
|
||||
if requestLogger == nil || requestLogger.dir == "" {
|
||||
t.Fatalf("expected request logger directory to be set")
|
||||
}
|
||||
|
||||
info, err := os.Stat(requestLogger.dir)
|
||||
if err != nil {
|
||||
t.Fatalf("expected directory to exist: %v", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
t.Fatalf("expected %q to be a directory", requestLogger.dir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeRouteForFilename(t *testing.T) {
|
||||
tests := []struct {
|
||||
route string
|
||||
want string
|
||||
}{
|
||||
{route: "/api/generate", want: "api_generate"},
|
||||
{route: "/v1/chat/completions", want: "v1_chat_completions"},
|
||||
{route: "/v1/messages", want: "v1_messages"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := sanitizeRouteForFilename(tt.route); got != tt.want {
|
||||
t.Fatalf("sanitizeRouteForFilename(%q) = %q, want %q", tt.route, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -165,8 +165,8 @@ int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x, const
|
||||
int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group) = NULL;
|
||||
int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key) = NULL;
|
||||
bool (*mlx_distributed_is_available_ptr)(void) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict) = NULL;
|
||||
bool (*mlx_distributed_is_available_ptr)(const char* bk) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk) = NULL;
|
||||
void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) = NULL;
|
||||
void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...) = NULL;
|
||||
int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) = NULL;
|
||||
@@ -319,10 +319,12 @@ int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, const
|
||||
int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_bartlett_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||
int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||
int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||
int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||
int (*mlx_blackman_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||
int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) = NULL;
|
||||
int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) = NULL;
|
||||
int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL;
|
||||
@@ -348,7 +350,7 @@ int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse
|
||||
int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
|
||||
int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) = NULL;
|
||||
int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) = NULL;
|
||||
int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) = NULL;
|
||||
int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
|
||||
int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) = NULL;
|
||||
int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||
@@ -375,6 +377,8 @@ int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w,
|
||||
int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||
int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||
int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) = NULL;
|
||||
int (*mlx_hamming_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||
int (*mlx_hanning_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||
int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
|
||||
int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||
@@ -434,8 +438,8 @@ int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, siz
|
||||
int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
|
||||
int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
|
||||
int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL;
|
||||
int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
||||
int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
||||
int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) = NULL;
|
||||
int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) = NULL;
|
||||
int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
||||
int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
@@ -2101,6 +2105,11 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_bartlett_ptr = GET_SYM(handle, "mlx_bartlett");
|
||||
if (mlx_bartlett_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_bartlett\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_bitwise_and_ptr = GET_SYM(handle, "mlx_bitwise_and");
|
||||
if (mlx_bitwise_and_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n");
|
||||
@@ -2121,6 +2130,11 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_blackman_ptr = GET_SYM(handle, "mlx_blackman");
|
||||
if (mlx_blackman_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_blackman\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_block_masked_mm_ptr = GET_SYM(handle, "mlx_block_masked_mm");
|
||||
if (mlx_block_masked_mm_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n");
|
||||
@@ -2381,6 +2395,16 @@ int mlx_load_functions(void* handle) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_hamming_ptr = GET_SYM(handle, "mlx_hamming");
|
||||
if (mlx_hamming_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_hamming\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_hanning_ptr = GET_SYM(handle, "mlx_hanning");
|
||||
if (mlx_hanning_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_hanning\n");
|
||||
return -1;
|
||||
}
|
||||
mlx_identity_ptr = GET_SYM(handle, "mlx_identity");
|
||||
if (mlx_identity_ptr == NULL) {
|
||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n");
|
||||
@@ -4132,12 +4156,12 @@ mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, i
|
||||
return mlx_distributed_group_split_ptr(group, color, key);
|
||||
}
|
||||
|
||||
bool mlx_distributed_is_available(void) {
|
||||
return mlx_distributed_is_available_ptr();
|
||||
bool mlx_distributed_is_available(const char* bk) {
|
||||
return mlx_distributed_is_available_ptr(bk);
|
||||
}
|
||||
|
||||
mlx_distributed_group mlx_distributed_init(bool strict) {
|
||||
return mlx_distributed_init_ptr(strict);
|
||||
mlx_distributed_group mlx_distributed_init(bool strict, const char* bk) {
|
||||
return mlx_distributed_init_ptr(strict, bk);
|
||||
}
|
||||
|
||||
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) {
|
||||
@@ -4748,6 +4772,10 @@ int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
||||
return mlx_atleast_3d_ptr(res, a, s);
|
||||
}
|
||||
|
||||
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) {
|
||||
return mlx_bartlett_ptr(res, M, s);
|
||||
}
|
||||
|
||||
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
||||
return mlx_bitwise_and_ptr(res, a, b, s);
|
||||
}
|
||||
@@ -4764,6 +4792,10 @@ int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const
|
||||
return mlx_bitwise_xor_ptr(res, a, b, s);
|
||||
}
|
||||
|
||||
int mlx_blackman(mlx_array* res, int M, const mlx_stream s) {
|
||||
return mlx_blackman_ptr(res, M, s);
|
||||
}
|
||||
|
||||
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) {
|
||||
return mlx_block_masked_mm_ptr(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s);
|
||||
}
|
||||
@@ -4864,8 +4896,8 @@ int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_
|
||||
return mlx_depends_ptr(res, inputs, dependencies);
|
||||
}
|
||||
|
||||
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) {
|
||||
return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, dtype, s);
|
||||
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) {
|
||||
return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s);
|
||||
}
|
||||
|
||||
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
|
||||
@@ -4972,6 +5004,14 @@ int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float
|
||||
return mlx_hadamard_transform_ptr(res, a, scale, s);
|
||||
}
|
||||
|
||||
int mlx_hamming(mlx_array* res, int M, const mlx_stream s) {
|
||||
return mlx_hamming_ptr(res, M, s);
|
||||
}
|
||||
|
||||
int mlx_hanning(mlx_array* res, int M, const mlx_stream s) {
|
||||
return mlx_hanning_ptr(res, M, s);
|
||||
}
|
||||
|
||||
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
|
||||
return mlx_identity_ptr(res, n, dtype, s);
|
||||
}
|
||||
@@ -5208,12 +5248,12 @@ int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indice
|
||||
return mlx_put_along_axis_ptr(res, a, indices, values, axis, s);
|
||||
}
|
||||
|
||||
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
||||
return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, s);
|
||||
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) {
|
||||
return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s);
|
||||
}
|
||||
|
||||
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
||||
return mlx_quantize_ptr(res, w, group_size, bits, mode, s);
|
||||
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) {
|
||||
return mlx_quantize_ptr(res, w, group_size, bits, mode, global_scale, s);
|
||||
}
|
||||
|
||||
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
||||
|
||||
@@ -2124,8 +2124,9 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
var globalScale C.mlx_array
|
||||
res := C.mlx_vector_array_new()
|
||||
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
|
||||
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, globalScale, C.default_stream())
|
||||
|
||||
// Result is a vector of arrays: [weights, scales, biases?]
|
||||
// mxfp8 mode returns only 2 elements (no biases)
|
||||
@@ -2154,6 +2155,7 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
optDtype := C.mlx_optional_dtype{has_value: false}
|
||||
var globalScale C.mlx_array
|
||||
|
||||
var b C.mlx_array
|
||||
if biases != nil {
|
||||
@@ -2161,7 +2163,7 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
|
||||
}
|
||||
|
||||
res := C.mlx_array_new()
|
||||
C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, optDtype, C.default_stream())
|
||||
C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, globalScale, optDtype, C.default_stream())
|
||||
return newArray(res)
|
||||
}
|
||||
|
||||
|
||||
@@ -309,10 +309,12 @@
|
||||
#undef mlx_atleast_1d
|
||||
#undef mlx_atleast_2d
|
||||
#undef mlx_atleast_3d
|
||||
#undef mlx_bartlett
|
||||
#undef mlx_bitwise_and
|
||||
#undef mlx_bitwise_invert
|
||||
#undef mlx_bitwise_or
|
||||
#undef mlx_bitwise_xor
|
||||
#undef mlx_blackman
|
||||
#undef mlx_block_masked_mm
|
||||
#undef mlx_broadcast_arrays
|
||||
#undef mlx_broadcast_to
|
||||
@@ -365,6 +367,8 @@
|
||||
#undef mlx_greater
|
||||
#undef mlx_greater_equal
|
||||
#undef mlx_hadamard_transform
|
||||
#undef mlx_hamming
|
||||
#undef mlx_hanning
|
||||
#undef mlx_identity
|
||||
#undef mlx_imag
|
||||
#undef mlx_inner
|
||||
@@ -751,8 +755,8 @@ extern int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x,
|
||||
extern int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group);
|
||||
extern int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group);
|
||||
extern mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key);
|
||||
extern bool (*mlx_distributed_is_available_ptr)(void);
|
||||
extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict);
|
||||
extern bool (*mlx_distributed_is_available_ptr)(const char* bk);
|
||||
extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk);
|
||||
extern void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
|
||||
extern void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...);
|
||||
extern int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless);
|
||||
@@ -905,10 +909,12 @@ extern int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype,
|
||||
extern int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_bartlett_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||
extern int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||
extern int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||
extern int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||
extern int (*mlx_blackman_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||
extern int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
|
||||
extern int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
|
||||
extern int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s);
|
||||
@@ -934,7 +940,7 @@ extern int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool
|
||||
extern int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s);
|
||||
extern int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
|
||||
extern int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s);
|
||||
extern int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s);
|
||||
extern int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||
extern int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s);
|
||||
extern int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||
@@ -961,6 +967,8 @@ extern int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_ar
|
||||
extern int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||
extern int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||
extern int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
|
||||
extern int (*mlx_hamming_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||
extern int (*mlx_hanning_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||
extern int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||
extern int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||
@@ -1020,8 +1028,8 @@ extern int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* ax
|
||||
extern int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s);
|
||||
extern int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s);
|
||||
extern int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
|
||||
extern int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
||||
extern int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
||||
extern int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s);
|
||||
extern int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s);
|
||||
extern int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
||||
extern int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
@@ -1492,9 +1500,9 @@ int mlx_distributed_group_size(mlx_distributed_group group);
|
||||
|
||||
mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
|
||||
|
||||
bool mlx_distributed_is_available(void);
|
||||
bool mlx_distributed_is_available(const char* bk);
|
||||
|
||||
mlx_distributed_group mlx_distributed_init(bool strict);
|
||||
mlx_distributed_group mlx_distributed_init(bool strict, const char* bk);
|
||||
|
||||
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
|
||||
|
||||
@@ -1800,6 +1808,8 @@ int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
|
||||
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
|
||||
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s);
|
||||
|
||||
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||
|
||||
int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
@@ -1808,6 +1818,8 @@ int mlx_bitwise_or(mlx_array* res, const mlx_array a, const mlx_array b, const m
|
||||
|
||||
int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||
|
||||
int mlx_blackman(mlx_array* res, int M, const mlx_stream s);
|
||||
|
||||
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
|
||||
|
||||
int mlx_broadcast_arrays(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
|
||||
@@ -1858,7 +1870,7 @@ int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
|
||||
int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
|
||||
|
||||
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s);
|
||||
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s);
|
||||
|
||||
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||
|
||||
@@ -1912,6 +1924,10 @@ int mlx_greater_equal(mlx_array* res, const mlx_array a, const mlx_array b, cons
|
||||
|
||||
int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
|
||||
|
||||
int mlx_hamming(mlx_array* res, int M, const mlx_stream s);
|
||||
|
||||
int mlx_hanning(mlx_array* res, int M, const mlx_stream s);
|
||||
|
||||
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||
|
||||
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
@@ -2030,9 +2046,9 @@ int mlx_prod(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream
|
||||
|
||||
int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
|
||||
|
||||
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
||||
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s);
|
||||
|
||||
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
||||
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s);
|
||||
|
||||
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package mlxrunner
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -36,14 +37,69 @@ type Client struct {
|
||||
modelName string
|
||||
contextLength atomic.Int64
|
||||
memory atomic.Uint64
|
||||
done chan error
|
||||
done chan struct{}
|
||||
doneErr error // valid after done is closed
|
||||
client *http.Client
|
||||
lastErr string
|
||||
lastErrLock sync.Mutex
|
||||
status *statusWriter
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
// statusWriter captures the last stderr line from the subprocess while
|
||||
// forwarding all output to os.Stderr. Lines longer than maxStatusLen are
|
||||
// truncated to the first maxStatusLen bytes.
|
||||
type statusWriter struct {
|
||||
lastErrMsg string
|
||||
buf []byte
|
||||
discarding bool
|
||||
mu sync.Mutex
|
||||
out *os.File
|
||||
}
|
||||
|
||||
const maxStatusLen = 256
|
||||
|
||||
func (w *statusWriter) Write(b []byte) (int, error) {
|
||||
n, err := w.out.Write(b)
|
||||
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
w.buf = append(w.buf, b...)
|
||||
for {
|
||||
i := bytes.IndexByte(w.buf, '\n')
|
||||
if i < 0 {
|
||||
break
|
||||
}
|
||||
if !w.discarding {
|
||||
line := bytes.TrimSpace(w.buf[:i])
|
||||
if len(line) > 0 {
|
||||
if len(line) > maxStatusLen {
|
||||
line = line[:maxStatusLen]
|
||||
}
|
||||
w.lastErrMsg = string(line)
|
||||
}
|
||||
}
|
||||
w.buf = w.buf[i+1:]
|
||||
w.discarding = false
|
||||
}
|
||||
// if the buffer grows past maxStatusLen without a newline, keep the front
|
||||
if len(w.buf) > maxStatusLen {
|
||||
if !w.discarding {
|
||||
w.lastErrMsg = string(bytes.TrimSpace(w.buf[:maxStatusLen]))
|
||||
w.discarding = true
|
||||
}
|
||||
w.buf = w.buf[:0]
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *statusWriter) getLastErr() string {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return w.lastErrMsg
|
||||
}
|
||||
|
||||
// NewClient prepares a new MLX runner client for LLM models.
|
||||
// The subprocess is not started until Load() is called.
|
||||
func NewClient(modelName string) (*Client, error) {
|
||||
@@ -53,7 +109,7 @@ func NewClient(modelName string) (*Client, error) {
|
||||
|
||||
c := &Client{
|
||||
modelName: modelName,
|
||||
done: make(chan error, 1),
|
||||
done: make(chan struct{}),
|
||||
client: &http.Client{Timeout: 10 * time.Minute},
|
||||
}
|
||||
|
||||
@@ -66,12 +122,6 @@ func NewClient(modelName string) (*Client, error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Client) getLastErr() string {
|
||||
c.lastErrLock.Lock()
|
||||
defer c.lastErrLock.Unlock()
|
||||
return c.lastErr
|
||||
}
|
||||
|
||||
// WaitUntilRunning waits for the subprocess to be ready.
|
||||
func (c *Client) WaitUntilRunning(ctx context.Context) error {
|
||||
timeout := time.After(2 * time.Minute)
|
||||
@@ -82,16 +132,14 @@ func (c *Client) WaitUntilRunning(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-c.done:
|
||||
errMsg := c.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err)
|
||||
case <-c.done:
|
||||
if msg := c.status.getLastErr(); msg != "" {
|
||||
return fmt.Errorf("mlx runner failed: %s (exit: %v)", msg, c.doneErr)
|
||||
}
|
||||
return fmt.Errorf("mlx runner exited unexpectedly: %w", err)
|
||||
return fmt.Errorf("mlx runner exited unexpectedly: %w", c.doneErr)
|
||||
case <-timeout:
|
||||
errMsg := c.getLastErr()
|
||||
if errMsg != "" {
|
||||
return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg)
|
||||
if msg := c.status.getLastErr(); msg != "" {
|
||||
return fmt.Errorf("timeout waiting for mlx runner: %s", msg)
|
||||
}
|
||||
return errors.New("timeout waiting for mlx runner to start")
|
||||
case <-ticker.C:
|
||||
@@ -348,18 +396,13 @@ func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo
|
||||
// Forward subprocess stdout/stderr to server logs
|
||||
stdout, _ := cmd.StdoutPipe()
|
||||
stderr, _ := cmd.StderrPipe()
|
||||
status := &statusWriter{out: os.Stderr}
|
||||
c.status = status
|
||||
go func() {
|
||||
io.Copy(os.Stderr, stdout) //nolint:errcheck
|
||||
}()
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
fmt.Fprintln(os.Stderr, line)
|
||||
c.lastErrLock.Lock()
|
||||
c.lastErr = line
|
||||
c.lastErrLock.Unlock()
|
||||
}
|
||||
io.Copy(status, stderr) //nolint:errcheck
|
||||
}()
|
||||
|
||||
slog.Info("starting mlx runner subprocess", "model", c.modelName, "port", c.port)
|
||||
@@ -369,8 +412,8 @@ func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo
|
||||
|
||||
// Reap subprocess when it exits
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
c.done <- err
|
||||
c.doneErr = cmd.Wait()
|
||||
close(c.done)
|
||||
}()
|
||||
|
||||
return nil, nil
|
||||
|
||||
@@ -15,7 +15,8 @@ set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "")
|
||||
file(READ "${CMAKE_CURRENT_SOURCE_DIR}/../../../MLX_VERSION" MLX_C_GIT_TAG)
|
||||
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
|
||||
@@ -326,8 +326,10 @@ int (*mlx_distributed_sum_scatter_)(
|
||||
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
|
||||
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
|
||||
bool (*mlx_distributed_is_available_)(void) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
|
||||
bool (*mlx_distributed_is_available_)(const char* bk /* may be null */) = NULL;
|
||||
mlx_distributed_group (*mlx_distributed_init_)(
|
||||
bool strict,
|
||||
const char* bk /* may be null */) = NULL;
|
||||
void (*mlx_set_error_handler_)(
|
||||
mlx_error_handler_func handler,
|
||||
void* data,
|
||||
@@ -924,6 +926,7 @@ int (*mlx_astype_)(
|
||||
int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_bartlett_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||
int (*mlx_bitwise_and_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -940,6 +943,7 @@ int (*mlx_bitwise_xor_)(
|
||||
const mlx_array a,
|
||||
const mlx_array b,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_blackman_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||
int (*mlx_block_masked_mm_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -1120,6 +1124,7 @@ int (*mlx_dequantize_)(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale /* may be null */,
|
||||
mlx_optional_dtype dtype,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
|
||||
@@ -1256,6 +1261,8 @@ int (*mlx_hadamard_transform_)(
|
||||
const mlx_array a,
|
||||
mlx_optional_float scale,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_hamming_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||
int (*mlx_hanning_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||
int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
|
||||
int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||
int (*mlx_inner_)(
|
||||
@@ -1548,6 +1555,8 @@ int (*mlx_qqmm_)(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale_x /* may be null */,
|
||||
const mlx_array global_scale_w /* may be null */,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_quantize_)(
|
||||
mlx_vector_array* res,
|
||||
@@ -1555,6 +1564,7 @@ int (*mlx_quantize_)(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale /* may be null */,
|
||||
const mlx_stream s) = NULL;
|
||||
int (*mlx_quantized_matmul_)(
|
||||
mlx_array* res,
|
||||
@@ -2550,10 +2560,12 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_atleast_1d);
|
||||
CHECK_LOAD(handle, mlx_atleast_2d);
|
||||
CHECK_LOAD(handle, mlx_atleast_3d);
|
||||
CHECK_LOAD(handle, mlx_bartlett);
|
||||
CHECK_LOAD(handle, mlx_bitwise_and);
|
||||
CHECK_LOAD(handle, mlx_bitwise_invert);
|
||||
CHECK_LOAD(handle, mlx_bitwise_or);
|
||||
CHECK_LOAD(handle, mlx_bitwise_xor);
|
||||
CHECK_LOAD(handle, mlx_blackman);
|
||||
CHECK_LOAD(handle, mlx_block_masked_mm);
|
||||
CHECK_LOAD(handle, mlx_broadcast_arrays);
|
||||
CHECK_LOAD(handle, mlx_broadcast_to);
|
||||
@@ -2606,6 +2618,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
CHECK_LOAD(handle, mlx_greater);
|
||||
CHECK_LOAD(handle, mlx_greater_equal);
|
||||
CHECK_LOAD(handle, mlx_hadamard_transform);
|
||||
CHECK_LOAD(handle, mlx_hamming);
|
||||
CHECK_LOAD(handle, mlx_hanning);
|
||||
CHECK_LOAD(handle, mlx_identity);
|
||||
CHECK_LOAD(handle, mlx_imag);
|
||||
CHECK_LOAD(handle, mlx_inner);
|
||||
|
||||
@@ -300,10 +300,12 @@
|
||||
#define mlx_atleast_1d mlx_atleast_1d_mlx_gen_orig_
|
||||
#define mlx_atleast_2d mlx_atleast_2d_mlx_gen_orig_
|
||||
#define mlx_atleast_3d mlx_atleast_3d_mlx_gen_orig_
|
||||
#define mlx_bartlett mlx_bartlett_mlx_gen_orig_
|
||||
#define mlx_bitwise_and mlx_bitwise_and_mlx_gen_orig_
|
||||
#define mlx_bitwise_invert mlx_bitwise_invert_mlx_gen_orig_
|
||||
#define mlx_bitwise_or mlx_bitwise_or_mlx_gen_orig_
|
||||
#define mlx_bitwise_xor mlx_bitwise_xor_mlx_gen_orig_
|
||||
#define mlx_blackman mlx_blackman_mlx_gen_orig_
|
||||
#define mlx_block_masked_mm mlx_block_masked_mm_mlx_gen_orig_
|
||||
#define mlx_broadcast_arrays mlx_broadcast_arrays_mlx_gen_orig_
|
||||
#define mlx_broadcast_to mlx_broadcast_to_mlx_gen_orig_
|
||||
@@ -356,6 +358,8 @@
|
||||
#define mlx_greater mlx_greater_mlx_gen_orig_
|
||||
#define mlx_greater_equal mlx_greater_equal_mlx_gen_orig_
|
||||
#define mlx_hadamard_transform mlx_hadamard_transform_mlx_gen_orig_
|
||||
#define mlx_hamming mlx_hamming_mlx_gen_orig_
|
||||
#define mlx_hanning mlx_hanning_mlx_gen_orig_
|
||||
#define mlx_identity mlx_identity_mlx_gen_orig_
|
||||
#define mlx_imag mlx_imag_mlx_gen_orig_
|
||||
#define mlx_inner mlx_inner_mlx_gen_orig_
|
||||
@@ -889,10 +893,12 @@
|
||||
#undef mlx_atleast_1d
|
||||
#undef mlx_atleast_2d
|
||||
#undef mlx_atleast_3d
|
||||
#undef mlx_bartlett
|
||||
#undef mlx_bitwise_and
|
||||
#undef mlx_bitwise_invert
|
||||
#undef mlx_bitwise_or
|
||||
#undef mlx_bitwise_xor
|
||||
#undef mlx_blackman
|
||||
#undef mlx_block_masked_mm
|
||||
#undef mlx_broadcast_arrays
|
||||
#undef mlx_broadcast_to
|
||||
@@ -945,6 +951,8 @@
|
||||
#undef mlx_greater
|
||||
#undef mlx_greater_equal
|
||||
#undef mlx_hadamard_transform
|
||||
#undef mlx_hamming
|
||||
#undef mlx_hanning
|
||||
#undef mlx_identity
|
||||
#undef mlx_imag
|
||||
#undef mlx_inner
|
||||
@@ -1501,8 +1509,10 @@ extern int (*mlx_distributed_sum_scatter_)(
|
||||
extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group);
|
||||
extern int (*mlx_distributed_group_size_)(mlx_distributed_group group);
|
||||
extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key);
|
||||
extern bool (*mlx_distributed_is_available_)(void);
|
||||
extern mlx_distributed_group (*mlx_distributed_init_)(bool strict);
|
||||
extern bool (*mlx_distributed_is_available_)(const char* bk /* may be null */);
|
||||
extern mlx_distributed_group (*mlx_distributed_init_)(
|
||||
bool strict,
|
||||
const char* bk /* may be null */);
|
||||
extern void (*mlx_set_error_handler_)(
|
||||
mlx_error_handler_func handler,
|
||||
void* data,
|
||||
@@ -2099,6 +2109,7 @@ extern int (*mlx_astype_)(
|
||||
extern int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_bartlett_)(mlx_array* res, int M, const mlx_stream s);
|
||||
extern int (*mlx_bitwise_and_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -2115,6 +2126,7 @@ extern int (*mlx_bitwise_xor_)(
|
||||
const mlx_array a,
|
||||
const mlx_array b,
|
||||
const mlx_stream s);
|
||||
extern int (*mlx_blackman_)(mlx_array* res, int M, const mlx_stream s);
|
||||
extern int (*mlx_block_masked_mm_)(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -2295,6 +2307,7 @@ extern int (*mlx_dequantize_)(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale /* may be null */,
|
||||
mlx_optional_dtype dtype,
|
||||
const mlx_stream s);
|
||||
extern int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||
@@ -2431,6 +2444,8 @@ extern int (*mlx_hadamard_transform_)(
|
||||
const mlx_array a,
|
||||
mlx_optional_float scale,
|
||||
const mlx_stream s);
|
||||
extern int (*mlx_hamming_)(mlx_array* res, int M, const mlx_stream s);
|
||||
extern int (*mlx_hanning_)(mlx_array* res, int M, const mlx_stream s);
|
||||
extern int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||
extern int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
extern int (*mlx_inner_)(
|
||||
@@ -2723,6 +2738,8 @@ extern int (*mlx_qqmm_)(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale_x /* may be null */,
|
||||
const mlx_array global_scale_w /* may be null */,
|
||||
const mlx_stream s);
|
||||
extern int (*mlx_quantize_)(
|
||||
mlx_vector_array* res,
|
||||
@@ -2730,6 +2747,7 @@ extern int (*mlx_quantize_)(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale /* may be null */,
|
||||
const mlx_stream s);
|
||||
extern int (*mlx_quantized_matmul_)(
|
||||
mlx_array* res,
|
||||
@@ -4033,11 +4051,13 @@ static inline int mlx_distributed_group_size(mlx_distributed_group group) {
|
||||
static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) {
|
||||
return mlx_distributed_group_split_(group, color, key);
|
||||
}
|
||||
static inline bool mlx_distributed_is_available(void) {
|
||||
return mlx_distributed_is_available_();
|
||||
static inline bool mlx_distributed_is_available(const char* bk /* may be null */) {
|
||||
return mlx_distributed_is_available_(bk);
|
||||
}
|
||||
static inline mlx_distributed_group mlx_distributed_init(bool strict) {
|
||||
return mlx_distributed_init_(strict);
|
||||
static inline mlx_distributed_group mlx_distributed_init(
|
||||
bool strict,
|
||||
const char* bk /* may be null */) {
|
||||
return mlx_distributed_init_(strict, bk);
|
||||
}
|
||||
static inline void mlx_set_error_handler(
|
||||
mlx_error_handler_func handler,
|
||||
@@ -4939,6 +4959,9 @@ static inline int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_st
|
||||
static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
||||
return mlx_atleast_3d_(res, a, s);
|
||||
}
|
||||
static inline int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) {
|
||||
return mlx_bartlett_(res, M, s);
|
||||
}
|
||||
static inline int mlx_bitwise_and(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -4963,6 +4986,9 @@ static inline int mlx_bitwise_xor(
|
||||
const mlx_stream s) {
|
||||
return mlx_bitwise_xor_(res, a, b, s);
|
||||
}
|
||||
static inline int mlx_blackman(mlx_array* res, int M, const mlx_stream s) {
|
||||
return mlx_blackman_(res, M, s);
|
||||
}
|
||||
static inline int mlx_block_masked_mm(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -5193,9 +5219,10 @@ static inline int mlx_dequantize(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale /* may be null */,
|
||||
mlx_optional_dtype dtype,
|
||||
const mlx_stream s) {
|
||||
return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, dtype, s);
|
||||
return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s);
|
||||
}
|
||||
static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
|
||||
return mlx_diag_(res, a, k, s);
|
||||
@@ -5383,6 +5410,12 @@ static inline int mlx_hadamard_transform(
|
||||
const mlx_stream s) {
|
||||
return mlx_hadamard_transform_(res, a, scale, s);
|
||||
}
|
||||
static inline int mlx_hamming(mlx_array* res, int M, const mlx_stream s) {
|
||||
return mlx_hamming_(res, M, s);
|
||||
}
|
||||
static inline int mlx_hanning(mlx_array* res, int M, const mlx_stream s) {
|
||||
return mlx_hanning_(res, M, s);
|
||||
}
|
||||
static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
|
||||
return mlx_identity_(res, n, dtype, s);
|
||||
}
|
||||
@@ -5793,8 +5826,10 @@ static inline int mlx_qqmm(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale_x /* may be null */,
|
||||
const mlx_array global_scale_w /* may be null */,
|
||||
const mlx_stream s) {
|
||||
return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, s);
|
||||
return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s);
|
||||
}
|
||||
static inline int mlx_quantize(
|
||||
mlx_vector_array* res,
|
||||
@@ -5802,8 +5837,9 @@ static inline int mlx_quantize(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale /* may be null */,
|
||||
const mlx_stream s) {
|
||||
return mlx_quantize_(res, w, group_size, bits, mode, s);
|
||||
return mlx_quantize_(res, w, group_size, bits, mode, global_scale, s);
|
||||
}
|
||||
static inline int mlx_quantized_matmul(
|
||||
mlx_array* res,
|
||||
|
||||
@@ -42,12 +42,14 @@ mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
|
||||
/**
|
||||
* Check if distributed is available.
|
||||
*/
|
||||
bool mlx_distributed_is_available(void);
|
||||
bool mlx_distributed_is_available(const char* bk /* may be null */);
|
||||
|
||||
/**
|
||||
* Initialize distributed.
|
||||
*/
|
||||
mlx_distributed_group mlx_distributed_init(bool strict);
|
||||
mlx_distributed_group mlx_distributed_init(
|
||||
bool strict,
|
||||
const char* bk /* may be null */);
|
||||
|
||||
/**@}*/
|
||||
|
||||
|
||||
@@ -166,6 +166,7 @@ int mlx_astype(
|
||||
int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s);
|
||||
int mlx_bitwise_and(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -182,6 +183,7 @@ int mlx_bitwise_xor(
|
||||
const mlx_array a,
|
||||
const mlx_array b,
|
||||
const mlx_stream s);
|
||||
int mlx_blackman(mlx_array* res, int M, const mlx_stream s);
|
||||
int mlx_block_masked_mm(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
@@ -362,6 +364,7 @@ int mlx_dequantize(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale /* may be null */,
|
||||
mlx_optional_dtype dtype,
|
||||
const mlx_stream s);
|
||||
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||
@@ -498,6 +501,8 @@ int mlx_hadamard_transform(
|
||||
const mlx_array a,
|
||||
mlx_optional_float scale,
|
||||
const mlx_stream s);
|
||||
int mlx_hamming(mlx_array* res, int M, const mlx_stream s);
|
||||
int mlx_hanning(mlx_array* res, int M, const mlx_stream s);
|
||||
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_inner(
|
||||
@@ -790,6 +795,8 @@ int mlx_qqmm(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale_x /* may be null */,
|
||||
const mlx_array global_scale_w /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_quantize(
|
||||
mlx_vector_array* res,
|
||||
@@ -797,6 +804,7 @@ int mlx_quantize(
|
||||
mlx_optional_int group_size,
|
||||
mlx_optional_int bits,
|
||||
const char* mode,
|
||||
const mlx_array global_scale /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_quantized_matmul(
|
||||
mlx_array* res,
|
||||
|
||||
@@ -15,9 +15,10 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
var globalScale C.mlx_array
|
||||
res := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(res)
|
||||
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, DefaultStream().ctx)
|
||||
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, globalScale, DefaultStream().ctx)
|
||||
|
||||
vecSize := int(C.mlx_vector_array_size(res))
|
||||
w0 := New("QUANTIZE_W")
|
||||
@@ -38,6 +39,7 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
optDtype := C.mlx_optional_dtype{has_value: false}
|
||||
var globalScale C.mlx_array
|
||||
|
||||
var b C.mlx_array
|
||||
if biases != nil {
|
||||
@@ -45,7 +47,7 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
|
||||
}
|
||||
|
||||
out := New("DEQUANTIZE")
|
||||
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx)
|
||||
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, globalScale, optDtype, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user