mirror of
https://github.com/ollama/ollama.git
synced 2026-04-26 02:36:09 +02:00
mlx: add op wrappers for Conv2d, Pad, activations, trig, and masked SDPA (#14913)
* mlx: add op wrappers for Conv2d, Pad, activations, trig, and masked SDPA Add Conv2d, flexible Pad (with axes/mode), PadConstant, Maximum, Minimum, Softplus, ReLU, GLU, Clamp, Sin, Cos, Clip, ScaledDotProductAttentionMasked, and RoPEWithFreqs. Refactor RoPEWithBase to delegate to RoPEWithFreqs. * review comments * mlx: fix ScaledDotProductAttentionMasked to consult the mask argument
This commit is contained in:
@@ -592,7 +592,7 @@ func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) {
|
||||
padBottom := blockRows*scaleShape[0] - rows
|
||||
padSide := blockCols*scaleShape[1] - cols
|
||||
if padBottom > 0 || padSide > 0 {
|
||||
decoded = mlx.Pad(decoded, []int32{0, int32(padBottom), 0, int32(padSide)})
|
||||
decoded = mlx.PadConstant(decoded, []int{0, 1}, []int{0, 0}, []int{padBottom, padSide})
|
||||
}
|
||||
|
||||
decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols))
|
||||
|
||||
Reference in New Issue
Block a user