Compare commits

...

6 Commits

Author SHA1 Message Date
Patrick Devine
6f26695eae mlx: update upstream mlx version 2026-03-22 13:39:52 -07:00
Bruce MacDonald
22c2bdbd8a docs: nemoclaw integration (#14962)
---------

Co-authored-by: ParthSareen <parth.sareen@ollama.com>
2026-03-20 15:27:37 -07:00
Bruce MacDonald
6df6d097d9 launch: skip openclaw gateway health check when no daemon install (#14984) 2026-03-20 15:20:14 -07:00
Jesse Gross
d7c176ab91 llm, mlxrunner: fix done channel value consumed by first receiver
Receiving from a buffered chan error consumes the value, so only the
first caller (WaitUntilRunning, HasExited, or Close) sees the signal.
Subsequent receivers block or take the wrong branch. Replace with a
closed chan struct{} which can be received from any number of times,
and store the error in a separate field.
2026-03-19 17:44:28 -07:00
Jesse Gross
0ff7d724ff mlx: fix subprocess log deadlock
The stderr reader used bufio.Scanner which has a 64KB max line size.
If the subprocess wrote a line exceeding this limit, the scanner would
stop reading, the OS pipe buffer would fill, and the subprocess would
deadlock.

Replace the scanner with a statusWriter that wraps io.Copy. The writer
forwards all stderr to os.Stderr while capturing the last short line
(≤256 bytes) for error reporting, avoiding both the deadlock and the
need to buffer arbitrarily long lines.
2026-03-19 17:44:28 -07:00
Devon Rifkin
46cb7795e1 add ability to turn on debug request logging (#14106)
If `OLLAMA_DEBUG_LOG_REQUESTS` is set, then on server startup a temp
folder will be created. Upon any inference request, the body will be
logged to a file in this folder, as well as a small shell script to
"replay" the request using cURL.

This is just intended for debugging scenarios, not as something to turn
on normally.
2026-03-19 17:08:17 -07:00
20 changed files with 631 additions and 107 deletions

View File

@@ -1 +1 @@
v0.30.6
v0.31.1

View File

@@ -1 +1 @@
v0.5.0
v0.6.0

View File

@@ -80,6 +80,12 @@ func (c *Openclaw) Run(model string, args []string) error {
}
if canInstallDaemon() {
onboardArgs = append(onboardArgs, "--install-daemon")
} else {
// When we can't install a daemon (e.g. no systemd, sudo dropped
// XDG_RUNTIME_DIR, or container environment), skip the gateway
// health check so non-interactive onboarding completes. The
// gateway is started as a foreground child process after onboarding.
onboardArgs = append(onboardArgs, "--skip-health")
}
cmd := exec.Command(bin, onboardArgs...)
cmd.Stdin = os.Stdin

View File

@@ -160,6 +160,12 @@
"group": "More information",
"pages": [
"/cli",
{
"group": "Assistant Sandboxing",
"pages": [
"/integrations/nemoclaw"
]
},
"/modelfile",
"/context-length",
"/linux",

View File

@@ -0,0 +1,67 @@
---
title: NemoClaw
---
NemoClaw is NVIDIA's open source security stack for [OpenClaw](/integrations/openclaw). It wraps OpenClaw with the NVIDIA OpenShell runtime to provide kernel-level sandboxing, network policy controls, and audit trails for AI agents.
## Quick start
Pull a model:
```bash
ollama pull nemotron-3-nano:30b
```
Run the installer:
```bash
curl -fsSL https://www.nvidia.com/nemoclaw.sh | \
NEMOCLAW_NON_INTERACTIVE=1 \
NEMOCLAW_PROVIDER=ollama \
NEMOCLAW_MODEL=nemotron-3-nano:30b \
bash
```
Connect to your sandbox:
```bash
nemoclaw my-assistant connect
```
Open the TUI:
```bash
openclaw tui
```
<Note>Ollama support in NemoClaw is still experimental.</Note>
## Platform support
| Platform | Runtime | Status |
|----------|---------|--------|
| Linux (Ubuntu 22.04+) | Docker | Primary |
| macOS (Apple Silicon) | Colima or Docker Desktop | Supported |
| Windows | WSL2 with Docker Desktop | Supported |
CMD and PowerShell are not supported on Windows — WSL2 is required.
<Note>Ollama must be installed and running before the installer runs. When running inside WSL2 or a container, ensure Ollama is reachable from the sandbox (e.g. `OLLAMA_HOST=0.0.0.0`).</Note>
## System requirements
- CPU: 4 vCPU minimum
- RAM: 8 GB minimum (16 GB recommended)
- Disk: 20 GB free (40 GB recommended for local models)
- Node.js 20+ and npm 10+
- Container runtime (Docker preferred)
## Recommended models
- `nemotron-3-super:cloud` — Strong reasoning and coding
- `qwen3.5:cloud` — 397B; reasoning and code generation
- `nemotron-3-nano:30b` — Recommended local model; fits in 24 GB VRAM
- `qwen3.5:27b` — Fast local reasoning (~18 GB VRAM)
- `glm-4.7-flash` — Reasoning and code generation (~25 GB VRAM)
More models at [ollama.com/search](https://ollama.com/search).

View File

@@ -214,6 +214,8 @@ func LogLevel() slog.Level {
var (
// FlashAttention enables the experimental flash attention feature.
FlashAttention = BoolWithDefault("OLLAMA_FLASH_ATTENTION")
// DebugLogRequests logs inference requests to disk for replay/debugging.
DebugLogRequests = Bool("OLLAMA_DEBUG_LOG_REQUESTS")
// KvCacheType is the quantization type for the K/V cache.
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
// NoHistory disables readline history.
@@ -302,28 +304,29 @@ type EnvVar struct {
func AsMap() map[string]EnvVar {
ret := map[string]EnvVar{
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
"OLLAMA_NO_CLOUD": {"OLLAMA_NO_CLOUD", NoCloud(), "Disable Ollama cloud features (remote inference and web search)"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
"OLLAMA_EDITOR": {"OLLAMA_EDITOR", Editor(), "Path to editor for interactive prompt editing (Ctrl+G)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
"OLLAMA_DEBUG_LOG_REQUESTS": {"OLLAMA_DEBUG_LOG_REQUESTS", DebugLogRequests(), "Log inference request bodies and replay curl commands to a temp directory"},
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
"OLLAMA_NO_CLOUD": {"OLLAMA_NO_CLOUD", NoCloud(), "Disable Ollama cloud features (remote inference and web search)"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
"OLLAMA_EDITOR": {"OLLAMA_EDITOR", Editor(), "Path to editor for interactive prompt editing (Ctrl+G)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
// Informational
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},

View File

@@ -87,7 +87,8 @@ type LlamaServer interface {
type llmServer struct {
port int
cmd *exec.Cmd
done chan error // Channel to signal when the process exits
done chan struct{} // closed when the process exits
doneErr error // valid after done is closed
status *StatusWriter
options api.Options
modelPath string
@@ -280,7 +281,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
sem: semaphore.NewWeighted(int64(numParallel)),
totalLayers: f.KV().BlockCount() + 1,
loadStart: time.Now(),
done: make(chan error, 1),
done: make(chan struct{}),
}
if err != nil {
@@ -304,10 +305,11 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
if strings.Contains(s.status.LastErrMsg, "unknown model") {
s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
}
s.done <- errors.New(s.status.LastErrMsg)
s.doneErr = errors.New(s.status.LastErrMsg)
} else {
s.done <- err
s.doneErr = err
}
close(s.done)
}()
if tok != nil {
@@ -1356,8 +1358,8 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
case <-ctx.Done():
slog.Warn("client connection closed before server finished loading, aborting load")
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
case err := <-s.done:
return fmt.Errorf("llama runner process has terminated: %w", err)
case <-s.done:
return fmt.Errorf("llama runner process has terminated: %w", s.doneErr)
default:
}
if time.Now().After(stallTimer) {

View File

@@ -0,0 +1,144 @@
package server
import (
"bytes"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"strings"
"sync/atomic"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/envconfig"
)
type inferenceRequestLogger struct {
dir string
counter uint64
}
func newInferenceRequestLogger() (*inferenceRequestLogger, error) {
dir, err := os.MkdirTemp("", "ollama-request-logs-*")
if err != nil {
return nil, err
}
return &inferenceRequestLogger{dir: dir}, nil
}
func (s *Server) initRequestLogging() error {
if !envconfig.DebugLogRequests() {
return nil
}
requestLogger, err := newInferenceRequestLogger()
if err != nil {
return fmt.Errorf("enable OLLAMA_DEBUG_LOG_REQUESTS: %w", err)
}
s.requestLogger = requestLogger
slog.Info(fmt.Sprintf("request debug logging enabled; inference request logs will be stored in %s and include request bodies and replay curl commands", requestLogger.dir))
return nil
}
func (s *Server) withInferenceRequestLogging(route string, handlers ...gin.HandlerFunc) []gin.HandlerFunc {
if s.requestLogger == nil {
return handlers
}
return append([]gin.HandlerFunc{s.requestLogger.middleware(route)}, handlers...)
}
func (l *inferenceRequestLogger) middleware(route string) gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request == nil {
c.Next()
return
}
method := c.Request.Method
host := c.Request.Host
scheme := "http"
if c.Request.TLS != nil {
scheme = "https"
}
contentType := c.GetHeader("Content-Type")
var body []byte
if c.Request.Body != nil {
var err error
body, err = io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(body))
if err != nil {
slog.Warn("failed to read request body for debug logging", "route", route, "error", err)
}
}
c.Next()
l.log(route, method, scheme, host, contentType, body)
}
}
func (l *inferenceRequestLogger) log(route, method, scheme, host, contentType string, body []byte) {
if l == nil || l.dir == "" {
return
}
if contentType == "" {
contentType = "application/json"
}
if host == "" || scheme == "" {
base := envconfig.Host()
if host == "" {
host = base.Host
}
if scheme == "" {
scheme = base.Scheme
}
}
routeForFilename := sanitizeRouteForFilename(route)
timestamp := fmt.Sprintf("%s-%06d", time.Now().UTC().Format("20060102T150405.000000000Z"), atomic.AddUint64(&l.counter, 1))
bodyFilename := fmt.Sprintf("%s_%s_body.json", timestamp, routeForFilename)
curlFilename := fmt.Sprintf("%s_%s_request.sh", timestamp, routeForFilename)
bodyPath := filepath.Join(l.dir, bodyFilename)
curlPath := filepath.Join(l.dir, curlFilename)
if err := os.WriteFile(bodyPath, body, 0o600); err != nil {
slog.Warn("failed to write debug request body", "route", route, "error", err)
return
}
url := fmt.Sprintf("%s://%s%s", scheme, host, route)
curl := fmt.Sprintf("#!/bin/sh\nSCRIPT_DIR=\"$(CDPATH= cd -- \"$(dirname -- \"$0\")\" && pwd)\"\ncurl --request %s --url %q --header %q --data-binary @\"${SCRIPT_DIR}/%s\"\n", method, url, "Content-Type: "+contentType, bodyFilename)
if err := os.WriteFile(curlPath, []byte(curl), 0o600); err != nil {
slog.Warn("failed to write debug request replay command", "route", route, "error", err)
return
}
slog.Info(fmt.Sprintf("logged to %s, replay using curl with `sh %s`", bodyPath, curlPath))
}
func sanitizeRouteForFilename(route string) string {
route = strings.TrimPrefix(route, "/")
if route == "" {
return "root"
}
var b strings.Builder
b.Grow(len(route))
for _, r := range route {
if ('a' <= r && r <= 'z') || ('A' <= r && r <= 'Z') || ('0' <= r && r <= '9') {
b.WriteRune(r)
} else {
b.WriteByte('_')
}
}
return b.String()
}

View File

@@ -100,6 +100,7 @@ type Server struct {
addr net.Addr
sched *Scheduler
defaultNumCtx int
requestLogger *inferenceRequestLogger
}
func init() {
@@ -1686,26 +1687,26 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
// Inference
r.GET("/api/ps", s.PsHandler)
r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/generate", s.withInferenceRequestLogging("/api/generate", s.GenerateHandler)...)
r.POST("/api/chat", s.withInferenceRequestLogging("/api/chat", s.ChatHandler)...)
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
// Inference (OpenAI compatibility)
// TODO(cloud-stage-a): apply Modelfile overlay deltas for local models with cloud
// parents on v1 request families while preserving this explicit :cloud passthrough.
r.POST("/v1/chat/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/chat/completions", s.withInferenceRequestLogging("/v1/chat/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ChatMiddleware(), s.ChatHandler)...)
r.POST("/v1/completions", s.withInferenceRequestLogging("/v1/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.CompletionsMiddleware(), s.GenerateHandler)...)
r.POST("/v1/embeddings", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", cloudModelPathPassthroughMiddleware(cloudErrRemoteModelDetailsUnavailable), middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ResponsesMiddleware(), s.ChatHandler)
r.POST("/v1/responses", s.withInferenceRequestLogging("/v1/responses", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ResponsesMiddleware(), s.ChatHandler)...)
// OpenAI-compatible image generation endpoints
r.POST("/v1/images/generations", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
r.POST("/v1/images/edits", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageEditsMiddleware(), s.GenerateHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
r.POST("/v1/messages", s.withInferenceRequestLogging("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)...)
if rc != nil {
// wrap old with new
@@ -1757,6 +1758,9 @@ func Serve(ln net.Listener) error {
}
s := &Server{addr: ln.Addr()}
if err := s.initRequestLogging(); err != nil {
return err
}
var rc *ollama.Registry
if useClient2 {

View File

@@ -0,0 +1,128 @@
package server
import (
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/gin-gonic/gin"
)
func TestInferenceRequestLoggerMiddlewareWritesReplayArtifacts(t *testing.T) {
gin.SetMode(gin.TestMode)
logDir := t.TempDir()
requestLogger := &inferenceRequestLogger{dir: logDir}
const route = "/v1/chat/completions"
const requestBody = `{"model":"test-model","messages":[{"role":"user","content":"hello"}]}`
var bodySeenByHandler string
r := gin.New()
r.POST(route, requestLogger.middleware(route), func(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
t.Fatalf("failed to read body in handler: %v", err)
}
bodySeenByHandler = string(body)
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodPost, route, strings.NewReader(requestBody))
req.Host = "127.0.0.1:11434"
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
if bodySeenByHandler != requestBody {
t.Fatalf("handler body mismatch:\nexpected: %s\ngot: %s", requestBody, bodySeenByHandler)
}
bodyFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_body.json"))
if err != nil {
t.Fatalf("failed to glob body logs: %v", err)
}
if len(bodyFiles) != 1 {
t.Fatalf("expected 1 body log, got %d (%v)", len(bodyFiles), bodyFiles)
}
curlFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_request.sh"))
if err != nil {
t.Fatalf("failed to glob curl logs: %v", err)
}
if len(curlFiles) != 1 {
t.Fatalf("expected 1 curl log, got %d (%v)", len(curlFiles), curlFiles)
}
bodyData, err := os.ReadFile(bodyFiles[0])
if err != nil {
t.Fatalf("failed to read body log: %v", err)
}
if string(bodyData) != requestBody {
t.Fatalf("body log mismatch:\nexpected: %s\ngot: %s", requestBody, string(bodyData))
}
curlData, err := os.ReadFile(curlFiles[0])
if err != nil {
t.Fatalf("failed to read curl log: %v", err)
}
curlString := string(curlData)
if !strings.Contains(curlString, "http://127.0.0.1:11434"+route) {
t.Fatalf("curl log does not contain expected route URL: %s", curlString)
}
bodyFileName := filepath.Base(bodyFiles[0])
if !strings.Contains(curlString, "@\"${SCRIPT_DIR}/"+bodyFileName+"\"") {
t.Fatalf("curl log does not reference sibling body file: %s", curlString)
}
}
func TestNewInferenceRequestLoggerCreatesDirectory(t *testing.T) {
requestLogger, err := newInferenceRequestLogger()
if err != nil {
t.Fatalf("expected no error creating request logger: %v", err)
}
t.Cleanup(func() {
_ = os.RemoveAll(requestLogger.dir)
})
if requestLogger == nil || requestLogger.dir == "" {
t.Fatalf("expected request logger directory to be set")
}
info, err := os.Stat(requestLogger.dir)
if err != nil {
t.Fatalf("expected directory to exist: %v", err)
}
if !info.IsDir() {
t.Fatalf("expected %q to be a directory", requestLogger.dir)
}
}
func TestSanitizeRouteForFilename(t *testing.T) {
tests := []struct {
route string
want string
}{
{route: "/api/generate", want: "api_generate"},
{route: "/v1/chat/completions", want: "v1_chat_completions"},
{route: "/v1/messages", want: "v1_messages"},
}
for _, tt := range tests {
if got := sanitizeRouteForFilename(tt.route); got != tt.want {
t.Fatalf("sanitizeRouteForFilename(%q) = %q, want %q", tt.route, got, tt.want)
}
}
}

View File

@@ -165,8 +165,8 @@ int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x, const
int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group) = NULL;
int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group) = NULL;
mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key) = NULL;
bool (*mlx_distributed_is_available_ptr)(void) = NULL;
mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict) = NULL;
bool (*mlx_distributed_is_available_ptr)(const char* bk) = NULL;
mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk) = NULL;
void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) = NULL;
void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...) = NULL;
int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) = NULL;
@@ -319,10 +319,12 @@ int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, const
int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_bartlett_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_blackman_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) = NULL;
int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) = NULL;
int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL;
@@ -348,7 +350,7 @@ int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse
int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) = NULL;
int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) = NULL;
int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
@@ -375,6 +377,8 @@ int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w,
int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) = NULL;
int (*mlx_hamming_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_hanning_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
@@ -434,8 +438,8 @@ int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, siz
int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL;
int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) = NULL;
int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) = NULL;
int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
@@ -2101,6 +2105,11 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n");
return -1;
}
mlx_bartlett_ptr = GET_SYM(handle, "mlx_bartlett");
if (mlx_bartlett_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_bartlett\n");
return -1;
}
mlx_bitwise_and_ptr = GET_SYM(handle, "mlx_bitwise_and");
if (mlx_bitwise_and_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n");
@@ -2121,6 +2130,11 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n");
return -1;
}
mlx_blackman_ptr = GET_SYM(handle, "mlx_blackman");
if (mlx_blackman_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_blackman\n");
return -1;
}
mlx_block_masked_mm_ptr = GET_SYM(handle, "mlx_block_masked_mm");
if (mlx_block_masked_mm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n");
@@ -2381,6 +2395,16 @@ int mlx_load_functions(void* handle) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n");
return -1;
}
mlx_hamming_ptr = GET_SYM(handle, "mlx_hamming");
if (mlx_hamming_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_hamming\n");
return -1;
}
mlx_hanning_ptr = GET_SYM(handle, "mlx_hanning");
if (mlx_hanning_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_hanning\n");
return -1;
}
mlx_identity_ptr = GET_SYM(handle, "mlx_identity");
if (mlx_identity_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n");
@@ -4132,12 +4156,12 @@ mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, i
return mlx_distributed_group_split_ptr(group, color, key);
}
bool mlx_distributed_is_available(void) {
return mlx_distributed_is_available_ptr();
bool mlx_distributed_is_available(const char* bk) {
return mlx_distributed_is_available_ptr(bk);
}
mlx_distributed_group mlx_distributed_init(bool strict) {
return mlx_distributed_init_ptr(strict);
mlx_distributed_group mlx_distributed_init(bool strict, const char* bk) {
return mlx_distributed_init_ptr(strict, bk);
}
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) {
@@ -4748,6 +4772,10 @@ int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_atleast_3d_ptr(res, a, s);
}
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) {
return mlx_bartlett_ptr(res, M, s);
}
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_bitwise_and_ptr(res, a, b, s);
}
@@ -4764,6 +4792,10 @@ int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const
return mlx_bitwise_xor_ptr(res, a, b, s);
}
int mlx_blackman(mlx_array* res, int M, const mlx_stream s) {
return mlx_blackman_ptr(res, M, s);
}
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) {
return mlx_block_masked_mm_ptr(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s);
}
@@ -4864,8 +4896,8 @@ int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_
return mlx_depends_ptr(res, inputs, dependencies);
}
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) {
return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, dtype, s);
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) {
return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s);
}
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
@@ -4972,6 +5004,14 @@ int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float
return mlx_hadamard_transform_ptr(res, a, scale, s);
}
int mlx_hamming(mlx_array* res, int M, const mlx_stream s) {
return mlx_hamming_ptr(res, M, s);
}
int mlx_hanning(mlx_array* res, int M, const mlx_stream s) {
return mlx_hanning_ptr(res, M, s);
}
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
return mlx_identity_ptr(res, n, dtype, s);
}
@@ -5208,12 +5248,12 @@ int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indice
return mlx_put_along_axis_ptr(res, a, indices, values, axis, s);
}
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, s);
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) {
return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s);
}
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
return mlx_quantize_ptr(res, w, group_size, bits, mode, s);
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) {
return mlx_quantize_ptr(res, w, group_size, bits, mode, global_scale, s);
}
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {

View File

@@ -2124,8 +2124,9 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
var globalScale C.mlx_array
res := C.mlx_vector_array_new()
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, globalScale, C.default_stream())
// Result is a vector of arrays: [weights, scales, biases?]
// mxfp8 mode returns only 2 elements (no biases)
@@ -2154,6 +2155,7 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
optDtype := C.mlx_optional_dtype{has_value: false}
var globalScale C.mlx_array
var b C.mlx_array
if biases != nil {
@@ -2161,7 +2163,7 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
}
res := C.mlx_array_new()
C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, optDtype, C.default_stream())
C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, globalScale, optDtype, C.default_stream())
return newArray(res)
}

