mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 09:54:18 +02:00
This change allows importing bf16 and converting to mxfp4/mxfp8/nvfp4 and also importing fp8 and converting directly to mxfp8.
150 lines
3.5 KiB
Go
150 lines
3.5 KiB
Go
package mlx
|
|
|
|
// #include "generated.h"
|
|
import "C"
|
|
|
|
import (
|
|
"fmt"
|
|
"iter"
|
|
"runtime"
|
|
"unsafe"
|
|
)
|
|
|
|
// SafetensorsFile represents a loaded safetensors file.
|
|
type SafetensorsFile struct {
|
|
arrays C.mlx_map_string_to_array
|
|
metadata C.mlx_map_string_to_string
|
|
}
|
|
|
|
func loadSafetensorsStream() C.mlx_stream {
|
|
if runtime.GOOS == "darwin" {
|
|
return C.mlx_default_cpu_stream_new()
|
|
}
|
|
return C.mlx_default_gpu_stream_new()
|
|
}
|
|
|
|
// LoadSafetensorsNative loads a safetensors file using MLX's native loader.
|
|
func LoadSafetensorsNative(path string) (*SafetensorsFile, error) {
|
|
var arrays C.mlx_map_string_to_array
|
|
var metadata C.mlx_map_string_to_string
|
|
|
|
cPath := C.CString(path)
|
|
defer C.free(unsafe.Pointer(cPath))
|
|
|
|
stream := loadSafetensorsStream()
|
|
defer C.mlx_stream_free(stream)
|
|
|
|
if C.mlx_load_safetensors(&arrays, &metadata, cPath, stream) != 0 {
|
|
return nil, fmt.Errorf("failed to load safetensors: %s", path)
|
|
}
|
|
|
|
return &SafetensorsFile{arrays: arrays, metadata: metadata}, nil
|
|
}
|
|
|
|
// Get retrieves a tensor by name.
|
|
func (s *SafetensorsFile) Get(name string) *Array {
|
|
cName := C.CString(name)
|
|
defer C.free(unsafe.Pointer(cName))
|
|
|
|
value := C.mlx_array_new()
|
|
if C.mlx_map_string_to_array_get(&value, s.arrays, cName) != 0 {
|
|
return nil
|
|
}
|
|
if value.ctx == nil {
|
|
return nil
|
|
}
|
|
|
|
arr := New(name)
|
|
arr.ctx = value
|
|
return arr
|
|
}
|
|
|
|
// GetMetadata retrieves a metadata value by key.
|
|
func (s *SafetensorsFile) GetMetadata(key string) string {
|
|
cKey := C.CString(key)
|
|
defer C.free(unsafe.Pointer(cKey))
|
|
|
|
var cValue *C.char
|
|
if C.mlx_map_string_to_string_get(&cValue, s.metadata, cKey) != 0 {
|
|
return ""
|
|
}
|
|
return C.GoString(cValue)
|
|
}
|
|
|
|
// Free releases the loaded safetensors maps.
|
|
func (s *SafetensorsFile) Free() {
|
|
if s == nil {
|
|
return
|
|
}
|
|
C.mlx_map_string_to_array_free(s.arrays)
|
|
C.mlx_map_string_to_string_free(s.metadata)
|
|
}
|
|
|
|
func Load(path string) iter.Seq2[string, *Array] {
|
|
return func(yield func(string, *Array) bool) {
|
|
sf, err := LoadSafetensorsNative(path)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer sf.Free()
|
|
|
|
it := C.mlx_map_string_to_array_iterator_new(sf.arrays)
|
|
defer C.mlx_map_string_to_array_iterator_free(it)
|
|
|
|
for {
|
|
var key *C.char
|
|
value := C.mlx_array_new()
|
|
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
|
|
break
|
|
}
|
|
|
|
name := C.GoString(key)
|
|
arr := New(name)
|
|
arr.ctx = value
|
|
if !yield(name, arr) {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// SaveSafetensors saves arrays to a safetensors file without metadata.
|
|
func SaveSafetensors(path string, arrays map[string]*Array) error {
|
|
return SaveSafetensorsWithMetadata(path, arrays, nil)
|
|
}
|
|
|
|
// SaveSafetensorsWithMetadata saves arrays to a safetensors file with metadata.
|
|
func SaveSafetensorsWithMetadata(path string, arrays map[string]*Array, metadata map[string]string) error {
|
|
cPath := C.CString(path)
|
|
defer C.free(unsafe.Pointer(cPath))
|
|
|
|
cArrays := C.mlx_map_string_to_array_new()
|
|
defer C.mlx_map_string_to_array_free(cArrays)
|
|
|
|
for name, arr := range arrays {
|
|
if arr == nil {
|
|
continue
|
|
}
|
|
cName := C.CString(name)
|
|
C.mlx_map_string_to_array_insert(cArrays, cName, arr.ctx)
|
|
C.free(unsafe.Pointer(cName))
|
|
}
|
|
|
|
cMetadata := C.mlx_map_string_to_string_new()
|
|
defer C.mlx_map_string_to_string_free(cMetadata)
|
|
|
|
for key, value := range metadata {
|
|
cKey := C.CString(key)
|
|
cValue := C.CString(value)
|
|
C.mlx_map_string_to_string_insert(cMetadata, cKey, cValue)
|
|
C.free(unsafe.Pointer(cKey))
|
|
C.free(unsafe.Pointer(cValue))
|
|
}
|
|
|
|
if C.mlx_save_safetensors(cPath, cArrays, cMetadata) != 0 {
|
|
return fmt.Errorf("failed to save safetensors: %s", path)
|
|
}
|
|
|
|
return nil
|
|
}
|