mlx: add op wrappers for Conv2d, Pad, activations, trig, and masked SDPA (#14913)

* mlx: add op wrappers for Conv2d, Pad, activations, trig, and masked SDPA

Add Conv2d, flexible Pad (with axes/mode), PadConstant, Maximum,
Minimum, Softplus, ReLU, GLU, Clamp, Sin, Cos, Clip,
ScaledDotProductAttentionMasked, and RoPEWithFreqs. Refactor
RoPEWithBase to delegate to RoPEWithFreqs.

* review comments

* mlx: fix ScaledDotProductAttentionMasked to consult the mask argument
This commit is contained in:
Daniel Hiltgen
2026-04-13 11:43:24 -07:00
committed by GitHub
parent d3da29cbfc
commit c88fb286ec
3 changed files with 216 additions and 39 deletions

View File

@@ -592,7 +592,7 @@ func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) {
padBottom := blockRows*scaleShape[0] - rows
padSide := blockCols*scaleShape[1] - cols
if padBottom > 0 || padSide > 0 {
decoded = mlx.Pad(decoded, []int32{0, int32(padBottom), 0, int32(padSide)})
decoded = mlx.PadConstant(decoded, []int{0, 1}, []int{0, 0}, []int{padBottom, padSide})
}
decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols))