Files
ollama/x/mlxrunner/mlx/slice.go
Jesse Gross 6f8ddbb26b mlxrunner: fix Slice(0, 0) returning full dimension instead of empty
Slice used cmp.Or to resolve a zero stop value to the dimension size,
intended to support open-ended slices like a[i:]. This made Slice(0, 0)
indistinguishable from Slice(), so any slice with a zero stop would
silently include the entire dimension instead of being empty.

Replace cmp.Or with an explicit End sentinel and resolve negative
indices against the dimension size, matching Python/PyTorch semantics.
2026-03-18 16:06:33 -07:00

101 lines
2.1 KiB
Go

package mlx
// #include "generated.h"
import "C"
import (
"math"
"unsafe"
)
// End is a sentinel value meaning "to the end of the dimension",
// equivalent to an omitted stop in Python (e.g. a[i:]).
const End = math.MaxInt32
type slice struct {
args []int
}
func Slice(args ...int) slice {
return slice{args: args}
}
func resolve(val, dim int) C.int {
if val == End {
return C.int(dim)
}
if val < 0 {
return C.int(dim + val)
}
return C.int(val)
}
func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
if len(slices) != len(dims) {
panic("number of slice arguments must match number of tensor dimensions")
}
args := [3][]C.int{
make([]C.int, len(slices)),
make([]C.int, len(slices)),
make([]C.int, len(slices)),
}
for i, s := range slices {
dim := dims[i]
switch len(s.args) {
case 0:
// slice[:]
args[0][i] = C.int(0)
args[1][i] = C.int(dim)
args[2][i] = C.int(1)
case 1:
// slice[i]
start := resolve(s.args[0], dim)
args[0][i] = start
args[1][i] = start + 1
args[2][i] = C.int(1)
case 2:
// slice[i:j]
args[0][i] = resolve(s.args[0], dim)
args[1][i] = resolve(s.args[1], dim)
args[2][i] = C.int(1)
case 3:
// slice[i:j:k]
args[0][i] = resolve(s.args[0], dim)
args[1][i] = resolve(s.args[1], dim)
args[2][i] = C.int(s.args[2])
default:
panic("invalid slice arguments")
}
}
return args[0], args[1], args[2]
}
func (t *Array) Slice(slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE")
C.mlx_slice(
&out.ctx, t.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),
unsafe.SliceData(stops), C.size_t(len(stops)),
unsafe.SliceData(strides), C.size_t(len(strides)),
DefaultStream().ctx,
)
return out
}
func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE_UPDATE")
C.mlx_slice_update(
&out.ctx, t.ctx, other.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),
unsafe.SliceData(stops), C.size_t(len(stops)),
unsafe.SliceData(strides), C.size_t(len(strides)),
DefaultStream().ctx,
)
return out
}