View File

@@ -309,10 +309,12 @@
#undef mlx_atleast_1d
#undef mlx_atleast_2d
#undef mlx_atleast_3d
#undef mlx_bartlett
#undef mlx_bitwise_and
#undef mlx_bitwise_invert
#undef mlx_bitwise_or
#undef mlx_bitwise_xor
#undef mlx_blackman
#undef mlx_block_masked_mm
#undef mlx_broadcast_arrays
#undef mlx_broadcast_to
@@ -365,6 +367,8 @@
#undef mlx_greater
#undef mlx_greater_equal
#undef mlx_hadamard_transform
#undef mlx_hamming
#undef mlx_hanning
#undef mlx_identity
#undef mlx_imag
#undef mlx_inner
@@ -751,8 +755,8 @@ extern int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x,
extern int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group);
extern int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group);
extern mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key);
extern bool (*mlx_distributed_is_available_ptr)(void);
extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict);
extern bool (*mlx_distributed_is_available_ptr)(const char* bk);
extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk);
extern void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
extern void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...);
extern int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless);
@@ -905,10 +909,12 @@ extern int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype,
extern int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_bartlett_ptr)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
extern int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
extern int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
extern int (*mlx_blackman_ptr)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
extern int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
extern int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s);
@@ -934,7 +940,7 @@ extern int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool
extern int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s);
extern int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
extern int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s);
extern int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s);
extern int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
extern int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s);
extern int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
@@ -961,6 +967,8 @@ extern int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_ar
extern int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
extern int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
extern int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
extern int (*mlx_hamming_ptr)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_hanning_ptr)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
extern int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
@@ -1020,8 +1028,8 @@ extern int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* ax
extern int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s);
extern int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s);
extern int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
extern int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
extern int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
extern int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s);
extern int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s);
extern int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
extern int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
@@ -1492,9 +1500,9 @@ int mlx_distributed_group_size(mlx_distributed_group group);
mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
bool mlx_distributed_is_available(void);
bool mlx_distributed_is_available(const char* bk);
mlx_distributed_group mlx_distributed_init(bool strict);
mlx_distributed_group mlx_distributed_init(bool strict, const char* bk);
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
@@ -1800,6 +1808,8 @@ int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s);
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s);
@@ -1808,6 +1818,8 @@ int mlx_bitwise_or(mlx_array* res, const mlx_array a, const mlx_array b, const m
int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
int mlx_blackman(mlx_array* res, int M, const mlx_stream s);
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
int mlx_broadcast_arrays(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
@@ -1858,7 +1870,7 @@ int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s);
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s);
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
@@ -1912,6 +1924,10 @@ int mlx_greater_equal(mlx_array* res, const mlx_array a, const mlx_array b, cons
int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
int mlx_hamming(mlx_array* res, int M, const mlx_stream s);
int mlx_hanning(mlx_array* res, int M, const mlx_stream s);
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
@@ -2030,9 +2046,9 @@ int mlx_prod(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream
int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s);
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s);
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);

