From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Tue, 3 Feb 2026 12:00:00 -0800 Subject: [PATCH] ggml: metal solve_tri --- ggml/src/ggml-metal/ggml-metal-device.cpp | 20 +++++++ ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 11 ++++ ggml/src/ggml-metal/ggml-metal-impl.h | 21 ++++++++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 63 +++++++++++++++++++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.metal | 60 +++++++++++++++++++++ 7 files changed, 177 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 680904d13..83385c9ef 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1370,6 +1370,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_SOLVE_TRI); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_solve_tri_f32"); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_GROUP_NORM); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 0a8b9211a..8a9d17460 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -133,6 +133,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 7b5ee968c..4e5acfbe5 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1023,6 +1023,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_L2_NORM: return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); + case GGML_OP_SOLVE_TRI: + return ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]) && + op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_F32 && + op->type == GGML_TYPE_F32; + case GGML_OP_COUNT_EQUAL: + return has_simdgroup_reduction && + op->src[0]->type == GGML_TYPE_I32 && + op->src[1]->type == GGML_TYPE_I32 && + op->type == GGML_TYPE_I64; case GGML_OP_ARGMAX: return has_simdgroup_reduction; case GGML_OP_NORM: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 8944b07e9..cfdea9c07 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -500,6 +500,27 @@ typedef struct { float eps; } ggml_metal_kargs_l2_norm; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_solve_tri; + typedef struct { int64_t ne00; int64_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 80864f303..4ac135603 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -357,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_l2_norm(ctx, idx); } break; + case GGML_OP_SOLVE_TRI: + { + n_fuse = ggml_metal_op_solve_tri(ctx, idx); + } break; case GGML_OP_GROUP_NORM: { n_fuse = ggml_metal_op_group_norm(ctx, idx); @@ -2931,6 +2935,65 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_kargs_solve_tri args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op); + + const int64_t ncols = ne10; + const int64_t n_batches = (int64_t)ne02 * ne03; + const int64_t nr = n_batches * ncols; + + int nth = 64; + nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + if (nth < 1) { + nth = 1; + } + + const int64_t n_tg = (nr + nth - 1) / nth; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1); + + return 1; +} + int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 902b54452..a475183d3 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -68,6 +68,7 @@ int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx); int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx); int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx); int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 114767785..876a9eecc 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32( } } +kernel void kernel_solve_tri_f32( + constant ggml_metal_kargs_solve_tri & args, + device const char * src0, + device const char * src1, + device char * dst, + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort ntg[[threads_per_threadgroup]]) { + const uint64_t ncols = (uint64_t) args.ne10; + const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03; + const uint64_t nr = n_batches * ncols; + + const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg; + if (gid >= nr) { + return; + } + + const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols); + const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols; + const uint64_t i02 = rem / ncols; + const uint64_t i01 = rem - i02 * ncols; + + const uint64_t sa0 = args.nb00 / sizeof(float); + const uint64_t sa1 = args.nb01 / sizeof(float); + const uint64_t sa2 = args.nb02 / sizeof(float); + const uint64_t sa3 = args.nb03 / sizeof(float); + + const uint64_t sb0 = args.nb10 / sizeof(float); + const uint64_t sb1 = args.nb11 / sizeof(float); + const uint64_t sb2 = args.nb12 / sizeof(float); + const uint64_t sb3 = args.nb13 / sizeof(float); + + const uint64_t sx0 = args.nb0 / sizeof(float); + const uint64_t sx1 = args.nb1 / sizeof(float); + const uint64_t sx2 = args.nb2 / sizeof(float); + const uint64_t sx3 = args.nb3 / sizeof(float); + + device const float * A = (device const float *) src0; + device const float * B = (device const float *) src1; + device float * X = (device float *) dst; + + const uint64_t A_base = i02 * sa2 + i03 * sa3; + const uint64_t B_base = i02 * sb2 + i03 * sb3; + const uint64_t X_base = i02 * sx2 + i03 * sx3; + + const uint64_t n = (uint64_t) args.ne11; + + for (uint64_t i00 = 0; i00 < n; ++i00) { + float sum = 0.0f; + for (uint64_t t = 0; t < i00; ++t) { + sum += A[A_base + i00 * sa1 + t * sa0] * + X[X_base + t * sx1 + i01 * sx0]; + } + + const float diag = A[A_base + i00 * sa1 + i00 * sa0]; + X[X_base + i00 * sx1 + i01 * sx0] = + (B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag; + } +} + kernel void kernel_group_norm_f32( constant ggml_metal_kargs_group_norm & args, device const float * src0,