diff --git a/x/mlxrunner/mlx/mlx.go b/x/mlxrunner/mlx/mlx.go index 4d8b88ab8..f110aa692 100644 --- a/x/mlxrunner/mlx/mlx.go +++ b/x/mlxrunner/mlx/mlx.go @@ -12,21 +12,16 @@ package mlx // #include "generated.h" import "C" -import ( - "unsafe" -) - func doEval(outputs []*Array, async bool) { - vectorData := make([]C.mlx_array, 0, len(outputs)) + vector := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(vector) + for _, output := range outputs { if output.Valid() { - vectorData = append(vectorData, output.ctx) + C.mlx_vector_array_append_value(vector, output.ctx) } } - vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData))) - defer C.mlx_vector_array_free(vector) - if async { C.mlx_async_eval(vector) } else { diff --git a/x/mlxrunner/mlx/ops.go b/x/mlxrunner/mlx/ops.go index a8c7f3bce..08f9ebc2c 100644 --- a/x/mlxrunner/mlx/ops.go +++ b/x/mlxrunner/mlx/ops.go @@ -66,15 +66,13 @@ func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array { } func (t *Array) Concatenate(axis int, others ...*Array) *Array { - vectorData := make([]C.mlx_array, len(others)+1) - vectorData[0] = t.ctx - for i := range others { - vectorData[i+1] = others[i].ctx - } - - vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData))) + vector := C.mlx_vector_array_new() defer C.mlx_vector_array_free(vector) + for _, other := range append([]*Array{t}, others...) { + C.mlx_vector_array_append_value(vector, other.ctx) + } + out := New("CONCATENATE", t) C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx) return out