View File

@@ -2,6 +2,7 @@ package mlxrunner
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
@@ -36,14 +37,69 @@ type Client struct {
modelName string
contextLength atomic.Int64
memory atomic.Uint64
done chan error
done chan struct{}
doneErr error // valid after done is closed
client *http.Client
lastErr string
lastErrLock sync.Mutex
status *statusWriter
mu sync.Mutex
cmd *exec.Cmd
}
// statusWriter captures the last stderr line from the subprocess while
// forwarding all output to os.Stderr. Lines longer than maxStatusLen are
// truncated to the first maxStatusLen bytes.
type statusWriter struct {
lastErrMsg string
buf []byte
discarding bool
mu sync.Mutex
out *os.File
}
const maxStatusLen = 256
func (w *statusWriter) Write(b []byte) (int, error) {
n, err := w.out.Write(b)
w.mu.Lock()
defer w.mu.Unlock()
w.buf = append(w.buf, b...)
for {
i := bytes.IndexByte(w.buf, '\n')
if i < 0 {
break
}
if !w.discarding {
line := bytes.TrimSpace(w.buf[:i])
if len(line) > 0 {
if len(line) > maxStatusLen {
line = line[:maxStatusLen]
}
w.lastErrMsg = string(line)
}
}
w.buf = w.buf[i+1:]
w.discarding = false
}
// if the buffer grows past maxStatusLen without a newline, keep the front
if len(w.buf) > maxStatusLen {
if !w.discarding {
w.lastErrMsg = string(bytes.TrimSpace(w.buf[:maxStatusLen]))
w.discarding = true
}
w.buf = w.buf[:0]
}
return n, err
}
func (w *statusWriter) getLastErr() string {
w.mu.Lock()
defer w.mu.Unlock()
return w.lastErrMsg
}
// NewClient prepares a new MLX runner client for LLM models.
// The subprocess is not started until Load() is called.
func NewClient(modelName string) (*Client, error) {
@@ -53,7 +109,7 @@ func NewClient(modelName string) (*Client, error) {
c := &Client{
modelName: modelName,
done: make(chan error, 1),
done: make(chan struct{}),
client: &http.Client{Timeout: 10 * time.Minute},
}
@@ -66,12 +122,6 @@ func NewClient(modelName string) (*Client, error) {
return c, nil
}
func (c *Client) getLastErr() string {
c.lastErrLock.Lock()
defer c.lastErrLock.Unlock()
return c.lastErr
}
// WaitUntilRunning waits for the subprocess to be ready.
func (c *Client) WaitUntilRunning(ctx context.Context) error {
timeout := time.After(2 * time.Minute)
@@ -82,16 +132,14 @@ func (c *Client) WaitUntilRunning(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-c.done:
errMsg := c.getLastErr()
if errMsg != "" {
return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err)
case <-c.done:
if msg := c.status.getLastErr(); msg != "" {
return fmt.Errorf("mlx runner failed: %s (exit: %v)", msg, c.doneErr)
}
return fmt.Errorf("mlx runner exited unexpectedly: %w", err)
return fmt.Errorf("mlx runner exited unexpectedly: %w", c.doneErr)
case <-timeout:
errMsg := c.getLastErr()
if errMsg != "" {
return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg)
if msg := c.status.getLastErr(); msg != "" {
return fmt.Errorf("timeout waiting for mlx runner: %s", msg)
}
return errors.New("timeout waiting for mlx runner to start")
case <-ticker.C:
@@ -348,18 +396,13 @@ func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo
// Forward subprocess stdout/stderr to server logs
stdout, _ := cmd.StdoutPipe()
stderr, _ := cmd.StderrPipe()
status := &statusWriter{out: os.Stderr}
c.status = status
go func() {
io.Copy(os.Stderr, stdout) //nolint:errcheck
}()
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
line := scanner.Text()
fmt.Fprintln(os.Stderr, line)
c.lastErrLock.Lock()
c.lastErr = line
c.lastErrLock.Unlock()
}
io.Copy(status, stderr) //nolint:errcheck
}()
slog.Info("starting mlx runner subprocess", "model", c.modelName, "port", c.port)
@@ -369,8 +412,8 @@ func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo
// Reap subprocess when it exits
go func() {
err := cmd.Wait()
c.done <- err
c.doneErr = cmd.Wait()
close(c.done)
}()
return nil, nil

View File

@@ -15,7 +15,8 @@ set(CMAKE_INSTALL_RPATH "@loader_path")
include(FetchContent)
set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "")
file(READ "${CMAKE_CURRENT_SOURCE_DIR}/../../../MLX_VERSION" MLX_C_GIT_TAG)
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
FetchContent_Declare(
mlx-c

View File

@@ -326,8 +326,10 @@ int (*mlx_distributed_sum_scatter_)(
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
bool (*mlx_distributed_is_available_)(void) = NULL;
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
bool (*mlx_distributed_is_available_)(const char* bk /* may be null */) = NULL;
mlx_distributed_group (*mlx_distributed_init_)(
bool strict,
const char* bk /* may be null */) = NULL;
void (*mlx_set_error_handler_)(
mlx_error_handler_func handler,
void* data,
@@ -924,6 +926,7 @@ int (*mlx_astype_)(
int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_bartlett_)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_bitwise_and_)(
mlx_array* res,
const mlx_array a,
@@ -940,6 +943,7 @@ int (*mlx_bitwise_xor_)(
const mlx_array a,
const mlx_array b,
const mlx_stream s) = NULL;
int (*mlx_blackman_)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_block_masked_mm_)(
mlx_array* res,
const mlx_array a,
@@ -1120,6 +1124,7 @@ int (*mlx_dequantize_)(
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_array global_scale /* may be null */,
mlx_optional_dtype dtype,
const mlx_stream s) = NULL;
int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
@@ -1256,6 +1261,8 @@ int (*mlx_hadamard_transform_)(
const mlx_array a,
mlx_optional_float scale,
const mlx_stream s) = NULL;
int (*mlx_hamming_)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_hanning_)(mlx_array* res, int M, const mlx_stream s) = NULL;
int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_inner_)(
@@ -1548,6 +1555,8 @@ int (*mlx_qqmm_)(
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_array global_scale_x /* may be null */,
const mlx_array global_scale_w /* may be null */,
const mlx_stream s) = NULL;
int (*mlx_quantize_)(
mlx_vector_array* res,
@@ -1555,6 +1564,7 @@ int (*mlx_quantize_)(
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_array global_scale /* may be null */,
const mlx_stream s) = NULL;
int (*mlx_quantized_matmul_)(
mlx_array* res,
@@ -2550,10 +2560,12 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_atleast_1d);
CHECK_LOAD(handle, mlx_atleast_2d);
CHECK_LOAD(handle, mlx_atleast_3d);
CHECK_LOAD(handle, mlx_bartlett);
CHECK_LOAD(handle, mlx_bitwise_and);
CHECK_LOAD(handle, mlx_bitwise_invert);
CHECK_LOAD(handle, mlx_bitwise_or);
CHECK_LOAD(handle, mlx_bitwise_xor);
CHECK_LOAD(handle, mlx_blackman);
CHECK_LOAD(handle, mlx_block_masked_mm);
CHECK_LOAD(handle, mlx_broadcast_arrays);
CHECK_LOAD(handle, mlx_broadcast_to);
@@ -2606,6 +2618,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
CHECK_LOAD(handle, mlx_greater);
CHECK_LOAD(handle, mlx_greater_equal);
CHECK_LOAD(handle, mlx_hadamard_transform);
CHECK_LOAD(handle, mlx_hamming);
CHECK_LOAD(handle, mlx_hanning);
CHECK_LOAD(handle, mlx_identity);
CHECK_LOAD(handle, mlx_imag);
CHECK_LOAD(handle, mlx_inner);

View File

@@ -300,10 +300,12 @@
#define mlx_atleast_1d mlx_atleast_1d_mlx_gen_orig_
#define mlx_atleast_2d mlx_atleast_2d_mlx_gen_orig_
#define mlx_atleast_3d mlx_atleast_3d_mlx_gen_orig_
#define mlx_bartlett mlx_bartlett_mlx_gen_orig_
#define mlx_bitwise_and mlx_bitwise_and_mlx_gen_orig_
#define mlx_bitwise_invert mlx_bitwise_invert_mlx_gen_orig_
#define mlx_bitwise_or mlx_bitwise_or_mlx_gen_orig_
#define mlx_bitwise_xor mlx_bitwise_xor_mlx_gen_orig_
#define mlx_blackman mlx_blackman_mlx_gen_orig_
#define mlx_block_masked_mm mlx_block_masked_mm_mlx_gen_orig_
#define mlx_broadcast_arrays mlx_broadcast_arrays_mlx_gen_orig_
#define mlx_broadcast_to mlx_broadcast_to_mlx_gen_orig_
@@ -356,6 +358,8 @@
#define mlx_greater mlx_greater_mlx_gen_orig_
#define mlx_greater_equal mlx_greater_equal_mlx_gen_orig_
#define mlx_hadamard_transform mlx_hadamard_transform_mlx_gen_orig_
#define mlx_hamming mlx_hamming_mlx_gen_orig_
#define mlx_hanning mlx_hanning_mlx_gen_orig_
#define mlx_identity mlx_identity_mlx_gen_orig_
#define mlx_imag mlx_imag_mlx_gen_orig_
#define mlx_inner mlx_inner_mlx_gen_orig_
@@ -889,10 +893,12 @@
#undef mlx_atleast_1d
#undef mlx_atleast_2d
#undef mlx_atleast_3d
#undef mlx_bartlett
#undef mlx_bitwise_and
#undef mlx_bitwise_invert
#undef mlx_bitwise_or
#undef mlx_bitwise_xor
#undef mlx_blackman
#undef mlx_block_masked_mm
#undef mlx_broadcast_arrays
#undef mlx_broadcast_to
@@ -945,6 +951,8 @@
#undef mlx_greater
#undef mlx_greater_equal
#undef mlx_hadamard_transform
#undef mlx_hamming
#undef mlx_hanning
#undef mlx_identity
#undef mlx_imag
#undef mlx_inner
@@ -1501,8 +1509,10 @@ extern int (*mlx_distributed_sum_scatter_)(
extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group);
extern int (*mlx_distributed_group_size_)(mlx_distributed_group group);
extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key);
extern bool (*mlx_distributed_is_available_)(void);
extern mlx_distributed_group (*mlx_distributed_init_)(bool strict);
extern bool (*mlx_distributed_is_available_)(const char* bk /* may be null */);
extern mlx_distributed_group (*mlx_distributed_init_)(
bool strict,
const char* bk /* may be null */);
extern void (*mlx_set_error_handler_)(
mlx_error_handler_func handler,
void* data,
@@ -2099,6 +2109,7 @@ extern int (*mlx_astype_)(
extern int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_bartlett_)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_bitwise_and_)(
mlx_array* res,
const mlx_array a,
@@ -2115,6 +2126,7 @@ extern int (*mlx_bitwise_xor_)(
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_blackman_)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_block_masked_mm_)(
mlx_array* res,
const mlx_array a,
@@ -2295,6 +2307,7 @@ extern int (*mlx_dequantize_)(
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_array global_scale /* may be null */,
mlx_optional_dtype dtype,
const mlx_stream s);
extern int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
@@ -2431,6 +2444,8 @@ extern int (*mlx_hadamard_transform_)(
const mlx_array a,
mlx_optional_float scale,
const mlx_stream s);
extern int (*mlx_hamming_)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_hanning_)(mlx_array* res, int M, const mlx_stream s);
extern int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
extern int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_inner_)(
@@ -2723,6 +2738,8 @@ extern int (*mlx_qqmm_)(
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_array global_scale_x /* may be null */,
const mlx_array global_scale_w /* may be null */,
const mlx_stream s);
extern int (*mlx_quantize_)(
mlx_vector_array* res,
@@ -2730,6 +2747,7 @@ extern int (*mlx_quantize_)(
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_array global_scale /* may be null */,
const mlx_stream s);
extern int (*mlx_quantized_matmul_)(
mlx_array* res,
@@ -4033,11 +4051,13 @@ static inline int mlx_distributed_group_size(mlx_distributed_group group) {
static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) {
return mlx_distributed_group_split_(group, color, key);
}
static inline bool mlx_distributed_is_available(void) {
return mlx_distributed_is_available_();
static inline bool mlx_distributed_is_available(const char* bk /* may be null */) {
return mlx_distributed_is_available_(bk);
}
static inline mlx_distributed_group mlx_distributed_init(bool strict) {
return mlx_distributed_init_(strict);
static inline mlx_distributed_group mlx_distributed_init(
bool strict,
const char* bk /* may be null */) {
return mlx_distributed_init_(strict, bk);
}
static inline void mlx_set_error_handler(
mlx_error_handler_func handler,
@@ -4939,6 +4959,9 @@ static inline int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_st
static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_atleast_3d_(res, a, s);
}
static inline int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) {
return mlx_bartlett_(res, M, s);
}
static inline int mlx_bitwise_and(
mlx_array* res,
const mlx_array a,
@@ -4963,6 +4986,9 @@ static inline int mlx_bitwise_xor(
const mlx_stream s) {
return mlx_bitwise_xor_(res, a, b, s);
}
static inline int mlx_blackman(mlx_array* res, int M, const mlx_stream s) {
return mlx_blackman_(res, M, s);
}
static inline int mlx_block_masked_mm(
mlx_array* res,
const mlx_array a,
@@ -5193,9 +5219,10 @@ static inline int mlx_dequantize(
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_array global_scale /* may be null */,
mlx_optional_dtype dtype,
const mlx_stream s) {
return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, dtype, s);
return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s);
}
static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
return mlx_diag_(res, a, k, s);
@@ -5383,6 +5410,12 @@ static inline int mlx_hadamard_transform(
const mlx_stream s) {
return mlx_hadamard_transform_(res, a, scale, s);
}
static inline int mlx_hamming(mlx_array* res, int M, const mlx_stream s) {
return mlx_hamming_(res, M, s);
}
static inline int mlx_hanning(mlx_array* res, int M, const mlx_stream s) {
return mlx_hanning_(res, M, s);
}
static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
return mlx_identity_(res, n, dtype, s);
}
@@ -5793,8 +5826,10 @@ static inline int mlx_qqmm(
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_array global_scale_x /* may be null */,
const mlx_array global_scale_w /* may be null */,
const mlx_stream s) {
return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, s);
return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s);
}
static inline int mlx_quantize(
mlx_vector_array* res,
@@ -5802,8 +5837,9 @@ static inline int mlx_quantize(
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_array global_scale /* may be null */,
const mlx_stream s) {
return mlx_quantize_(res, w, group_size, bits, mode, s);
return mlx_quantize_(res, w, group_size, bits, mode, global_scale, s);
}
static inline int mlx_quantized_matmul(
mlx_array* res,

View File

@@ -42,12 +42,14 @@ mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
/**
* Check if distributed is available.
*/
bool mlx_distributed_is_available(void);
bool mlx_distributed_is_available(const char* bk /* may be null */);
/**
* Initialize distributed.
*/
mlx_distributed_group mlx_distributed_init(bool strict);
mlx_distributed_group mlx_distributed_init(
bool strict,
const char* bk /* may be null */);
/**@}*/

View File

@@ -166,6 +166,7 @@ int mlx_astype(
int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s);
int mlx_bitwise_and(
mlx_array* res,
const mlx_array a,
@@ -182,6 +183,7 @@ int mlx_bitwise_xor(
const mlx_array a,
const mlx_array b,
const mlx_stream s);
int mlx_blackman(mlx_array* res, int M, const mlx_stream s);
int mlx_block_masked_mm(
mlx_array* res,
const mlx_array a,
@@ -362,6 +364,7 @@ int mlx_dequantize(
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_array global_scale /* may be null */,
mlx_optional_dtype dtype,
const mlx_stream s);
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
@@ -498,6 +501,8 @@ int mlx_hadamard_transform(
const mlx_array a,
mlx_optional_float scale,
const mlx_stream s);
int mlx_hamming(mlx_array* res, int M, const mlx_stream s);
int mlx_hanning(mlx_array* res, int M, const mlx_stream s);
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_inner(
@@ -790,6 +795,8 @@ int mlx_qqmm(
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_array global_scale_x /* may be null */,
const mlx_array global_scale_w /* may be null */,
const mlx_stream s);
int mlx_quantize(
mlx_vector_array* res,
@@ -797,6 +804,7 @@ int mlx_quantize(
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_array global_scale /* may be null */,
const mlx_stream s);
int mlx_quantized_matmul(
mlx_array* res,

View File

@@ -15,9 +15,10 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
var globalScale C.mlx_array
res := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(res)
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, DefaultStream().ctx)
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, globalScale, DefaultStream().ctx)
vecSize := int(C.mlx_vector_array_size(res))
w0 := New("QUANTIZE_W")
@@ -38,6 +39,7 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
optDtype := C.mlx_optional_dtype{has_value: false}
var globalScale C.mlx_array
var b C.mlx_array
if biases != nil {
@@ -45,7 +47,7 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
}
out := New("DEQUANTIZE")
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx)
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, globalScale, optDtype, DefaultStream().ctx)
return out
}