mirror of
https://github.com/ollama/ollama.git
synced 2026-04-27 19:25:55 +02:00
Compare commits
12 Commits
pdevine/qw
...
parth/upda
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
00af64a0ae | ||
|
|
21f0db0d37 | ||
|
|
95ee7fbd29 | ||
|
|
ec55536734 | ||
|
|
77491439c2 | ||
|
|
b166b36cd2 | ||
|
|
c2b0bb7a52 | ||
|
|
22c2bdbd8a | ||
|
|
6df6d097d9 | ||
|
|
d7c176ab91 | ||
|
|
0ff7d724ff | ||
|
|
46cb7795e1 |
@@ -157,7 +157,7 @@ COPY CMakeLists.txt CMakePresets.json .
|
|||||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||||
COPY x/imagegen/mlx x/imagegen/mlx
|
COPY x/imagegen/mlx x/imagegen/mlx
|
||||||
COPY go.mod go.sum .
|
COPY go.mod go.sum .
|
||||||
COPY MLX_VERSION MLX_CORE_VERSION .
|
COPY MLX_VERSION MLX_C_VERSION .
|
||||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||||
ENV PATH=/usr/local/go/bin:$PATH
|
ENV PATH=/usr/local/go/bin:$PATH
|
||||||
RUN go mod download
|
RUN go mod download
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
v0.30.6
|
|
||||||
1
MLX_C_VERSION
Normal file
1
MLX_C_VERSION
Normal file
@@ -0,0 +1 @@
|
|||||||
|
0726ca922fc902c4c61ef9c27d94132be418e945
|
||||||
@@ -1 +1 @@
|
|||||||
v0.5.0
|
38ad257088fb2193ad47e527cf6534a689f30943
|
||||||
|
|||||||
@@ -80,6 +80,12 @@ func (c *Openclaw) Run(model string, args []string) error {
|
|||||||
}
|
}
|
||||||
if canInstallDaemon() {
|
if canInstallDaemon() {
|
||||||
onboardArgs = append(onboardArgs, "--install-daemon")
|
onboardArgs = append(onboardArgs, "--install-daemon")
|
||||||
|
} else {
|
||||||
|
// When we can't install a daemon (e.g. no systemd, sudo dropped
|
||||||
|
// XDG_RUNTIME_DIR, or container environment), skip the gateway
|
||||||
|
// health check so non-interactive onboarding completes. The
|
||||||
|
// gateway is started as a foreground child process after onboarding.
|
||||||
|
onboardArgs = append(onboardArgs, "--skip-health")
|
||||||
}
|
}
|
||||||
cmd := exec.Command(bin, onboardArgs...)
|
cmd := exec.Command(bin, onboardArgs...)
|
||||||
cmd.Stdin = os.Stdin
|
cmd.Stdin = os.Stdin
|
||||||
|
|||||||
@@ -160,6 +160,12 @@
|
|||||||
"group": "More information",
|
"group": "More information",
|
||||||
"pages": [
|
"pages": [
|
||||||
"/cli",
|
"/cli",
|
||||||
|
{
|
||||||
|
"group": "Assistant Sandboxing",
|
||||||
|
"pages": [
|
||||||
|
"/integrations/nemoclaw"
|
||||||
|
]
|
||||||
|
},
|
||||||
"/modelfile",
|
"/modelfile",
|
||||||
"/context-length",
|
"/context-length",
|
||||||
"/linux",
|
"/linux",
|
||||||
|
|||||||
@@ -96,6 +96,33 @@ The `/loop` command runs a prompt or slash command on a recurring schedule insid
|
|||||||
/loop 1h Remind me to review the deploy status
|
/loop 1h Remind me to review the deploy status
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Channels
|
||||||
|
|
||||||
|
Chat with Claude Code from Telegram by connecting a bot to your session. Create a bot via [@BotFather](https://t.me/BotFather).
|
||||||
|
|
||||||
|
Install the telegram plugin:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
/plugin install telegram@claude-plugins-official
|
||||||
|
```
|
||||||
|
|
||||||
|
Configure the token:
|
||||||
|
```shell
|
||||||
|
/telegram:configure 123456789:ABCdEF...
|
||||||
|
```
|
||||||
|
|
||||||
|
Launch with Ollama:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama launch claude -- --channels plugin:telegram@claude-plugins-official
|
||||||
|
```
|
||||||
|
|
||||||
|
See the [plugin README](https://github.com/anthropics/claude-plugins-official/tree/main/external_plugins/telegram) for full setup instructions including pairing and access control.
|
||||||
|
|
||||||
|
Claude Code will prompt for permission on most actions. To allow the bot to work autonomously, configure [permission rules](https://code.claude.com/docs/en/permissions) or pass `--dangerously-skip-permissions` in isolated environments.
|
||||||
|
|
||||||
|
Other channels may also be added by following the [Claude Code docs](https://code.claude.com/docs/en/channels-reference).
|
||||||
|
|
||||||
## Manual setup
|
## Manual setup
|
||||||
|
|
||||||
Claude Code connects to Ollama using the Anthropic-compatible API.
|
Claude Code connects to Ollama using the Anthropic-compatible API.
|
||||||
|
|||||||
67
docs/integrations/nemoclaw.mdx
Normal file
67
docs/integrations/nemoclaw.mdx
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
---
|
||||||
|
title: NemoClaw
|
||||||
|
---
|
||||||
|
|
||||||
|
NemoClaw is NVIDIA's open source security stack for [OpenClaw](/integrations/openclaw). It wraps OpenClaw with the NVIDIA OpenShell runtime to provide kernel-level sandboxing, network policy controls, and audit trails for AI agents.
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
Pull a model:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ollama pull nemotron-3-nano:30b
|
||||||
|
```
|
||||||
|
|
||||||
|
Run the installer:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -fsSL https://www.nvidia.com/nemoclaw.sh | \
|
||||||
|
NEMOCLAW_NON_INTERACTIVE=1 \
|
||||||
|
NEMOCLAW_PROVIDER=ollama \
|
||||||
|
NEMOCLAW_MODEL=nemotron-3-nano:30b \
|
||||||
|
bash
|
||||||
|
```
|
||||||
|
|
||||||
|
Connect to your sandbox:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
nemoclaw my-assistant connect
|
||||||
|
```
|
||||||
|
|
||||||
|
Open the TUI:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
openclaw tui
|
||||||
|
```
|
||||||
|
|
||||||
|
<Note>Ollama support in NemoClaw is still experimental.</Note>
|
||||||
|
|
||||||
|
## Platform support
|
||||||
|
|
||||||
|
| Platform | Runtime | Status |
|
||||||
|
|----------|---------|--------|
|
||||||
|
| Linux (Ubuntu 22.04+) | Docker | Primary |
|
||||||
|
| macOS (Apple Silicon) | Colima or Docker Desktop | Supported |
|
||||||
|
| Windows | WSL2 with Docker Desktop | Supported |
|
||||||
|
|
||||||
|
CMD and PowerShell are not supported on Windows — WSL2 is required.
|
||||||
|
|
||||||
|
<Note>Ollama must be installed and running before the installer runs. When running inside WSL2 or a container, ensure Ollama is reachable from the sandbox (e.g. `OLLAMA_HOST=0.0.0.0`).</Note>
|
||||||
|
|
||||||
|
## System requirements
|
||||||
|
|
||||||
|
- CPU: 4 vCPU minimum
|
||||||
|
- RAM: 8 GB minimum (16 GB recommended)
|
||||||
|
- Disk: 20 GB free (40 GB recommended for local models)
|
||||||
|
- Node.js 20+ and npm 10+
|
||||||
|
- Container runtime (Docker preferred)
|
||||||
|
|
||||||
|
## Recommended models
|
||||||
|
|
||||||
|
- `nemotron-3-super:cloud` — Strong reasoning and coding
|
||||||
|
- `qwen3.5:cloud` — 397B; reasoning and code generation
|
||||||
|
- `nemotron-3-nano:30b` — Recommended local model; fits in 24 GB VRAM
|
||||||
|
- `qwen3.5:27b` — Fast local reasoning (~18 GB VRAM)
|
||||||
|
- `glm-4.7-flash` — Reasoning and code generation (~25 GB VRAM)
|
||||||
|
|
||||||
|
More models at [ollama.com/search](https://ollama.com/search).
|
||||||
@@ -214,6 +214,8 @@ func LogLevel() slog.Level {
|
|||||||
var (
|
var (
|
||||||
// FlashAttention enables the experimental flash attention feature.
|
// FlashAttention enables the experimental flash attention feature.
|
||||||
FlashAttention = BoolWithDefault("OLLAMA_FLASH_ATTENTION")
|
FlashAttention = BoolWithDefault("OLLAMA_FLASH_ATTENTION")
|
||||||
|
// DebugLogRequests logs inference requests to disk for replay/debugging.
|
||||||
|
DebugLogRequests = Bool("OLLAMA_DEBUG_LOG_REQUESTS")
|
||||||
// KvCacheType is the quantization type for the K/V cache.
|
// KvCacheType is the quantization type for the K/V cache.
|
||||||
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
|
KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
|
||||||
// NoHistory disables readline history.
|
// NoHistory disables readline history.
|
||||||
@@ -302,28 +304,29 @@ type EnvVar struct {
|
|||||||
|
|
||||||
func AsMap() map[string]EnvVar {
|
func AsMap() map[string]EnvVar {
|
||||||
ret := map[string]EnvVar{
|
ret := map[string]EnvVar{
|
||||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
|
"OLLAMA_DEBUG_LOG_REQUESTS": {"OLLAMA_DEBUG_LOG_REQUESTS", DebugLogRequests(), "Log inference request bodies and replay curl commands to a temp directory"},
|
||||||
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(false), "Enabled flash attention"},
|
||||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||||
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||||
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
|
"OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||||
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
|
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
|
||||||
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
|
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"},
|
||||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
|
"OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"},
|
||||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
|
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
|
||||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
|
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
|
||||||
"OLLAMA_NO_CLOUD": {"OLLAMA_NO_CLOUD", NoCloud(), "Disable Ollama cloud features (remote inference and web search)"},
|
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
|
||||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
"OLLAMA_NO_CLOUD": {"OLLAMA_NO_CLOUD", NoCloud(), "Disable Ollama cloud features (remote inference and web search)"},
|
||||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
|
||||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
|
||||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
|
||||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
|
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||||
"OLLAMA_EDITOR": {"OLLAMA_EDITOR", Editor(), "Path to editor for interactive prompt editing (Ctrl+G)"},
|
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
|
||||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
"OLLAMA_EDITOR": {"OLLAMA_EDITOR", Editor(), "Path to editor for interactive prompt editing (Ctrl+G)"},
|
||||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||||
|
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||||
|
|
||||||
// Informational
|
// Informational
|
||||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||||
|
|||||||
@@ -87,7 +87,8 @@ type LlamaServer interface {
|
|||||||
type llmServer struct {
|
type llmServer struct {
|
||||||
port int
|
port int
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
done chan error // Channel to signal when the process exits
|
done chan struct{} // closed when the process exits
|
||||||
|
doneErr error // valid after done is closed
|
||||||
status *StatusWriter
|
status *StatusWriter
|
||||||
options api.Options
|
options api.Options
|
||||||
modelPath string
|
modelPath string
|
||||||
@@ -280,7 +281,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
|||||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||||
totalLayers: f.KV().BlockCount() + 1,
|
totalLayers: f.KV().BlockCount() + 1,
|
||||||
loadStart: time.Now(),
|
loadStart: time.Now(),
|
||||||
done: make(chan error, 1),
|
done: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -304,10 +305,11 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
|||||||
if strings.Contains(s.status.LastErrMsg, "unknown model") {
|
if strings.Contains(s.status.LastErrMsg, "unknown model") {
|
||||||
s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
|
s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
|
||||||
}
|
}
|
||||||
s.done <- errors.New(s.status.LastErrMsg)
|
s.doneErr = errors.New(s.status.LastErrMsg)
|
||||||
} else {
|
} else {
|
||||||
s.done <- err
|
s.doneErr = err
|
||||||
}
|
}
|
||||||
|
close(s.done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if tok != nil {
|
if tok != nil {
|
||||||
@@ -1356,8 +1358,8 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
slog.Warn("client connection closed before server finished loading, aborting load")
|
slog.Warn("client connection closed before server finished loading, aborting load")
|
||||||
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
|
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
|
||||||
case err := <-s.done:
|
case <-s.done:
|
||||||
return fmt.Errorf("llama runner process has terminated: %w", err)
|
return fmt.Errorf("llama runner process has terminated: %w", s.doneErr)
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
if time.Now().After(stallTimer) {
|
if time.Now().After(stallTimer) {
|
||||||
|
|||||||
144
server/inference_request_log.go
Normal file
144
server/inference_request_log.go
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
type inferenceRequestLogger struct {
|
||||||
|
dir string
|
||||||
|
counter uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newInferenceRequestLogger() (*inferenceRequestLogger, error) {
|
||||||
|
dir, err := os.MkdirTemp("", "ollama-request-logs-*")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &inferenceRequestLogger{dir: dir}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) initRequestLogging() error {
|
||||||
|
if !envconfig.DebugLogRequests() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
requestLogger, err := newInferenceRequestLogger()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("enable OLLAMA_DEBUG_LOG_REQUESTS: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.requestLogger = requestLogger
|
||||||
|
slog.Info(fmt.Sprintf("request debug logging enabled; inference request logs will be stored in %s and include request bodies and replay curl commands", requestLogger.dir))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) withInferenceRequestLogging(route string, handlers ...gin.HandlerFunc) []gin.HandlerFunc {
|
||||||
|
if s.requestLogger == nil {
|
||||||
|
return handlers
|
||||||
|
}
|
||||||
|
|
||||||
|
return append([]gin.HandlerFunc{s.requestLogger.middleware(route)}, handlers...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *inferenceRequestLogger) middleware(route string) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if c.Request == nil {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
method := c.Request.Method
|
||||||
|
host := c.Request.Host
|
||||||
|
scheme := "http"
|
||||||
|
if c.Request.TLS != nil {
|
||||||
|
scheme = "https"
|
||||||
|
}
|
||||||
|
contentType := c.GetHeader("Content-Type")
|
||||||
|
|
||||||
|
var body []byte
|
||||||
|
if c.Request.Body != nil {
|
||||||
|
var err error
|
||||||
|
body, err = io.ReadAll(c.Request.Body)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to read request body for debug logging", "route", route, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
l.log(route, method, scheme, host, contentType, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *inferenceRequestLogger) log(route, method, scheme, host, contentType string, body []byte) {
|
||||||
|
if l == nil || l.dir == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if contentType == "" {
|
||||||
|
contentType = "application/json"
|
||||||
|
}
|
||||||
|
if host == "" || scheme == "" {
|
||||||
|
base := envconfig.Host()
|
||||||
|
if host == "" {
|
||||||
|
host = base.Host
|
||||||
|
}
|
||||||
|
if scheme == "" {
|
||||||
|
scheme = base.Scheme
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
routeForFilename := sanitizeRouteForFilename(route)
|
||||||
|
timestamp := fmt.Sprintf("%s-%06d", time.Now().UTC().Format("20060102T150405.000000000Z"), atomic.AddUint64(&l.counter, 1))
|
||||||
|
bodyFilename := fmt.Sprintf("%s_%s_body.json", timestamp, routeForFilename)
|
||||||
|
curlFilename := fmt.Sprintf("%s_%s_request.sh", timestamp, routeForFilename)
|
||||||
|
bodyPath := filepath.Join(l.dir, bodyFilename)
|
||||||
|
curlPath := filepath.Join(l.dir, curlFilename)
|
||||||
|
|
||||||
|
if err := os.WriteFile(bodyPath, body, 0o600); err != nil {
|
||||||
|
slog.Warn("failed to write debug request body", "route", route, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s://%s%s", scheme, host, route)
|
||||||
|
curl := fmt.Sprintf("#!/bin/sh\nSCRIPT_DIR=\"$(CDPATH= cd -- \"$(dirname -- \"$0\")\" && pwd)\"\ncurl --request %s --url %q --header %q --data-binary @\"${SCRIPT_DIR}/%s\"\n", method, url, "Content-Type: "+contentType, bodyFilename)
|
||||||
|
if err := os.WriteFile(curlPath, []byte(curl), 0o600); err != nil {
|
||||||
|
slog.Warn("failed to write debug request replay command", "route", route, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info(fmt.Sprintf("logged to %s, replay using curl with `sh %s`", bodyPath, curlPath))
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeRouteForFilename(route string) string {
|
||||||
|
route = strings.TrimPrefix(route, "/")
|
||||||
|
if route == "" {
|
||||||
|
return "root"
|
||||||
|
}
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
b.Grow(len(route))
|
||||||
|
for _, r := range route {
|
||||||
|
if ('a' <= r && r <= 'z') || ('A' <= r && r <= 'Z') || ('0' <= r && r <= '9') {
|
||||||
|
b.WriteRune(r)
|
||||||
|
} else {
|
||||||
|
b.WriteByte('_')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
@@ -100,6 +100,7 @@ type Server struct {
|
|||||||
addr net.Addr
|
addr net.Addr
|
||||||
sched *Scheduler
|
sched *Scheduler
|
||||||
defaultNumCtx int
|
defaultNumCtx int
|
||||||
|
requestLogger *inferenceRequestLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -1686,26 +1687,26 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
r.GET("/api/ps", s.PsHandler)
|
r.GET("/api/ps", s.PsHandler)
|
||||||
r.POST("/api/generate", s.GenerateHandler)
|
r.POST("/api/generate", s.withInferenceRequestLogging("/api/generate", s.GenerateHandler)...)
|
||||||
r.POST("/api/chat", s.ChatHandler)
|
r.POST("/api/chat", s.withInferenceRequestLogging("/api/chat", s.ChatHandler)...)
|
||||||
r.POST("/api/embed", s.EmbedHandler)
|
r.POST("/api/embed", s.EmbedHandler)
|
||||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||||
|
|
||||||
// Inference (OpenAI compatibility)
|
// Inference (OpenAI compatibility)
|
||||||
// TODO(cloud-stage-a): apply Modelfile overlay deltas for local models with cloud
|
// TODO(cloud-stage-a): apply Modelfile overlay deltas for local models with cloud
|
||||||
// parents on v1 request families while preserving this explicit :cloud passthrough.
|
// parents on v1 request families while preserving this explicit :cloud passthrough.
|
||||||
r.POST("/v1/chat/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ChatMiddleware(), s.ChatHandler)
|
r.POST("/v1/chat/completions", s.withInferenceRequestLogging("/v1/chat/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ChatMiddleware(), s.ChatHandler)...)
|
||||||
r.POST("/v1/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.CompletionsMiddleware(), s.GenerateHandler)
|
r.POST("/v1/completions", s.withInferenceRequestLogging("/v1/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.CompletionsMiddleware(), s.GenerateHandler)...)
|
||||||
r.POST("/v1/embeddings", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
r.POST("/v1/embeddings", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||||
r.GET("/v1/models/:model", cloudModelPathPassthroughMiddleware(cloudErrRemoteModelDetailsUnavailable), middleware.RetrieveMiddleware(), s.ShowHandler)
|
r.GET("/v1/models/:model", cloudModelPathPassthroughMiddleware(cloudErrRemoteModelDetailsUnavailable), middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||||
r.POST("/v1/responses", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ResponsesMiddleware(), s.ChatHandler)
|
r.POST("/v1/responses", s.withInferenceRequestLogging("/v1/responses", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ResponsesMiddleware(), s.ChatHandler)...)
|
||||||
// OpenAI-compatible image generation endpoints
|
// OpenAI-compatible image generation endpoints
|
||||||
r.POST("/v1/images/generations", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
r.POST("/v1/images/generations", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||||
r.POST("/v1/images/edits", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
r.POST("/v1/images/edits", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||||
|
|
||||||
// Inference (Anthropic compatibility)
|
// Inference (Anthropic compatibility)
|
||||||
r.POST("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
r.POST("/v1/messages", s.withInferenceRequestLogging("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)...)
|
||||||
|
|
||||||
if rc != nil {
|
if rc != nil {
|
||||||
// wrap old with new
|
// wrap old with new
|
||||||
@@ -1757,6 +1758,9 @@ func Serve(ln net.Listener) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := &Server{addr: ln.Addr()}
|
s := &Server{addr: ln.Addr()}
|
||||||
|
if err := s.initRequestLogging(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
var rc *ollama.Registry
|
var rc *ollama.Registry
|
||||||
if useClient2 {
|
if useClient2 {
|
||||||
|
|||||||
128
server/routes_request_log_test.go
Normal file
128
server/routes_request_log_test.go
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInferenceRequestLoggerMiddlewareWritesReplayArtifacts(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
logDir := t.TempDir()
|
||||||
|
requestLogger := &inferenceRequestLogger{dir: logDir}
|
||||||
|
|
||||||
|
const route = "/v1/chat/completions"
|
||||||
|
const requestBody = `{"model":"test-model","messages":[{"role":"user","content":"hello"}]}`
|
||||||
|
|
||||||
|
var bodySeenByHandler string
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
r.POST(route, requestLogger.middleware(route), func(c *gin.Context) {
|
||||||
|
body, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read body in handler: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bodySeenByHandler = string(body)
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, route, strings.NewReader(requestBody))
|
||||||
|
req.Host = "127.0.0.1:11434"
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if bodySeenByHandler != requestBody {
|
||||||
|
t.Fatalf("handler body mismatch:\nexpected: %s\ngot: %s", requestBody, bodySeenByHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_body.json"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to glob body logs: %v", err)
|
||||||
|
}
|
||||||
|
if len(bodyFiles) != 1 {
|
||||||
|
t.Fatalf("expected 1 body log, got %d (%v)", len(bodyFiles), bodyFiles)
|
||||||
|
}
|
||||||
|
|
||||||
|
curlFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_request.sh"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to glob curl logs: %v", err)
|
||||||
|
}
|
||||||
|
if len(curlFiles) != 1 {
|
||||||
|
t.Fatalf("expected 1 curl log, got %d (%v)", len(curlFiles), curlFiles)
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyData, err := os.ReadFile(bodyFiles[0])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read body log: %v", err)
|
||||||
|
}
|
||||||
|
if string(bodyData) != requestBody {
|
||||||
|
t.Fatalf("body log mismatch:\nexpected: %s\ngot: %s", requestBody, string(bodyData))
|
||||||
|
}
|
||||||
|
|
||||||
|
curlData, err := os.ReadFile(curlFiles[0])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read curl log: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
curlString := string(curlData)
|
||||||
|
if !strings.Contains(curlString, "http://127.0.0.1:11434"+route) {
|
||||||
|
t.Fatalf("curl log does not contain expected route URL: %s", curlString)
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyFileName := filepath.Base(bodyFiles[0])
|
||||||
|
if !strings.Contains(curlString, "@\"${SCRIPT_DIR}/"+bodyFileName+"\"") {
|
||||||
|
t.Fatalf("curl log does not reference sibling body file: %s", curlString)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewInferenceRequestLoggerCreatesDirectory(t *testing.T) {
|
||||||
|
requestLogger, err := newInferenceRequestLogger()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error creating request logger: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = os.RemoveAll(requestLogger.dir)
|
||||||
|
})
|
||||||
|
|
||||||
|
if requestLogger == nil || requestLogger.dir == "" {
|
||||||
|
t.Fatalf("expected request logger directory to be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := os.Stat(requestLogger.dir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected directory to exist: %v", err)
|
||||||
|
}
|
||||||
|
if !info.IsDir() {
|
||||||
|
t.Fatalf("expected %q to be a directory", requestLogger.dir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeRouteForFilename(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
route string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{route: "/api/generate", want: "api_generate"},
|
||||||
|
{route: "/v1/chat/completions", want: "v1_chat_completions"},
|
||||||
|
{route: "/v1/messages", want: "v1_messages"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := sanitizeRouteForFilename(tt.route); got != tt.want {
|
||||||
|
t.Fatalf("sanitizeRouteForFilename(%q) = %q, want %q", tt.route, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
|
|
||||||
# Read MLX version from top-level file (shared with Dockerfile)
|
# Read MLX-C version from top-level file (shared with Dockerfile)
|
||||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG)
|
file(READ "${CMAKE_SOURCE_DIR}/MLX_C_VERSION" MLX_C_GIT_TAG)
|
||||||
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
||||||
|
|
||||||
# Read MLX core version from top-level file
|
# Read MLX version from top-level file
|
||||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_CORE_VERSION" MLX_GIT_TAG)
|
file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_GIT_TAG)
|
||||||
string(STRIP "${MLX_GIT_TAG}" MLX_GIT_TAG)
|
string(STRIP "${MLX_GIT_TAG}" MLX_GIT_TAG)
|
||||||
|
|
||||||
set(MLX_C_BUILD_EXAMPLES OFF)
|
set(MLX_C_BUILD_EXAMPLES OFF)
|
||||||
@@ -98,6 +98,15 @@ FetchContent_MakeAvailable(mlx-c)
|
|||||||
file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h")
|
file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h")
|
||||||
file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_SOURCE_DIR}/x/mlxrunner/mlx/include/mlx/c/")
|
file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_SOURCE_DIR}/x/mlxrunner/mlx/include/mlx/c/")
|
||||||
|
|
||||||
|
# Regenerate Go/C shim wrappers from the (possibly updated) headers.
|
||||||
|
find_program(GO_EXECUTABLE go REQUIRED)
|
||||||
|
message(STATUS "Regenerating MLX Go wrappers")
|
||||||
|
execute_process(
|
||||||
|
COMMAND ${GO_EXECUTABLE} generate ./x/...
|
||||||
|
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
|
||||||
|
COMMAND_ERROR_IS_FATAL ANY
|
||||||
|
)
|
||||||
|
|
||||||
# For local dev builds, override MLX_VERSION with git describe output
|
# For local dev builds, override MLX_VERSION with git describe output
|
||||||
if(TARGET mlx_version AND DEFINED FETCHCONTENT_SOURCE_DIR_MLX)
|
if(TARGET mlx_version AND DEFINED FETCHCONTENT_SOURCE_DIR_MLX)
|
||||||
execute_process(
|
execute_process(
|
||||||
|
|||||||
@@ -165,8 +165,8 @@ int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x, const
|
|||||||
int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group) = NULL;
|
int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group) = NULL;
|
||||||
int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group) = NULL;
|
int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group) = NULL;
|
||||||
mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key) = NULL;
|
mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key) = NULL;
|
||||||
bool (*mlx_distributed_is_available_ptr)(void) = NULL;
|
bool (*mlx_distributed_is_available_ptr)(const char* bk) = NULL;
|
||||||
mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict) = NULL;
|
mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk) = NULL;
|
||||||
void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) = NULL;
|
void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) = NULL;
|
||||||
void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...) = NULL;
|
void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...) = NULL;
|
||||||
int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) = NULL;
|
int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) = NULL;
|
||||||
@@ -319,10 +319,12 @@ int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, const
|
|||||||
int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_bartlett_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_blackman_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) = NULL;
|
int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) = NULL;
|
||||||
int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) = NULL;
|
int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) = NULL;
|
||||||
int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL;
|
int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL;
|
||||||
@@ -348,7 +350,7 @@ int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse
|
|||||||
int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
|
int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
|
||||||
int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) = NULL;
|
int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) = NULL;
|
||||||
int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) = NULL;
|
int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) = NULL;
|
||||||
int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
|
int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
|
||||||
int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) = NULL;
|
int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) = NULL;
|
||||||
int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
@@ -375,6 +377,8 @@ int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w,
|
|||||||
int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) = NULL;
|
int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_hamming_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_hanning_ptr)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
|
int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
|
||||||
int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
|
||||||
@@ -434,8 +438,8 @@ int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, siz
|
|||||||
int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
|
int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
|
||||||
int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
|
int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
|
||||||
int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL;
|
int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL;
|
||||||
int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) = NULL;
|
||||||
int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) = NULL;
|
||||||
int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
|
||||||
int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
@@ -2101,6 +2105,11 @@ int mlx_load_functions(void* handle) {
|
|||||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n");
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
mlx_bartlett_ptr = GET_SYM(handle, "mlx_bartlett");
|
||||||
|
if (mlx_bartlett_ptr == NULL) {
|
||||||
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_bartlett\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
mlx_bitwise_and_ptr = GET_SYM(handle, "mlx_bitwise_and");
|
mlx_bitwise_and_ptr = GET_SYM(handle, "mlx_bitwise_and");
|
||||||
if (mlx_bitwise_and_ptr == NULL) {
|
if (mlx_bitwise_and_ptr == NULL) {
|
||||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n");
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n");
|
||||||
@@ -2121,6 +2130,11 @@ int mlx_load_functions(void* handle) {
|
|||||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n");
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
mlx_blackman_ptr = GET_SYM(handle, "mlx_blackman");
|
||||||
|
if (mlx_blackman_ptr == NULL) {
|
||||||
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_blackman\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
mlx_block_masked_mm_ptr = GET_SYM(handle, "mlx_block_masked_mm");
|
mlx_block_masked_mm_ptr = GET_SYM(handle, "mlx_block_masked_mm");
|
||||||
if (mlx_block_masked_mm_ptr == NULL) {
|
if (mlx_block_masked_mm_ptr == NULL) {
|
||||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n");
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n");
|
||||||
@@ -2381,6 +2395,16 @@ int mlx_load_functions(void* handle) {
|
|||||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n");
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
mlx_hamming_ptr = GET_SYM(handle, "mlx_hamming");
|
||||||
|
if (mlx_hamming_ptr == NULL) {
|
||||||
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_hamming\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
mlx_hanning_ptr = GET_SYM(handle, "mlx_hanning");
|
||||||
|
if (mlx_hanning_ptr == NULL) {
|
||||||
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_hanning\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
mlx_identity_ptr = GET_SYM(handle, "mlx_identity");
|
mlx_identity_ptr = GET_SYM(handle, "mlx_identity");
|
||||||
if (mlx_identity_ptr == NULL) {
|
if (mlx_identity_ptr == NULL) {
|
||||||
fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n");
|
fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n");
|
||||||
@@ -4132,12 +4156,12 @@ mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, i
|
|||||||
return mlx_distributed_group_split_ptr(group, color, key);
|
return mlx_distributed_group_split_ptr(group, color, key);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool mlx_distributed_is_available(void) {
|
bool mlx_distributed_is_available(const char* bk) {
|
||||||
return mlx_distributed_is_available_ptr();
|
return mlx_distributed_is_available_ptr(bk);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlx_distributed_group mlx_distributed_init(bool strict) {
|
mlx_distributed_group mlx_distributed_init(bool strict, const char* bk) {
|
||||||
return mlx_distributed_init_ptr(strict);
|
return mlx_distributed_init_ptr(strict, bk);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) {
|
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) {
|
||||||
@@ -4748,6 +4772,10 @@ int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
|||||||
return mlx_atleast_3d_ptr(res, a, s);
|
return mlx_atleast_3d_ptr(res, a, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_bartlett_ptr(res, M, s);
|
||||||
|
}
|
||||||
|
|
||||||
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
|
||||||
return mlx_bitwise_and_ptr(res, a, b, s);
|
return mlx_bitwise_and_ptr(res, a, b, s);
|
||||||
}
|
}
|
||||||
@@ -4764,6 +4792,10 @@ int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const
|
|||||||
return mlx_bitwise_xor_ptr(res, a, b, s);
|
return mlx_bitwise_xor_ptr(res, a, b, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int mlx_blackman(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_blackman_ptr(res, M, s);
|
||||||
|
}
|
||||||
|
|
||||||
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) {
|
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) {
|
||||||
return mlx_block_masked_mm_ptr(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s);
|
return mlx_block_masked_mm_ptr(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s);
|
||||||
}
|
}
|
||||||
@@ -4864,8 +4896,8 @@ int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_
|
|||||||
return mlx_depends_ptr(res, inputs, dependencies);
|
return mlx_depends_ptr(res, inputs, dependencies);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) {
|
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s) {
|
||||||
return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, dtype, s);
|
return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
|
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
|
||||||
@@ -4972,6 +5004,14 @@ int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float
|
|||||||
return mlx_hadamard_transform_ptr(res, a, scale, s);
|
return mlx_hadamard_transform_ptr(res, a, scale, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int mlx_hamming(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_hamming_ptr(res, M, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
int mlx_hanning(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_hanning_ptr(res, M, s);
|
||||||
|
}
|
||||||
|
|
||||||
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
|
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
|
||||||
return mlx_identity_ptr(res, n, dtype, s);
|
return mlx_identity_ptr(res, n, dtype, s);
|
||||||
}
|
}
|
||||||
@@ -5208,12 +5248,12 @@ int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indice
|
|||||||
return mlx_put_along_axis_ptr(res, a, indices, values, axis, s);
|
return mlx_put_along_axis_ptr(res, a, indices, values, axis, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s) {
|
||||||
return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, s);
|
return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s) {
|
||||||
return mlx_quantize_ptr(res, w, group_size, bits, mode, s);
|
return mlx_quantize_ptr(res, w, group_size, bits, mode, global_scale, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
|
||||||
|
|||||||
@@ -2125,7 +2125,8 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
|
|||||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||||
res := C.mlx_vector_array_new()
|
res := C.mlx_vector_array_new()
|
||||||
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
|
var globalScale C.mlx_array
|
||||||
|
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, globalScale, C.default_stream())
|
||||||
|
|
||||||
// Result is a vector of arrays: [weights, scales, biases?]
|
// Result is a vector of arrays: [weights, scales, biases?]
|
||||||
// mxfp8 mode returns only 2 elements (no biases)
|
// mxfp8 mode returns only 2 elements (no biases)
|
||||||
@@ -2161,7 +2162,8 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
|
|||||||
}
|
}
|
||||||
|
|
||||||
res := C.mlx_array_new()
|
res := C.mlx_array_new()
|
||||||
C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, optDtype, C.default_stream())
|
var globalScale C.mlx_array
|
||||||
|
C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, globalScale, optDtype, C.default_stream())
|
||||||
return newArray(res)
|
return newArray(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -309,10 +309,12 @@
|
|||||||
#undef mlx_atleast_1d
|
#undef mlx_atleast_1d
|
||||||
#undef mlx_atleast_2d
|
#undef mlx_atleast_2d
|
||||||
#undef mlx_atleast_3d
|
#undef mlx_atleast_3d
|
||||||
|
#undef mlx_bartlett
|
||||||
#undef mlx_bitwise_and
|
#undef mlx_bitwise_and
|
||||||
#undef mlx_bitwise_invert
|
#undef mlx_bitwise_invert
|
||||||
#undef mlx_bitwise_or
|
#undef mlx_bitwise_or
|
||||||
#undef mlx_bitwise_xor
|
#undef mlx_bitwise_xor
|
||||||
|
#undef mlx_blackman
|
||||||
#undef mlx_block_masked_mm
|
#undef mlx_block_masked_mm
|
||||||
#undef mlx_broadcast_arrays
|
#undef mlx_broadcast_arrays
|
||||||
#undef mlx_broadcast_to
|
#undef mlx_broadcast_to
|
||||||
@@ -365,6 +367,8 @@
|
|||||||
#undef mlx_greater
|
#undef mlx_greater
|
||||||
#undef mlx_greater_equal
|
#undef mlx_greater_equal
|
||||||
#undef mlx_hadamard_transform
|
#undef mlx_hadamard_transform
|
||||||
|
#undef mlx_hamming
|
||||||
|
#undef mlx_hanning
|
||||||
#undef mlx_identity
|
#undef mlx_identity
|
||||||
#undef mlx_imag
|
#undef mlx_imag
|
||||||
#undef mlx_inner
|
#undef mlx_inner
|
||||||
@@ -751,8 +755,8 @@ extern int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x,
|
|||||||
extern int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group);
|
extern int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group);
|
||||||
extern int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group);
|
extern int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group);
|
||||||
extern mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key);
|
extern mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key);
|
||||||
extern bool (*mlx_distributed_is_available_ptr)(void);
|
extern bool (*mlx_distributed_is_available_ptr)(const char* bk);
|
||||||
extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict);
|
extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict, const char* bk);
|
||||||
extern void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
|
extern void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
|
||||||
extern void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...);
|
extern void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...);
|
||||||
extern int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless);
|
extern int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless);
|
||||||
@@ -905,10 +909,12 @@ extern int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype,
|
|||||||
extern int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
|
extern int (*mlx_bartlett_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||||
extern int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
extern int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
extern int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
|
extern int (*mlx_blackman_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||||
extern int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
|
extern int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
|
||||||
extern int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
|
extern int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
|
||||||
extern int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s);
|
extern int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s);
|
||||||
@@ -934,7 +940,7 @@ extern int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool
|
|||||||
extern int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s);
|
extern int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s);
|
||||||
extern int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
|
extern int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
|
||||||
extern int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s);
|
extern int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s);
|
||||||
extern int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
extern int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||||
extern int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s);
|
extern int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s);
|
||||||
extern int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
@@ -961,6 +967,8 @@ extern int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_ar
|
|||||||
extern int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
extern int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
extern int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
|
extern int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
|
||||||
|
extern int (*mlx_hamming_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
extern int (*mlx_hanning_ptr)(mlx_array* res, int M, const mlx_stream s);
|
||||||
extern int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
extern int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||||
extern int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
extern int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
@@ -1020,8 +1028,8 @@ extern int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* ax
|
|||||||
extern int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s);
|
extern int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s);
|
||||||
extern int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s);
|
extern int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s);
|
||||||
extern int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
|
extern int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
|
||||||
extern int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
extern int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s);
|
||||||
extern int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
extern int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s);
|
||||||
extern int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
extern int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
||||||
extern int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
@@ -1492,9 +1500,9 @@ int mlx_distributed_group_size(mlx_distributed_group group);
|
|||||||
|
|
||||||
mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
|
mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
|
||||||
|
|
||||||
bool mlx_distributed_is_available(void);
|
bool mlx_distributed_is_available(const char* bk);
|
||||||
|
|
||||||
mlx_distributed_group mlx_distributed_init(bool strict);
|
mlx_distributed_group mlx_distributed_init(bool strict, const char* bk);
|
||||||
|
|
||||||
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
|
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*));
|
||||||
|
|
||||||
@@ -1800,6 +1808,8 @@ int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
|||||||
|
|
||||||
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
|
|
||||||
|
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
@@ -1808,6 +1818,8 @@ int mlx_bitwise_or(mlx_array* res, const mlx_array a, const mlx_array b, const m
|
|||||||
|
|
||||||
int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s);
|
||||||
|
|
||||||
|
int mlx_blackman(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
|
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s);
|
||||||
|
|
||||||
int mlx_broadcast_arrays(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
|
int mlx_broadcast_arrays(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s);
|
||||||
@@ -1858,7 +1870,7 @@ int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s);
|
|||||||
|
|
||||||
int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
|
int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies);
|
||||||
|
|
||||||
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s);
|
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , mlx_optional_dtype dtype, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||||
|
|
||||||
@@ -1912,6 +1924,10 @@ int mlx_greater_equal(mlx_array* res, const mlx_array a, const mlx_array b, cons
|
|||||||
|
|
||||||
int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
|
int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s);
|
||||||
|
|
||||||
|
int mlx_hamming(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
|
||||||
|
int mlx_hanning(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
@@ -2030,9 +2046,9 @@ int mlx_prod(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream
|
|||||||
|
|
||||||
int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
|
int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s);
|
||||||
|
|
||||||
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale_x , const mlx_array global_scale_w , const mlx_stream s);
|
||||||
|
|
||||||
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_array global_scale , const mlx_stream s);
|
||||||
|
|
||||||
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s);
|
||||||
|
|
||||||
|
|||||||
@@ -93,21 +93,8 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
|||||||
matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1])
|
matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for partial match within a node's edge — truncate path
|
|
||||||
// to the parent boundary. snapshot() will split the node and
|
|
||||||
// create the branch point during prefill when caches are ready.
|
|
||||||
partialMatch := false
|
|
||||||
if len(matchPath) > 1 {
|
|
||||||
lastNode := matchPath[len(matchPath)-1]
|
|
||||||
matchedInEdge := matched - lastNode.startOffset()
|
|
||||||
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
|
|
||||||
matchPath = matchPath[:len(matchPath)-1]
|
|
||||||
partialMatch = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Switch to the matched path, paging in/out as needed.
|
// Switch to the matched path, paging in/out as needed.
|
||||||
c.switchToPath(matchPath)
|
c.switchToPath(matchPath, matched)
|
||||||
|
|
||||||
// switchToPath aligns caches to a common offset
|
// switchToPath aligns caches to a common offset
|
||||||
prefix := c.minCacheOffset()
|
prefix := c.minCacheOffset()
|
||||||
@@ -116,7 +103,7 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
|||||||
// Schedule a snapshot at the branch point during prefill so future
|
// Schedule a snapshot at the branch point during prefill so future
|
||||||
// requests diverging here can restore instead of re-evaluating.
|
// requests diverging here can restore instead of re-evaluating.
|
||||||
var snapshotAt int
|
var snapshotAt int
|
||||||
if partialMatch || (prefix == 0 && matched > 0) {
|
if prefix < matched {
|
||||||
snapshotAt = matched
|
snapshotAt = matched
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,7 +129,7 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
|||||||
|
|
||||||
// switchToPath transitions from the current active path to a new path,
|
// switchToPath transitions from the current active path to a new path,
|
||||||
// paging out diverging segments and paging in the new path.
|
// paging out diverging segments and paging in the new path.
|
||||||
func (c *kvCache) switchToPath(newPath []*trieNode) {
|
func (c *kvCache) switchToPath(newPath []*trieNode, matched int) {
|
||||||
defer c.enforceEvictionPolicy()
|
defer c.enforceEvictionPolicy()
|
||||||
|
|
||||||
// Find common ancestor index.
|
// Find common ancestor index.
|
||||||
@@ -167,7 +154,10 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
|
|||||||
// non-leaf nodes here would produce wrong results for non-rewindable
|
// non-leaf nodes here would produce wrong results for non-rewindable
|
||||||
// caches (e.g. RecurrentCache) whose state reflects the leaf, not
|
// caches (e.g. RecurrentCache) whose state reflects the leaf, not
|
||||||
// the intermediate boundary.
|
// the intermediate boundary.
|
||||||
if leaf := len(c.activePath) - 1; leaf >= commonLen {
|
leaf := len(c.activePath) - 1
|
||||||
|
leafDiverges := leaf >= commonLen
|
||||||
|
leafNeedsRewind := matched < c.activePath[leaf].endOffset
|
||||||
|
if leafDiverges || leafNeedsRewind {
|
||||||
node := c.activePath[leaf]
|
node := c.activePath[leaf]
|
||||||
if !node.hasAllSnapshots() {
|
if !node.hasAllSnapshots() {
|
||||||
fromOffset := node.startOffset()
|
fromOffset := node.startOffset()
|
||||||
@@ -184,14 +174,16 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rewind each cache to the ancestor offset or free it. Freed
|
// Rewind each cache to the target offset or free it. When matched
|
||||||
// caches (e.g. RecurrentCache that can't rewind) will be restored
|
// falls within the ancestor's range (same-path case), we rewind
|
||||||
// from snapshots during page-in.
|
// directly to the match point. Otherwise we rewind to the ancestor
|
||||||
|
// and let page-in bring us forward to matched.
|
||||||
|
rewindTarget := min(ancestorOffset, matched)
|
||||||
for _, kv := range c.caches {
|
for _, kv := range c.caches {
|
||||||
if kv == nil {
|
if kv == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !kv.Restore(nil, ancestorOffset) {
|
if !kv.Restore(nil, rewindTarget) {
|
||||||
kv.Free()
|
kv.Free()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -199,10 +191,12 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
|
|||||||
// Page in — walk the full new path, restoring from snapshots.
|
// Page in — walk the full new path, restoring from snapshots.
|
||||||
// Freed caches naturally pick up the first available snapshot.
|
// Freed caches naturally pick up the first available snapshot.
|
||||||
// Caches already past a node skip it via offset check.
|
// Caches already past a node skip it via offset check.
|
||||||
|
pageIn:
|
||||||
for _, node := range newPath {
|
for _, node := range newPath {
|
||||||
if len(node.snapshots) == 0 {
|
if !node.hasSnapshots() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
nodeTarget := min(node.endOffset, matched)
|
||||||
for j, kv := range c.caches {
|
for j, kv := range c.caches {
|
||||||
if kv == nil {
|
if kv == nil {
|
||||||
continue
|
continue
|
||||||
@@ -210,19 +204,18 @@ func (c *kvCache) switchToPath(newPath []*trieNode) {
|
|||||||
if j >= len(node.snapshots) || node.snapshots[j] == nil {
|
if j >= len(node.snapshots) || node.snapshots[j] == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if kv.Offset() >= node.endOffset {
|
if kv.Offset() >= nodeTarget {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !kv.Restore(node.snapshots[j], node.endOffset) {
|
if !kv.Restore(node.snapshots[j], nodeTarget) {
|
||||||
slog.Warn("cache restore failure during page-in, freeing all caches", "layer", j, "offset", node.startOffset())
|
// Restore failed — stop page-in and let alignment
|
||||||
c.freeAll()
|
// bring all caches to a consistent offset.
|
||||||
c.activePath = []*trieNode{c.root}
|
break pageIn
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if node.endOffset > ancestorOffset {
|
if node.endOffset > ancestorOffset {
|
||||||
pageInCount++
|
pageInCount++
|
||||||
logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), node.endOffset))
|
logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), nodeTarget))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -536,6 +529,9 @@ func (c *kvCache) dumpTree() {
|
|||||||
if nodeBytes > 0 {
|
if nodeBytes > 0 {
|
||||||
label += " " + mlx.PrettyBytes(int(nodeBytes)).String()
|
label += " " + mlx.PrettyBytes(int(nodeBytes)).String()
|
||||||
}
|
}
|
||||||
|
if !n.lastUsed.IsZero() {
|
||||||
|
label += fmt.Sprintf(" %s ago", time.Since(n.lastUsed).Truncate(time.Millisecond))
|
||||||
|
}
|
||||||
var flags []string
|
var flags []string
|
||||||
if n.user {
|
if n.user {
|
||||||
flags = append(flags, "user")
|
flags = append(flags, "user")
|
||||||
|
|||||||
28
x/mlxrunner/cache/cache.go
vendored
28
x/mlxrunner/cache/cache.go
vendored
@@ -17,7 +17,8 @@ type Cache interface {
|
|||||||
Snapshot(fromOffset int) Snapshot
|
Snapshot(fromOffset int) Snapshot
|
||||||
|
|
||||||
// Restore brings the cache to target. If snapshot is nil, rewinds
|
// Restore brings the cache to target. If snapshot is nil, rewinds
|
||||||
// using the cache's own live state.
|
// using the cache's own live state. Returns false if the target is
|
||||||
|
// unreachable (e.g. target > current offset, or negative).
|
||||||
Restore(snapshot Snapshot, target int) bool
|
Restore(snapshot Snapshot, target int) bool
|
||||||
|
|
||||||
// Merge combines two sequential snapshots [a,b) and [b,c) into [a,c).
|
// Merge combines two sequential snapshots [a,b) and [b,c) into [a,c).
|
||||||
@@ -122,17 +123,21 @@ func (c *KVCache) Snapshot(fromOffset int) Snapshot {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *KVCache) Restore(snapshot Snapshot, target int) bool {
|
func (c *KVCache) Restore(snapshot Snapshot, target int) bool {
|
||||||
|
if target < 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
if snapshot == nil {
|
if snapshot == nil {
|
||||||
// Rewind using live state — just clamp offset.
|
if target > c.offset {
|
||||||
target = max(0, min(target, c.offset))
|
return false
|
||||||
|
}
|
||||||
c.offset = target
|
c.offset = target
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
snap := snapshot.(*kvSnapshot)
|
snap := snapshot.(*kvSnapshot)
|
||||||
|
|
||||||
// Check that the cache has data up to the snapshot's starting point.
|
if target > snap.toOffset || c.offset < snap.fromOffset {
|
||||||
if c.offset < snap.fromOffset {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -354,7 +359,14 @@ func (c *RotatingKVCache) Snapshot(fromOffset int) Snapshot {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
|
func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
|
||||||
|
if target < 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
if snapshot == nil {
|
if snapshot == nil {
|
||||||
|
if target >= c.offset {
|
||||||
|
return target == c.offset
|
||||||
|
}
|
||||||
// Live rewind is only safe when the buffer hasn't filled yet
|
// Live rewind is only safe when the buffer hasn't filled yet
|
||||||
// (offset <= maxSize). Once the window has shifted, rewinding
|
// (offset <= maxSize). Once the window has shifted, rewinding
|
||||||
// leaves fewer than maxSize trailing tokens to attend to —
|
// leaves fewer than maxSize trailing tokens to attend to —
|
||||||
@@ -362,7 +374,6 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
|
|||||||
if c.offset > c.maxSize {
|
if c.offset > c.maxSize {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
target = max(0, min(target, c.offset))
|
|
||||||
c.offset = target
|
c.offset = target
|
||||||
c.idx = target
|
c.idx = target
|
||||||
return true
|
return true
|
||||||
@@ -370,6 +381,10 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
|
|||||||
|
|
||||||
snap := snapshot.(*rotatingSnapshot)
|
snap := snapshot.(*rotatingSnapshot)
|
||||||
|
|
||||||
|
if target > snap.toOffset {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// Reject if clamping would leave an incomplete window.
|
// Reject if clamping would leave an incomplete window.
|
||||||
if target < snap.toOffset && snap.toOffset > c.maxSize {
|
if target < snap.toOffset && snap.toOffset > c.maxSize {
|
||||||
return false
|
return false
|
||||||
@@ -388,7 +403,6 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
|
|||||||
|
|
||||||
// Clamp to target if needed.
|
// Clamp to target if needed.
|
||||||
if target < c.offset {
|
if target < c.offset {
|
||||||
target = max(0, target)
|
|
||||||
c.offset = target
|
c.offset = target
|
||||||
c.idx = target
|
c.idx = target
|
||||||
}
|
}
|
||||||
|
|||||||
22
x/mlxrunner/cache/recurrent.go
vendored
22
x/mlxrunner/cache/recurrent.go
vendored
@@ -22,14 +22,9 @@ func (c *RecurrentCache) setStateRaw(old, v *mlx.Array) *mlx.Array {
|
|||||||
if v == nil || !v.Valid() {
|
if v == nil || !v.Valid() {
|
||||||
return old
|
return old
|
||||||
}
|
}
|
||||||
if old == v {
|
|
||||||
return old
|
|
||||||
}
|
|
||||||
|
|
||||||
mlx.Pin(v)
|
mlx.Pin(v)
|
||||||
if old != nil && old != v {
|
mlx.Unpin(old)
|
||||||
mlx.Unpin(old)
|
|
||||||
}
|
|
||||||
|
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
@@ -38,9 +33,6 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo
|
|||||||
if v == nil || !v.Valid() {
|
if v == nil || !v.Valid() {
|
||||||
return old
|
return old
|
||||||
}
|
}
|
||||||
if old == v {
|
|
||||||
return old
|
|
||||||
}
|
|
||||||
|
|
||||||
root := v
|
root := v
|
||||||
if ensureContiguous {
|
if ensureContiguous {
|
||||||
@@ -49,9 +41,7 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo
|
|||||||
detached := root.Clone()
|
detached := root.Clone()
|
||||||
|
|
||||||
mlx.Pin(detached)
|
mlx.Pin(detached)
|
||||||
if old != nil && old != detached {
|
mlx.Unpin(old)
|
||||||
mlx.Unpin(old)
|
|
||||||
}
|
|
||||||
|
|
||||||
return detached
|
return detached
|
||||||
}
|
}
|
||||||
@@ -150,10 +140,10 @@ func (c *RecurrentCache) Restore(snapshot Snapshot, target int) bool {
|
|||||||
|
|
||||||
snap := snapshot.(*recurrentSnapshot)
|
snap := snapshot.(*recurrentSnapshot)
|
||||||
|
|
||||||
// Recurrent state encodes all tokens up to snap.offset. Restoring
|
// Recurrent snapshots encode cumulative state up to exactly
|
||||||
// to a target before that would leave stale state from tokens
|
// snap.offset. Target must match — rewinding would leave stale
|
||||||
// [target, snap.offset) baked in. Only allow restoring forward.
|
// state, and advancing isn't possible without feeding tokens.
|
||||||
if target < snap.offset {
|
if target != snap.offset {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
34
x/mlxrunner/cache/recurrent_test.go
vendored
34
x/mlxrunner/cache/recurrent_test.go
vendored
@@ -6,39 +6,35 @@ import (
|
|||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestRecurrentCacheRestoreDirectionality verifies that RecurrentCache only
|
// TestRecurrentCacheRestoreExactOffset verifies that RecurrentCache restore
|
||||||
// allows restoring forward (target >= snapshot offset), never backward.
|
// only succeeds when target exactly matches the snapshot's offset. Recurrent
|
||||||
func TestRecurrentCacheRestoreDirectionality(t *testing.T) {
|
// state is cumulative, so it can't be rewound or fast-forwarded.
|
||||||
|
func TestRecurrentCacheRestoreExactOffset(t *testing.T) {
|
||||||
skipIfNoMLX(t)
|
skipIfNoMLX(t)
|
||||||
c := NewRecurrentCache(3, 12, 4, 8, 8)
|
c := NewRecurrentCache(3, 12, 4, 8, 8)
|
||||||
_ = c.ConvState(1, mlx.DTypeFloat16)
|
_ = c.ConvState(1, mlx.DTypeFloat16)
|
||||||
_ = c.DeltaState(1, mlx.DTypeFloat16)
|
_ = c.DeltaState(1, mlx.DTypeFloat16)
|
||||||
c.Advance(10)
|
c.Advance(10)
|
||||||
|
|
||||||
snap := c.Snapshot(0)
|
snap := c.Snapshot(0) // snap.offset == 10
|
||||||
|
|
||||||
c.Advance(5) // now at 15
|
c.Advance(5) // cache now at 15
|
||||||
|
|
||||||
// Restore backward should fail.
|
// target < snap.offset: fails (can't rewind past snapshot)
|
||||||
if c.Restore(snap, 5) {
|
if c.Restore(snap, 5) {
|
||||||
t.Fatal("Restore(snap, 5) should fail — target < snap.offset")
|
t.Fatal("Restore(snap, 5) should fail — target != snap.offset")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Restore to exact snap offset should succeed.
|
// target > snap.offset: fails (can't advance without feeding tokens)
|
||||||
|
if c.Restore(snap, 15) {
|
||||||
|
t.Fatal("Restore(snap, 15) should fail — target != snap.offset")
|
||||||
|
}
|
||||||
|
|
||||||
|
// target == snap.offset: succeeds
|
||||||
if !c.Restore(snap, 10) {
|
if !c.Restore(snap, 10) {
|
||||||
t.Fatal("Restore(snap, 10) should succeed")
|
t.Fatal("Restore(snap, 10) should succeed — target == snap.offset")
|
||||||
}
|
}
|
||||||
if c.Offset() != 10 {
|
if c.Offset() != 10 {
|
||||||
t.Fatalf("offset = %d, want 10", c.Offset())
|
t.Fatalf("offset = %d, want 10", c.Offset())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Restore forward (target > snap offset) should succeed, offset = snap.offset.
|
|
||||||
snap2 := c.Snapshot(0)
|
|
||||||
if !c.Restore(snap2, 15) {
|
|
||||||
t.Fatal("Restore(snap, 15) should succeed")
|
|
||||||
}
|
|
||||||
// Recurrent state is at snap.offset (10), not target (15).
|
|
||||||
if c.Offset() != 10 {
|
|
||||||
t.Fatalf("offset = %d, want 10 (snap offset)", c.Offset())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -79,20 +79,20 @@ func (c *fakeRewindableCache) Snapshot(fromOffset int) cache.Snapshot {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool {
|
func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool {
|
||||||
|
if target < 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
if snapshot == nil {
|
if snapshot == nil {
|
||||||
// Rewind live state.
|
|
||||||
if target < 0 {
|
|
||||||
target = 0
|
|
||||||
}
|
|
||||||
if target > len(c.tokens) {
|
if target > len(c.tokens) {
|
||||||
target = len(c.tokens)
|
return false
|
||||||
}
|
}
|
||||||
c.tokens = c.tokens[:target]
|
c.tokens = c.tokens[:target]
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
s := snapshot.(*fakeSnapshot)
|
s := snapshot.(*fakeSnapshot)
|
||||||
if len(c.tokens) < s.from {
|
if target > s.to || len(c.tokens) < s.from {
|
||||||
return false // don't have base data up to snapshot start
|
return false
|
||||||
}
|
}
|
||||||
c.tokens = append(c.tokens[:s.from], s.tokens...)
|
c.tokens = append(c.tokens[:s.from], s.tokens...)
|
||||||
if target < len(c.tokens) {
|
if target < len(c.tokens) {
|
||||||
@@ -196,9 +196,13 @@ func (c *fakeSlidingWindowCache) Snapshot(fromOffset int) cache.Snapshot {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool {
|
func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool {
|
||||||
|
if target < 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
if snapshot == nil {
|
if snapshot == nil {
|
||||||
if target == len(c.tokens) {
|
if target >= len(c.tokens) {
|
||||||
return true
|
return target == len(c.tokens)
|
||||||
}
|
}
|
||||||
// Live rewind only works when buffer hasn't filled (offset <= maxSize).
|
// Live rewind only works when buffer hasn't filled (offset <= maxSize).
|
||||||
if len(c.tokens) > c.maxSize {
|
if len(c.tokens) > c.maxSize {
|
||||||
@@ -208,6 +212,14 @@ func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bo
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
s := snapshot.(*fakeSnapshot)
|
s := snapshot.(*fakeSnapshot)
|
||||||
|
if target > s.to {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Reject if clamping would leave an incomplete window
|
||||||
|
// (matches RotatingKVCache behavior).
|
||||||
|
if target < s.to && s.to > c.maxSize {
|
||||||
|
return false
|
||||||
|
}
|
||||||
c.tokens = slices.Clone(s.tokens)
|
c.tokens = slices.Clone(s.tokens)
|
||||||
if target < len(c.tokens) {
|
if target < len(c.tokens) {
|
||||||
c.tokens = c.tokens[:target]
|
c.tokens = c.tokens[:target]
|
||||||
@@ -268,8 +280,8 @@ func (c *fakeRecurrentCache) Restore(snapshot cache.Snapshot, target int) bool {
|
|||||||
return target == len(c.tokens) // can only no-op
|
return target == len(c.tokens) // can only no-op
|
||||||
}
|
}
|
||||||
s := snapshot.(*fakeSnapshot)
|
s := snapshot.(*fakeSnapshot)
|
||||||
if target < s.to {
|
if target != s.to {
|
||||||
return false // can't go backward
|
return false // cumulative state requires exact match
|
||||||
}
|
}
|
||||||
c.tokens = slices.Clone(s.tokens)
|
c.tokens = slices.Clone(s.tokens)
|
||||||
return true
|
return true
|
||||||
@@ -294,9 +306,10 @@ type feedableCache interface {
|
|||||||
|
|
||||||
// testEnv encapsulates a kvCache and its fake caches for a test scenario.
|
// testEnv encapsulates a kvCache and its fake caches for a test scenario.
|
||||||
type testEnv struct {
|
type testEnv struct {
|
||||||
kvc *kvCache
|
kvc *kvCache
|
||||||
caches []cache.Cache // typed references for assertions
|
caches []cache.Cache // typed references for assertions
|
||||||
tracker *snapshotTracker
|
tracker *snapshotTracker
|
||||||
|
rewindable bool // true when all caches support arbitrary Restore(nil, target)
|
||||||
}
|
}
|
||||||
|
|
||||||
// newTransformerEnv creates a test environment with a single rewindable cache
|
// newTransformerEnv creates a test environment with a single rewindable cache
|
||||||
@@ -305,23 +318,28 @@ func newTransformerEnv() *testEnv {
|
|||||||
tracker := &snapshotTracker{}
|
tracker := &snapshotTracker{}
|
||||||
caches := []cache.Cache{&fakeRewindableCache{tracker: tracker}}
|
caches := []cache.Cache{&fakeRewindableCache{tracker: tracker}}
|
||||||
return &testEnv{
|
return &testEnv{
|
||||||
kvc: &kvCache{caches: caches},
|
kvc: &kvCache{caches: caches},
|
||||||
caches: caches,
|
caches: caches,
|
||||||
tracker: tracker,
|
tracker: tracker,
|
||||||
|
rewindable: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// newSlidingWindowEnv creates a test environment with one rewindable cache and
|
// newSlidingWindowEnv creates a test environment with one rewindable cache and
|
||||||
// one sliding window cache (Mistral-style architecture).
|
// one sliding window cache (Mistral-style architecture). The sliding window
|
||||||
|
// maxSize is set small enough that test sequences fill it, making
|
||||||
|
// Restore(nil, target) fail — the same behavior as production models where
|
||||||
|
// the window fills after a few turns.
|
||||||
func newSlidingWindowEnv() *testEnv {
|
func newSlidingWindowEnv() *testEnv {
|
||||||
tr := &snapshotTracker{}
|
tr := &snapshotTracker{}
|
||||||
rc := &fakeRewindableCache{tracker: tr}
|
rc := &fakeRewindableCache{tracker: tr}
|
||||||
sw := &fakeSlidingWindowCache{maxSize: 32, tracker: tr}
|
sw := &fakeSlidingWindowCache{maxSize: 4, tracker: tr}
|
||||||
caches := []cache.Cache{rc, sw}
|
caches := []cache.Cache{rc, sw}
|
||||||
return &testEnv{
|
return &testEnv{
|
||||||
kvc: &kvCache{caches: caches},
|
kvc: &kvCache{caches: caches},
|
||||||
caches: caches,
|
caches: caches,
|
||||||
tracker: tr,
|
tracker: tr,
|
||||||
|
rewindable: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -333,9 +351,10 @@ func newRecurrentEnv() *testEnv {
|
|||||||
nrc := &fakeRecurrentCache{tracker: tr}
|
nrc := &fakeRecurrentCache{tracker: tr}
|
||||||
caches := []cache.Cache{rc, nrc}
|
caches := []cache.Cache{rc, nrc}
|
||||||
return &testEnv{
|
return &testEnv{
|
||||||
kvc: &kvCache{caches: caches},
|
kvc: &kvCache{caches: caches},
|
||||||
caches: caches,
|
caches: caches,
|
||||||
tracker: tr,
|
tracker: tr,
|
||||||
|
rewindable: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -590,15 +609,24 @@ func TestBranchCreationAndReuse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Request B: [1,2,3,4,5,10,11,12] — shares 5-token prefix with A.
|
// Request B: [1,2,3,4,5,10,11,12] — shares 5-token prefix with A.
|
||||||
// Partial match in A's edge triggers snapshotOffset.
|
// For rewindable caches, switchToPath rewinds to the match point
|
||||||
|
// so only the non-matching suffix needs evaluation. For non-rewindable
|
||||||
|
// caches (RecurrentCache), the rewind fails and freeAll fires.
|
||||||
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31})
|
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31})
|
||||||
if resB.snapshotOffset != 5 {
|
if env.rewindable {
|
||||||
t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset)
|
if resB.snapshotOffset != 0 {
|
||||||
}
|
t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset)
|
||||||
// Cache was rewound to 0 (partial match truncates path to root),
|
}
|
||||||
// so all tokens were re-evaluated.
|
if len(resB.remaining) != 3 {
|
||||||
if len(resB.remaining) != 8 {
|
t.Fatalf("B: remaining = %d, want 3 (rewind to match point)", len(resB.remaining))
|
||||||
t.Fatalf("B: remaining = %d, want 8", len(resB.remaining))
|
}
|
||||||
|
} else {
|
||||||
|
if resB.snapshotOffset != 5 {
|
||||||
|
t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset)
|
||||||
|
}
|
||||||
|
if len(resB.remaining) != 8 {
|
||||||
|
t.Fatalf("B: remaining = %d, want 8 (freeAll fallback)", len(resB.remaining))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
|
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
|
||||||
|
|
||||||
@@ -635,14 +663,24 @@ func TestExactMatchSeedBehavior(t *testing.T) {
|
|||||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
|
||||||
|
|
||||||
// Request B: identical prompt. Holdback means matched=4, partial in
|
// Request B: identical prompt. Holdback means matched=4, partial in
|
||||||
// the 5-token edge, so path truncates to root and all tokens are
|
// the 5-token edge. For rewindable caches, switchToPath rewinds to
|
||||||
// re-evaluated. snapshotOffset should be set at the holdback point.
|
// offset 4, so only the held-back token needs re-evaluation. For
|
||||||
|
// non-rewindable caches, the rewind fails and freeAll fires.
|
||||||
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{20, 21})
|
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{20, 21})
|
||||||
if len(resB.remaining) != 5 {
|
if env.rewindable {
|
||||||
t.Fatalf("B: remaining = %d, want 5 (full re-eval due to holdback)", len(resB.remaining))
|
if len(resB.remaining) != 1 {
|
||||||
}
|
t.Fatalf("B: remaining = %d, want 1 (rewind to holdback point)", len(resB.remaining))
|
||||||
if resB.snapshotOffset != 4 {
|
}
|
||||||
t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset)
|
if resB.snapshotOffset != 0 {
|
||||||
|
t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if len(resB.remaining) != 5 {
|
||||||
|
t.Fatalf("B: remaining = %d, want 5 (freeAll fallback)", len(resB.remaining))
|
||||||
|
}
|
||||||
|
if resB.snapshotOffset != 4 {
|
||||||
|
t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})
|
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package mlxrunner
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -36,14 +37,69 @@ type Client struct {
|
|||||||
modelName string
|
modelName string
|
||||||
contextLength atomic.Int64
|
contextLength atomic.Int64
|
||||||
memory atomic.Uint64
|
memory atomic.Uint64
|
||||||
done chan error
|
done chan struct{}
|
||||||
|
doneErr error // valid after done is closed
|
||||||
client *http.Client
|
client *http.Client
|
||||||
lastErr string
|
status *statusWriter
|
||||||
lastErrLock sync.Mutex
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// statusWriter captures the last stderr line from the subprocess while
|
||||||
|
// forwarding all output to os.Stderr. Lines longer than maxStatusLen are
|
||||||
|
// truncated to the first maxStatusLen bytes.
|
||||||
|
type statusWriter struct {
|
||||||
|
lastErrMsg string
|
||||||
|
buf []byte
|
||||||
|
discarding bool
|
||||||
|
mu sync.Mutex
|
||||||
|
out *os.File
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxStatusLen = 256
|
||||||
|
|
||||||
|
func (w *statusWriter) Write(b []byte) (int, error) {
|
||||||
|
n, err := w.out.Write(b)
|
||||||
|
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
w.buf = append(w.buf, b...)
|
||||||
|
for {
|
||||||
|
i := bytes.IndexByte(w.buf, '\n')
|
||||||
|
if i < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !w.discarding {
|
||||||
|
line := bytes.TrimSpace(w.buf[:i])
|
||||||
|
if len(line) > 0 {
|
||||||
|
if len(line) > maxStatusLen {
|
||||||
|
line = line[:maxStatusLen]
|
||||||
|
}
|
||||||
|
w.lastErrMsg = string(line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.buf = w.buf[i+1:]
|
||||||
|
w.discarding = false
|
||||||
|
}
|
||||||
|
// if the buffer grows past maxStatusLen without a newline, keep the front
|
||||||
|
if len(w.buf) > maxStatusLen {
|
||||||
|
if !w.discarding {
|
||||||
|
w.lastErrMsg = string(bytes.TrimSpace(w.buf[:maxStatusLen]))
|
||||||
|
w.discarding = true
|
||||||
|
}
|
||||||
|
w.buf = w.buf[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *statusWriter) getLastErr() string {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
return w.lastErrMsg
|
||||||
|
}
|
||||||
|
|
||||||
// NewClient prepares a new MLX runner client for LLM models.
|
// NewClient prepares a new MLX runner client for LLM models.
|
||||||
// The subprocess is not started until Load() is called.
|
// The subprocess is not started until Load() is called.
|
||||||
func NewClient(modelName string) (*Client, error) {
|
func NewClient(modelName string) (*Client, error) {
|
||||||
@@ -53,7 +109,7 @@ func NewClient(modelName string) (*Client, error) {
|
|||||||
|
|
||||||
c := &Client{
|
c := &Client{
|
||||||
modelName: modelName,
|
modelName: modelName,
|
||||||
done: make(chan error, 1),
|
done: make(chan struct{}),
|
||||||
client: &http.Client{Timeout: 10 * time.Minute},
|
client: &http.Client{Timeout: 10 * time.Minute},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,12 +122,6 @@ func NewClient(modelName string) (*Client, error) {
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) getLastErr() string {
|
|
||||||
c.lastErrLock.Lock()
|
|
||||||
defer c.lastErrLock.Unlock()
|
|
||||||
return c.lastErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// WaitUntilRunning waits for the subprocess to be ready.
|
// WaitUntilRunning waits for the subprocess to be ready.
|
||||||
func (c *Client) WaitUntilRunning(ctx context.Context) error {
|
func (c *Client) WaitUntilRunning(ctx context.Context) error {
|
||||||
timeout := time.After(2 * time.Minute)
|
timeout := time.After(2 * time.Minute)
|
||||||
@@ -82,16 +132,14 @@ func (c *Client) WaitUntilRunning(ctx context.Context) error {
|
|||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case err := <-c.done:
|
case <-c.done:
|
||||||
errMsg := c.getLastErr()
|
if msg := c.status.getLastErr(); msg != "" {
|
||||||
if errMsg != "" {
|
return fmt.Errorf("mlx runner failed: %s (exit: %v)", msg, c.doneErr)
|
||||||
return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err)
|
|
||||||
}
|
}
|
||||||
return fmt.Errorf("mlx runner exited unexpectedly: %w", err)
|
return fmt.Errorf("mlx runner exited unexpectedly: %w", c.doneErr)
|
||||||
case <-timeout:
|
case <-timeout:
|
||||||
errMsg := c.getLastErr()
|
if msg := c.status.getLastErr(); msg != "" {
|
||||||
if errMsg != "" {
|
return fmt.Errorf("timeout waiting for mlx runner: %s", msg)
|
||||||
return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg)
|
|
||||||
}
|
}
|
||||||
return errors.New("timeout waiting for mlx runner to start")
|
return errors.New("timeout waiting for mlx runner to start")
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
@@ -182,6 +230,9 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
|
|
||||||
resp, err := c.client.Do(httpReq)
|
resp, err := c.client.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errMsg := c.status.getLastErr(); errMsg != "" {
|
||||||
|
return fmt.Errorf("mlx runner failed: %s", errMsg)
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
@@ -219,7 +270,13 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return scanner.Err()
|
if err := scanner.Err(); err != nil {
|
||||||
|
if errMsg := c.status.getLastErr(); errMsg != "" {
|
||||||
|
return fmt.Errorf("mlx runner failed: %s", errMsg)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) ContextLength() int {
|
func (c *Client) ContextLength() int {
|
||||||
@@ -348,18 +405,13 @@ func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo
|
|||||||
// Forward subprocess stdout/stderr to server logs
|
// Forward subprocess stdout/stderr to server logs
|
||||||
stdout, _ := cmd.StdoutPipe()
|
stdout, _ := cmd.StdoutPipe()
|
||||||
stderr, _ := cmd.StderrPipe()
|
stderr, _ := cmd.StderrPipe()
|
||||||
|
status := &statusWriter{out: os.Stderr}
|
||||||
|
c.status = status
|
||||||
go func() {
|
go func() {
|
||||||
io.Copy(os.Stderr, stdout) //nolint:errcheck
|
io.Copy(os.Stderr, stdout) //nolint:errcheck
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
scanner := bufio.NewScanner(stderr)
|
io.Copy(status, stderr) //nolint:errcheck
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Text()
|
|
||||||
fmt.Fprintln(os.Stderr, line)
|
|
||||||
c.lastErrLock.Lock()
|
|
||||||
c.lastErr = line
|
|
||||||
c.lastErrLock.Unlock()
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
slog.Info("starting mlx runner subprocess", "model", c.modelName, "port", c.port)
|
slog.Info("starting mlx runner subprocess", "model", c.modelName, "port", c.port)
|
||||||
@@ -369,8 +421,8 @@ func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo
|
|||||||
|
|
||||||
// Reap subprocess when it exits
|
// Reap subprocess when it exits
|
||||||
go func() {
|
go func() {
|
||||||
err := cmd.Wait()
|
c.doneErr = cmd.Wait()
|
||||||
c.done <- err
|
close(c.done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
|||||||
@@ -15,7 +15,9 @@ set(CMAKE_INSTALL_RPATH "@loader_path")
|
|||||||
|
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
|
|
||||||
set(MLX_C_GIT_TAG "v0.5.0" CACHE STRING "")
|
# Read MLX-C version from top-level file (shared with imagegen CMakeLists)
|
||||||
|
file(READ "${CMAKE_SOURCE_DIR}/MLX_C_VERSION" MLX_C_GIT_TAG)
|
||||||
|
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
||||||
|
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
mlx-c
|
mlx-c
|
||||||
|
|||||||
@@ -137,6 +137,9 @@ func Unpin(s ...*Array) {
|
|||||||
for _, t := range s {
|
for _, t := range s {
|
||||||
if t != nil {
|
if t != nil {
|
||||||
t.pinned--
|
t.pinned--
|
||||||
|
if t.pinned < 0 {
|
||||||
|
panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -261,7 +264,7 @@ func LogArrays() {
|
|||||||
|
|
||||||
for _, t := range arrays {
|
for _, t := range arrays {
|
||||||
nb := t.NumBytes()
|
nb := t.NumBytes()
|
||||||
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s %v", t.name, t.DType(), PrettyBytes(nb), t.Dims()))
|
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims()))
|
||||||
}
|
}
|
||||||
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
|
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s", len(arrays), PrettyBytes(ActiveMemory())))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,10 @@ var (
|
|||||||
gatedDeltaMetalKernelOnce sync.Once
|
gatedDeltaMetalKernelOnce sync.Once
|
||||||
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
|
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
|
||||||
gatedDeltaMetalDisabled bool
|
gatedDeltaMetalDisabled bool
|
||||||
|
|
||||||
|
gatedDeltaCUDAKernelOnce sync.Once
|
||||||
|
gatedDeltaCUDAKernel C.mlx_fast_cuda_kernel
|
||||||
|
gatedDeltaCUDADisabled bool
|
||||||
)
|
)
|
||||||
|
|
||||||
const gatedDeltaMetalKernelSource = `
|
const gatedDeltaMetalKernelSource = `
|
||||||
@@ -83,6 +87,86 @@ for (int i = 0; i < n_per_t; ++i) {
|
|||||||
}
|
}
|
||||||
`
|
`
|
||||||
|
|
||||||
|
const gatedDeltaCUDAKernelSource = `
|
||||||
|
auto tid_x = threadIdx.x;
|
||||||
|
auto tid_y = threadIdx.y;
|
||||||
|
auto grid_y = blockIdx.y * blockDim.y + tid_y;
|
||||||
|
auto grid_z = blockIdx.z;
|
||||||
|
|
||||||
|
int T_val = static_cast<int>(*T);
|
||||||
|
|
||||||
|
auto n = grid_z;
|
||||||
|
auto b_idx = n / Hv;
|
||||||
|
auto hv_idx = n % Hv;
|
||||||
|
auto hk_idx = hv_idx / (Hv / Hk);
|
||||||
|
constexpr int n_per_t = Dk / 32;
|
||||||
|
|
||||||
|
// q, k: [B, T, Hk, Dk]
|
||||||
|
auto q_ = q + b_idx * T_val * Hk * Dk + hk_idx * Dk;
|
||||||
|
auto k_ = k + b_idx * T_val * Hk * Dk + hk_idx * Dk;
|
||||||
|
|
||||||
|
// v, y: [B, T, Hv, Dv]
|
||||||
|
auto dv_idx = grid_y;
|
||||||
|
auto v_ = v + b_idx * T_val * Hv * Dv + hv_idx * Dv;
|
||||||
|
y += b_idx * T_val * Hv * Dv + hv_idx * Dv;
|
||||||
|
|
||||||
|
auto dk_idx = tid_x;
|
||||||
|
|
||||||
|
// state_in, state_out: [B, Hv, Dv, Dk]
|
||||||
|
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
|
||||||
|
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
|
||||||
|
|
||||||
|
float state[n_per_t];
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
state[i] = static_cast<float>(i_state[s_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// g: [B, T, Hv]
|
||||||
|
auto g_ = g + b_idx * T_val * Hv;
|
||||||
|
auto beta_ = beta + b_idx * T_val * Hv;
|
||||||
|
|
||||||
|
for (int t = 0; t < T_val; ++t) {
|
||||||
|
float kv_mem = 0.0f;
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
state[i] = state[i] * static_cast<float>(g_[hv_idx]);
|
||||||
|
kv_mem += state[i] * static_cast<float>(k_[s_idx]);
|
||||||
|
}
|
||||||
|
// Warp reduction (full warp, 32 threads in x)
|
||||||
|
for (int offset = 16; offset > 0; offset >>= 1)
|
||||||
|
kv_mem += __shfl_down_sync(0xffffffff, kv_mem, offset);
|
||||||
|
kv_mem = __shfl_sync(0xffffffff, kv_mem, 0);
|
||||||
|
|
||||||
|
auto delta = (static_cast<float>(v_[dv_idx]) - kv_mem) * static_cast<float>(beta_[hv_idx]);
|
||||||
|
|
||||||
|
float out = 0.0f;
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
state[i] = state[i] + static_cast<float>(k_[s_idx]) * delta;
|
||||||
|
out += state[i] * static_cast<float>(q_[s_idx]);
|
||||||
|
}
|
||||||
|
// Warp reduction
|
||||||
|
for (int offset = 16; offset > 0; offset >>= 1)
|
||||||
|
out += __shfl_down_sync(0xffffffff, out, offset);
|
||||||
|
if (tid_x == 0) {
|
||||||
|
y[dv_idx] = static_cast<InT>(out);
|
||||||
|
}
|
||||||
|
|
||||||
|
q_ += Hk * Dk;
|
||||||
|
k_ += Hk * Dk;
|
||||||
|
v_ += Hv * Dv;
|
||||||
|
y += Hv * Dv;
|
||||||
|
g_ += Hv;
|
||||||
|
beta_ += Hv;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n_per_t; ++i) {
|
||||||
|
auto s_idx = n_per_t * dk_idx + i;
|
||||||
|
o_state[s_idx] = static_cast<InT>(state[i]);
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
|
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
|
||||||
vec := C.mlx_vector_string_new()
|
vec := C.mlx_vector_string_new()
|
||||||
ok := true
|
ok := true
|
||||||
@@ -352,11 +436,184 @@ func gatedDeltaFallback(q, k, v, g, beta, state *Array) (y, nextState *Array) {
|
|||||||
return Concatenate(outs, 1), nextState
|
return Concatenate(outs, 1), nextState
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func initGatedDeltaCUDAKernel() {
|
||||||
|
var cudaAvail C.bool
|
||||||
|
if C.mlx_cuda_is_available(&cudaAvail) != 0 || !bool(cudaAvail) {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
|
||||||
|
if !ok {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
freeInputs()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer freeInputs()
|
||||||
|
|
||||||
|
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
|
||||||
|
if !ok {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
freeOutputs()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer freeOutputs()
|
||||||
|
|
||||||
|
cName := C.CString("gated_delta_step")
|
||||||
|
defer C.free(unsafe.Pointer(cName))
|
||||||
|
cSource := C.CString(gatedDeltaCUDAKernelSource)
|
||||||
|
defer C.free(unsafe.Pointer(cSource))
|
||||||
|
cHeader := C.CString("")
|
||||||
|
defer C.free(unsafe.Pointer(cHeader))
|
||||||
|
|
||||||
|
gatedDeltaCUDAKernel = C.mlx_fast_cuda_kernel_new(
|
||||||
|
cName,
|
||||||
|
inputs,
|
||||||
|
outputs,
|
||||||
|
cSource,
|
||||||
|
cHeader,
|
||||||
|
C.bool(true),
|
||||||
|
C.int(0),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
|
||||||
|
if gatedDeltaCUDADisabled {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
qd := q.Dims()
|
||||||
|
kd := k.Dims()
|
||||||
|
vd := v.Dims()
|
||||||
|
gd := g.Dims()
|
||||||
|
bd := beta.Dims()
|
||||||
|
sd := state.Dims()
|
||||||
|
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
|
||||||
|
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
Hv, Dv := vd[2], vd[3]
|
||||||
|
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if gd[0] != B || gd[1] != T || gd[2] != Hv {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if bd[0] != B || bd[1] != T || bd[2] != Hv {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
dtype := q.DType()
|
||||||
|
if k.DType() != dtype || v.DType() != dtype || g.DType() != dtype || beta.DType() != dtype || state.DType() != dtype {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
gatedDeltaCUDAKernelOnce.Do(initGatedDeltaCUDAKernel)
|
||||||
|
if gatedDeltaCUDADisabled {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := C.mlx_fast_cuda_kernel_config_new()
|
||||||
|
defer C.mlx_fast_cuda_kernel_config_free(cfg)
|
||||||
|
|
||||||
|
cInT := C.CString("InT")
|
||||||
|
defer C.free(unsafe.Pointer(cInT))
|
||||||
|
if C.mlx_fast_cuda_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(dtype)) != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
for _, tpl := range []struct {
|
||||||
|
name string
|
||||||
|
value int
|
||||||
|
}{
|
||||||
|
{name: "Dk", value: Dk},
|
||||||
|
{name: "Dv", value: Dv},
|
||||||
|
{name: "Hk", value: Hk},
|
||||||
|
{name: "Hv", value: Hv},
|
||||||
|
} {
|
||||||
|
cn := C.CString(tpl.name)
|
||||||
|
rc := C.mlx_fast_cuda_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
|
||||||
|
C.free(unsafe.Pointer(cn))
|
||||||
|
if rc != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
|
||||||
|
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
|
||||||
|
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(dtype)) != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(dtype)) != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if C.mlx_fast_cuda_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
threadY := Dv
|
||||||
|
if threadY > 4 {
|
||||||
|
threadY = 4
|
||||||
|
}
|
||||||
|
if C.mlx_fast_cuda_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
tScalar := FromValue(T)
|
||||||
|
inputs := []C.mlx_array{
|
||||||
|
q.ctx,
|
||||||
|
k.ctx,
|
||||||
|
v.ctx,
|
||||||
|
g.ctx,
|
||||||
|
beta.ctx,
|
||||||
|
state.ctx,
|
||||||
|
tScalar.ctx,
|
||||||
|
}
|
||||||
|
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
|
||||||
|
defer C.mlx_vector_array_free(inVec)
|
||||||
|
|
||||||
|
outVec := C.mlx_vector_array_new()
|
||||||
|
defer C.mlx_vector_array_free(outVec)
|
||||||
|
if C.mlx_fast_cuda_kernel_apply(&outVec, gatedDeltaCUDAKernel, inVec, cfg, DefaultStream().ctx) != 0 {
|
||||||
|
gatedDeltaCUDADisabled = true
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
if int(C.mlx_vector_array_size(outVec)) < 2 {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
y = New("GATED_DELTA_CUDA_Y")
|
||||||
|
nextState = New("GATED_DELTA_CUDA_STATE")
|
||||||
|
C.mlx_vector_array_get(&y.ctx, outVec, 0)
|
||||||
|
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
|
||||||
|
return y, nextState, true
|
||||||
|
}
|
||||||
|
|
||||||
// GatedDelta runs the recurrent update operation.
|
// GatedDelta runs the recurrent update operation.
|
||||||
//
|
//
|
||||||
// It uses the fused Metal kernel when available and otherwise falls back to a
|
// It tries the fused CUDA kernel first, then Metal, then falls back to a
|
||||||
// backend-agnostic MLX implementation with identical inputs/outputs.
|
// backend-agnostic MLX implementation with identical inputs/outputs.
|
||||||
func GatedDelta(q, k, v, g, beta, state *Array) (y, nextState *Array) {
|
func GatedDelta(q, k, v, g, beta, state *Array) (y, nextState *Array) {
|
||||||
|
if y, nextState, ok := gatedDeltaCUDAKernelApply(q, k, v, g, beta, state); ok {
|
||||||
|
return y, nextState
|
||||||
|
}
|
||||||
if y, nextState, ok := gatedDeltaKernel(q, k, v, g, beta, state); ok {
|
if y, nextState, ok := gatedDeltaKernel(q, k, v, g, beta, state); ok {
|
||||||
return y, nextState
|
return y, nextState
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -326,8 +326,10 @@ int (*mlx_distributed_sum_scatter_)(
|
|||||||
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
|
int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
|
||||||
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
|
int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
|
||||||
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
|
mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
|
||||||
bool (*mlx_distributed_is_available_)(void) = NULL;
|
bool (*mlx_distributed_is_available_)(const char* bk /* may be null */) = NULL;
|
||||||
mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
|
mlx_distributed_group (*mlx_distributed_init_)(
|
||||||
|
bool strict,
|
||||||
|
const char* bk /* may be null */) = NULL;
|
||||||
void (*mlx_set_error_handler_)(
|
void (*mlx_set_error_handler_)(
|
||||||
mlx_error_handler_func handler,
|
mlx_error_handler_func handler,
|
||||||
void* data,
|
void* data,
|
||||||
@@ -924,6 +926,7 @@ int (*mlx_astype_)(
|
|||||||
int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_bartlett_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
int (*mlx_bitwise_and_)(
|
int (*mlx_bitwise_and_)(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -940,6 +943,7 @@ int (*mlx_bitwise_xor_)(
|
|||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
const mlx_array b,
|
const mlx_array b,
|
||||||
const mlx_stream s) = NULL;
|
const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_blackman_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
int (*mlx_block_masked_mm_)(
|
int (*mlx_block_masked_mm_)(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -1120,6 +1124,7 @@ int (*mlx_dequantize_)(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
mlx_optional_dtype dtype,
|
mlx_optional_dtype dtype,
|
||||||
const mlx_stream s) = NULL;
|
const mlx_stream s) = NULL;
|
||||||
int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
|
int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
|
||||||
@@ -1256,6 +1261,8 @@ int (*mlx_hadamard_transform_)(
|
|||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
mlx_optional_float scale,
|
mlx_optional_float scale,
|
||||||
const mlx_stream s) = NULL;
|
const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_hamming_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
|
int (*mlx_hanning_)(mlx_array* res, int M, const mlx_stream s) = NULL;
|
||||||
int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
|
int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
|
||||||
int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
|
||||||
int (*mlx_inner_)(
|
int (*mlx_inner_)(
|
||||||
@@ -1548,6 +1555,8 @@ int (*mlx_qqmm_)(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale_x /* may be null */,
|
||||||
|
const mlx_array global_scale_w /* may be null */,
|
||||||
const mlx_stream s) = NULL;
|
const mlx_stream s) = NULL;
|
||||||
int (*mlx_quantize_)(
|
int (*mlx_quantize_)(
|
||||||
mlx_vector_array* res,
|
mlx_vector_array* res,
|
||||||
@@ -1555,6 +1564,7 @@ int (*mlx_quantize_)(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
const mlx_stream s) = NULL;
|
const mlx_stream s) = NULL;
|
||||||
int (*mlx_quantized_matmul_)(
|
int (*mlx_quantized_matmul_)(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
@@ -2550,10 +2560,12 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
|||||||
CHECK_LOAD(handle, mlx_atleast_1d);
|
CHECK_LOAD(handle, mlx_atleast_1d);
|
||||||
CHECK_LOAD(handle, mlx_atleast_2d);
|
CHECK_LOAD(handle, mlx_atleast_2d);
|
||||||
CHECK_LOAD(handle, mlx_atleast_3d);
|
CHECK_LOAD(handle, mlx_atleast_3d);
|
||||||
|
CHECK_LOAD(handle, mlx_bartlett);
|
||||||
CHECK_LOAD(handle, mlx_bitwise_and);
|
CHECK_LOAD(handle, mlx_bitwise_and);
|
||||||
CHECK_LOAD(handle, mlx_bitwise_invert);
|
CHECK_LOAD(handle, mlx_bitwise_invert);
|
||||||
CHECK_LOAD(handle, mlx_bitwise_or);
|
CHECK_LOAD(handle, mlx_bitwise_or);
|
||||||
CHECK_LOAD(handle, mlx_bitwise_xor);
|
CHECK_LOAD(handle, mlx_bitwise_xor);
|
||||||
|
CHECK_LOAD(handle, mlx_blackman);
|
||||||
CHECK_LOAD(handle, mlx_block_masked_mm);
|
CHECK_LOAD(handle, mlx_block_masked_mm);
|
||||||
CHECK_LOAD(handle, mlx_broadcast_arrays);
|
CHECK_LOAD(handle, mlx_broadcast_arrays);
|
||||||
CHECK_LOAD(handle, mlx_broadcast_to);
|
CHECK_LOAD(handle, mlx_broadcast_to);
|
||||||
@@ -2606,6 +2618,8 @@ int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
|||||||
CHECK_LOAD(handle, mlx_greater);
|
CHECK_LOAD(handle, mlx_greater);
|
||||||
CHECK_LOAD(handle, mlx_greater_equal);
|
CHECK_LOAD(handle, mlx_greater_equal);
|
||||||
CHECK_LOAD(handle, mlx_hadamard_transform);
|
CHECK_LOAD(handle, mlx_hadamard_transform);
|
||||||
|
CHECK_LOAD(handle, mlx_hamming);
|
||||||
|
CHECK_LOAD(handle, mlx_hanning);
|
||||||
CHECK_LOAD(handle, mlx_identity);
|
CHECK_LOAD(handle, mlx_identity);
|
||||||
CHECK_LOAD(handle, mlx_imag);
|
CHECK_LOAD(handle, mlx_imag);
|
||||||
CHECK_LOAD(handle, mlx_inner);
|
CHECK_LOAD(handle, mlx_inner);
|
||||||
|
|||||||
@@ -300,10 +300,12 @@
|
|||||||
#define mlx_atleast_1d mlx_atleast_1d_mlx_gen_orig_
|
#define mlx_atleast_1d mlx_atleast_1d_mlx_gen_orig_
|
||||||
#define mlx_atleast_2d mlx_atleast_2d_mlx_gen_orig_
|
#define mlx_atleast_2d mlx_atleast_2d_mlx_gen_orig_
|
||||||
#define mlx_atleast_3d mlx_atleast_3d_mlx_gen_orig_
|
#define mlx_atleast_3d mlx_atleast_3d_mlx_gen_orig_
|
||||||
|
#define mlx_bartlett mlx_bartlett_mlx_gen_orig_
|
||||||
#define mlx_bitwise_and mlx_bitwise_and_mlx_gen_orig_
|
#define mlx_bitwise_and mlx_bitwise_and_mlx_gen_orig_
|
||||||
#define mlx_bitwise_invert mlx_bitwise_invert_mlx_gen_orig_
|
#define mlx_bitwise_invert mlx_bitwise_invert_mlx_gen_orig_
|
||||||
#define mlx_bitwise_or mlx_bitwise_or_mlx_gen_orig_
|
#define mlx_bitwise_or mlx_bitwise_or_mlx_gen_orig_
|
||||||
#define mlx_bitwise_xor mlx_bitwise_xor_mlx_gen_orig_
|
#define mlx_bitwise_xor mlx_bitwise_xor_mlx_gen_orig_
|
||||||
|
#define mlx_blackman mlx_blackman_mlx_gen_orig_
|
||||||
#define mlx_block_masked_mm mlx_block_masked_mm_mlx_gen_orig_
|
#define mlx_block_masked_mm mlx_block_masked_mm_mlx_gen_orig_
|
||||||
#define mlx_broadcast_arrays mlx_broadcast_arrays_mlx_gen_orig_
|
#define mlx_broadcast_arrays mlx_broadcast_arrays_mlx_gen_orig_
|
||||||
#define mlx_broadcast_to mlx_broadcast_to_mlx_gen_orig_
|
#define mlx_broadcast_to mlx_broadcast_to_mlx_gen_orig_
|
||||||
@@ -356,6 +358,8 @@
|
|||||||
#define mlx_greater mlx_greater_mlx_gen_orig_
|
#define mlx_greater mlx_greater_mlx_gen_orig_
|
||||||
#define mlx_greater_equal mlx_greater_equal_mlx_gen_orig_
|
#define mlx_greater_equal mlx_greater_equal_mlx_gen_orig_
|
||||||
#define mlx_hadamard_transform mlx_hadamard_transform_mlx_gen_orig_
|
#define mlx_hadamard_transform mlx_hadamard_transform_mlx_gen_orig_
|
||||||
|
#define mlx_hamming mlx_hamming_mlx_gen_orig_
|
||||||
|
#define mlx_hanning mlx_hanning_mlx_gen_orig_
|
||||||
#define mlx_identity mlx_identity_mlx_gen_orig_
|
#define mlx_identity mlx_identity_mlx_gen_orig_
|
||||||
#define mlx_imag mlx_imag_mlx_gen_orig_
|
#define mlx_imag mlx_imag_mlx_gen_orig_
|
||||||
#define mlx_inner mlx_inner_mlx_gen_orig_
|
#define mlx_inner mlx_inner_mlx_gen_orig_
|
||||||
@@ -889,10 +893,12 @@
|
|||||||
#undef mlx_atleast_1d
|
#undef mlx_atleast_1d
|
||||||
#undef mlx_atleast_2d
|
#undef mlx_atleast_2d
|
||||||
#undef mlx_atleast_3d
|
#undef mlx_atleast_3d
|
||||||
|
#undef mlx_bartlett
|
||||||
#undef mlx_bitwise_and
|
#undef mlx_bitwise_and
|
||||||
#undef mlx_bitwise_invert
|
#undef mlx_bitwise_invert
|
||||||
#undef mlx_bitwise_or
|
#undef mlx_bitwise_or
|
||||||
#undef mlx_bitwise_xor
|
#undef mlx_bitwise_xor
|
||||||
|
#undef mlx_blackman
|
||||||
#undef mlx_block_masked_mm
|
#undef mlx_block_masked_mm
|
||||||
#undef mlx_broadcast_arrays
|
#undef mlx_broadcast_arrays
|
||||||
#undef mlx_broadcast_to
|
#undef mlx_broadcast_to
|
||||||
@@ -945,6 +951,8 @@
|
|||||||
#undef mlx_greater
|
#undef mlx_greater
|
||||||
#undef mlx_greater_equal
|
#undef mlx_greater_equal
|
||||||
#undef mlx_hadamard_transform
|
#undef mlx_hadamard_transform
|
||||||
|
#undef mlx_hamming
|
||||||
|
#undef mlx_hanning
|
||||||
#undef mlx_identity
|
#undef mlx_identity
|
||||||
#undef mlx_imag
|
#undef mlx_imag
|
||||||
#undef mlx_inner
|
#undef mlx_inner
|
||||||
@@ -1501,8 +1509,10 @@ extern int (*mlx_distributed_sum_scatter_)(
|
|||||||
extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group);
|
extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group);
|
||||||
extern int (*mlx_distributed_group_size_)(mlx_distributed_group group);
|
extern int (*mlx_distributed_group_size_)(mlx_distributed_group group);
|
||||||
extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key);
|
extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key);
|
||||||
extern bool (*mlx_distributed_is_available_)(void);
|
extern bool (*mlx_distributed_is_available_)(const char* bk /* may be null */);
|
||||||
extern mlx_distributed_group (*mlx_distributed_init_)(bool strict);
|
extern mlx_distributed_group (*mlx_distributed_init_)(
|
||||||
|
bool strict,
|
||||||
|
const char* bk /* may be null */);
|
||||||
extern void (*mlx_set_error_handler_)(
|
extern void (*mlx_set_error_handler_)(
|
||||||
mlx_error_handler_func handler,
|
mlx_error_handler_func handler,
|
||||||
void* data,
|
void* data,
|
||||||
@@ -2099,6 +2109,7 @@ extern int (*mlx_astype_)(
|
|||||||
extern int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
|
extern int (*mlx_bartlett_)(mlx_array* res, int M, const mlx_stream s);
|
||||||
extern int (*mlx_bitwise_and_)(
|
extern int (*mlx_bitwise_and_)(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -2115,6 +2126,7 @@ extern int (*mlx_bitwise_xor_)(
|
|||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
const mlx_array b,
|
const mlx_array b,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
|
extern int (*mlx_blackman_)(mlx_array* res, int M, const mlx_stream s);
|
||||||
extern int (*mlx_block_masked_mm_)(
|
extern int (*mlx_block_masked_mm_)(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -2295,6 +2307,7 @@ extern int (*mlx_dequantize_)(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
mlx_optional_dtype dtype,
|
mlx_optional_dtype dtype,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
extern int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
extern int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||||
@@ -2431,6 +2444,8 @@ extern int (*mlx_hadamard_transform_)(
|
|||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
mlx_optional_float scale,
|
mlx_optional_float scale,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
|
extern int (*mlx_hamming_)(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
extern int (*mlx_hanning_)(mlx_array* res, int M, const mlx_stream s);
|
||||||
extern int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
extern int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||||
extern int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
extern int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
extern int (*mlx_inner_)(
|
extern int (*mlx_inner_)(
|
||||||
@@ -2723,6 +2738,8 @@ extern int (*mlx_qqmm_)(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale_x /* may be null */,
|
||||||
|
const mlx_array global_scale_w /* may be null */,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
extern int (*mlx_quantize_)(
|
extern int (*mlx_quantize_)(
|
||||||
mlx_vector_array* res,
|
mlx_vector_array* res,
|
||||||
@@ -2730,6 +2747,7 @@ extern int (*mlx_quantize_)(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
extern int (*mlx_quantized_matmul_)(
|
extern int (*mlx_quantized_matmul_)(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
@@ -4033,11 +4051,13 @@ static inline int mlx_distributed_group_size(mlx_distributed_group group) {
|
|||||||
static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) {
|
static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) {
|
||||||
return mlx_distributed_group_split_(group, color, key);
|
return mlx_distributed_group_split_(group, color, key);
|
||||||
}
|
}
|
||||||
static inline bool mlx_distributed_is_available(void) {
|
static inline bool mlx_distributed_is_available(const char* bk /* may be null */) {
|
||||||
return mlx_distributed_is_available_();
|
return mlx_distributed_is_available_(bk);
|
||||||
}
|
}
|
||||||
static inline mlx_distributed_group mlx_distributed_init(bool strict) {
|
static inline mlx_distributed_group mlx_distributed_init(
|
||||||
return mlx_distributed_init_(strict);
|
bool strict,
|
||||||
|
const char* bk /* may be null */) {
|
||||||
|
return mlx_distributed_init_(strict, bk);
|
||||||
}
|
}
|
||||||
static inline void mlx_set_error_handler(
|
static inline void mlx_set_error_handler(
|
||||||
mlx_error_handler_func handler,
|
mlx_error_handler_func handler,
|
||||||
@@ -4939,6 +4959,9 @@ static inline int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_st
|
|||||||
static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
|
||||||
return mlx_atleast_3d_(res, a, s);
|
return mlx_atleast_3d_(res, a, s);
|
||||||
}
|
}
|
||||||
|
static inline int mlx_bartlett(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_bartlett_(res, M, s);
|
||||||
|
}
|
||||||
static inline int mlx_bitwise_and(
|
static inline int mlx_bitwise_and(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -4963,6 +4986,9 @@ static inline int mlx_bitwise_xor(
|
|||||||
const mlx_stream s) {
|
const mlx_stream s) {
|
||||||
return mlx_bitwise_xor_(res, a, b, s);
|
return mlx_bitwise_xor_(res, a, b, s);
|
||||||
}
|
}
|
||||||
|
static inline int mlx_blackman(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_blackman_(res, M, s);
|
||||||
|
}
|
||||||
static inline int mlx_block_masked_mm(
|
static inline int mlx_block_masked_mm(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -5193,9 +5219,10 @@ static inline int mlx_dequantize(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
mlx_optional_dtype dtype,
|
mlx_optional_dtype dtype,
|
||||||
const mlx_stream s) {
|
const mlx_stream s) {
|
||||||
return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, dtype, s);
|
return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, global_scale, dtype, s);
|
||||||
}
|
}
|
||||||
static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
|
static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
|
||||||
return mlx_diag_(res, a, k, s);
|
return mlx_diag_(res, a, k, s);
|
||||||
@@ -5383,6 +5410,12 @@ static inline int mlx_hadamard_transform(
|
|||||||
const mlx_stream s) {
|
const mlx_stream s) {
|
||||||
return mlx_hadamard_transform_(res, a, scale, s);
|
return mlx_hadamard_transform_(res, a, scale, s);
|
||||||
}
|
}
|
||||||
|
static inline int mlx_hamming(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_hamming_(res, M, s);
|
||||||
|
}
|
||||||
|
static inline int mlx_hanning(mlx_array* res, int M, const mlx_stream s) {
|
||||||
|
return mlx_hanning_(res, M, s);
|
||||||
|
}
|
||||||
static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
|
static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
|
||||||
return mlx_identity_(res, n, dtype, s);
|
return mlx_identity_(res, n, dtype, s);
|
||||||
}
|
}
|
||||||
@@ -5793,8 +5826,10 @@ static inline int mlx_qqmm(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale_x /* may be null */,
|
||||||
|
const mlx_array global_scale_w /* may be null */,
|
||||||
const mlx_stream s) {
|
const mlx_stream s) {
|
||||||
return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, s);
|
return mlx_qqmm_(res, x, w, w_scales, group_size, bits, mode, global_scale_x, global_scale_w, s);
|
||||||
}
|
}
|
||||||
static inline int mlx_quantize(
|
static inline int mlx_quantize(
|
||||||
mlx_vector_array* res,
|
mlx_vector_array* res,
|
||||||
@@ -5802,8 +5837,9 @@ static inline int mlx_quantize(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
const mlx_stream s) {
|
const mlx_stream s) {
|
||||||
return mlx_quantize_(res, w, group_size, bits, mode, s);
|
return mlx_quantize_(res, w, group_size, bits, mode, global_scale, s);
|
||||||
}
|
}
|
||||||
static inline int mlx_quantized_matmul(
|
static inline int mlx_quantized_matmul(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Vendored MLX-C Headers
|
# Vendored MLX-C Headers
|
||||||
|
|
||||||
These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c).
|
These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c).
|
||||||
The pinned version is in `MLX_VERSION` at the repo root.
|
The pinned version is in `MLX_C_VERSION` at the repo root.
|
||||||
|
|
||||||
Headers are automatically refreshed when you run a CMake build:
|
Headers are automatically refreshed when you run a CMake build:
|
||||||
|
|
||||||
|
|||||||
@@ -42,12 +42,14 @@ mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
|
|||||||
/**
|
/**
|
||||||
* Check if distributed is available.
|
* Check if distributed is available.
|
||||||
*/
|
*/
|
||||||
bool mlx_distributed_is_available(void);
|
bool mlx_distributed_is_available(const char* bk /* may be null */);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialize distributed.
|
* Initialize distributed.
|
||||||
*/
|
*/
|
||||||
mlx_distributed_group mlx_distributed_init(bool strict);
|
mlx_distributed_group mlx_distributed_init(
|
||||||
|
bool strict,
|
||||||
|
const char* bk /* may be null */);
|
||||||
|
|
||||||
/**@}*/
|
/**@}*/
|
||||||
|
|
||||||
|
|||||||
@@ -166,6 +166,7 @@ int mlx_astype(
|
|||||||
int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
|
int mlx_bartlett(mlx_array* res, int M, const mlx_stream s);
|
||||||
int mlx_bitwise_and(
|
int mlx_bitwise_and(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -182,6 +183,7 @@ int mlx_bitwise_xor(
|
|||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
const mlx_array b,
|
const mlx_array b,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
|
int mlx_blackman(mlx_array* res, int M, const mlx_stream s);
|
||||||
int mlx_block_masked_mm(
|
int mlx_block_masked_mm(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
@@ -362,6 +364,7 @@ int mlx_dequantize(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
mlx_optional_dtype dtype,
|
mlx_optional_dtype dtype,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
|
||||||
@@ -498,6 +501,8 @@ int mlx_hadamard_transform(
|
|||||||
const mlx_array a,
|
const mlx_array a,
|
||||||
mlx_optional_float scale,
|
mlx_optional_float scale,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
|
int mlx_hamming(mlx_array* res, int M, const mlx_stream s);
|
||||||
|
int mlx_hanning(mlx_array* res, int M, const mlx_stream s);
|
||||||
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
|
||||||
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
|
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||||
int mlx_inner(
|
int mlx_inner(
|
||||||
@@ -790,6 +795,8 @@ int mlx_qqmm(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale_x /* may be null */,
|
||||||
|
const mlx_array global_scale_w /* may be null */,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
int mlx_quantize(
|
int mlx_quantize(
|
||||||
mlx_vector_array* res,
|
mlx_vector_array* res,
|
||||||
@@ -797,6 +804,7 @@ int mlx_quantize(
|
|||||||
mlx_optional_int group_size,
|
mlx_optional_int group_size,
|
||||||
mlx_optional_int bits,
|
mlx_optional_int bits,
|
||||||
const char* mode,
|
const char* mode,
|
||||||
|
const mlx_array global_scale /* may be null */,
|
||||||
const mlx_stream s);
|
const mlx_stream s);
|
||||||
int mlx_quantized_matmul(
|
int mlx_quantized_matmul(
|
||||||
mlx_array* res,
|
mlx_array* res,
|
||||||
|
|||||||
@@ -7,8 +7,44 @@ package mlx
|
|||||||
// #cgo LDFLAGS: -lstdc++
|
// #cgo LDFLAGS: -lstdc++
|
||||||
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
|
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
|
||||||
// #include "generated.h"
|
// #include "generated.h"
|
||||||
|
// #include <string.h>
|
||||||
|
//
|
||||||
|
// static char _mlx_last_error_msg[1024] = {0};
|
||||||
|
// static int _mlx_last_error_flag = 0;
|
||||||
|
//
|
||||||
|
// static void _mlx_capture_error_handler(const char* msg, void* data) {
|
||||||
|
// (void)data;
|
||||||
|
// strncpy(_mlx_last_error_msg, msg, sizeof(_mlx_last_error_msg) - 1);
|
||||||
|
// _mlx_last_error_msg[sizeof(_mlx_last_error_msg) - 1] = '\0';
|
||||||
|
// _mlx_last_error_flag = 1;
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// static void mlx_install_capture_handler(void) {
|
||||||
|
// if (mlx_set_error_handler_) {
|
||||||
|
// mlx_set_error_handler_(_mlx_capture_error_handler, NULL, NULL);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// static void mlx_clear_last_error(void) {
|
||||||
|
// _mlx_last_error_flag = 0;
|
||||||
|
// _mlx_last_error_msg[0] = '\0';
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// static int mlx_had_last_error(void) {
|
||||||
|
// return _mlx_last_error_flag;
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// static const char* mlx_get_last_error(void) {
|
||||||
|
// return _mlx_last_error_flag ? _mlx_last_error_msg : NULL;
|
||||||
|
// }
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Replace the default exit(-1) error handler with one that captures
|
||||||
|
// the error message so we can surface it in Go.
|
||||||
|
C.mlx_install_capture_handler()
|
||||||
|
}
|
||||||
|
|
||||||
// Version returns the MLX core library version string.
|
// Version returns the MLX core library version string.
|
||||||
func Version() string {
|
func Version() string {
|
||||||
str := C.mlx_string_new()
|
str := C.mlx_string_new()
|
||||||
@@ -31,10 +67,19 @@ func doEval(outputs []*Array, async bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
C.mlx_clear_last_error()
|
||||||
|
var rc C.int
|
||||||
if async {
|
if async {
|
||||||
C.mlx_async_eval(vector)
|
rc = C.mlx_async_eval(vector)
|
||||||
} else {
|
} else {
|
||||||
C.mlx_eval(vector)
|
rc = C.mlx_eval(vector)
|
||||||
|
}
|
||||||
|
if rc != 0 {
|
||||||
|
msg := "mlx eval failed"
|
||||||
|
if C.mlx_had_last_error() != 0 {
|
||||||
|
msg = C.GoString(C.mlx_get_last_error())
|
||||||
|
}
|
||||||
|
panic("mlx: " + msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,8 @@ func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, bias
|
|||||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||||
res := C.mlx_vector_array_new()
|
res := C.mlx_vector_array_new()
|
||||||
defer C.mlx_vector_array_free(res)
|
defer C.mlx_vector_array_free(res)
|
||||||
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, DefaultStream().ctx)
|
var globalScale C.mlx_array
|
||||||
|
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, globalScale, DefaultStream().ctx)
|
||||||
|
|
||||||
vecSize := int(C.mlx_vector_array_size(res))
|
vecSize := int(C.mlx_vector_array_size(res))
|
||||||
w0 := New("QUANTIZE_W")
|
w0 := New("QUANTIZE_W")
|
||||||
@@ -45,7 +46,8 @@ func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Arr
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := New("DEQUANTIZE")
|
out := New("DEQUANTIZE")
|
||||||
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx)
|
var globalScale C.mlx_array
|
||||||
|
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, globalScale, optDtype, DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user