ci: include mlx jit headers on linux (#15083)

* ci: include mlx jit headers on linux

* handle CUDA JIT headers
This commit is contained in:
Daniel Hiltgen
2026-03-26 23:10:07 -07:00
committed by GitHub
parent f567abc63f
commit 516ebd8548
3 changed files with 83 additions and 4 deletions

View File

@@ -400,6 +400,21 @@ func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo
slog.Debug("mlx subprocess library path", libPathEnvVar, pathEnvVal)
}
// Point MLX's JIT compiler at our bundled CUDA runtime headers.
// MLX resolves headers via $CUDA_PATH/include/*.h (and checks CUDA_HOME first).
// Always use bundled headers to avoid version mismatches with any
// system-installed CUDA toolkit.
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_cuda_*")); err == nil {
for _, d := range mlxDirs {
if _, err := os.Stat(filepath.Join(d, "include")); err == nil {
setEnv(cmd, "CUDA_PATH", d)
setEnv(cmd, "CUDA_HOME", d)
slog.Debug("mlx subprocess CUDA headers", "CUDA_PATH", d)
break
}
}
}
c.cmd = cmd
// Forward subprocess stdout/stderr to server logs
@@ -519,3 +534,16 @@ func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
}
var _ llm.LlamaServer = (*Client)(nil)
// setEnv sets or replaces an environment variable in cmd.Env.
func setEnv(cmd *exec.Cmd, key, value string) {
entry := key + "=" + value
prefix := strings.ToUpper(key + "=")
for i, e := range cmd.Env {
if strings.HasPrefix(strings.ToUpper(e), prefix) {
cmd.Env[i] = entry
return
}
}
cmd.Env = append(cmd.Env, entry)
}