ggml: skip cublasGemmBatchedEx during graph reservation

cublasGemmBatchedEx fails during graph capture when pool allocations
return fake pointers. This is triggered when NUM_PARALLEL is greater
than 1 for models like gemma4 that use batched matmuls. Skip it
during reservation since the memory tracking is already handled by
the pool allocations.

Fixes #15249
This commit is contained in:
Jesse Gross
2026-04-03 11:26:03 -07:00
parent 036ed1b9b5
commit bb0c58e134
2 changed files with 45 additions and 3 deletions

View File

@@ -11,9 +11,9 @@ must be recreated with no-alloc set to false before loading data.
ggml/include/ggml-backend.h | 1 +
ggml/src/ggml-backend-impl.h | 16 +++
ggml/src/ggml-backend.cpp | 75 ++++++++++-
ggml/src/ggml-cuda/common.cuh | 62 ++++++++-
ggml/src/ggml-cuda/common.cuh | 83 +++++++++++-
ggml/src/ggml-cuda/ggml-cuda.cu | 224 ++++++++++++++++++++++++++------
5 files changed, 333 insertions(+), 45 deletions(-)
5 files changed, 354 insertions(+), 45 deletions(-)
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index 93c95602d..dbbb61d9c 100644
@@ -229,7 +229,7 @@ diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 9fcb2f9fd..e800ee8f6 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -37,6 +37,41 @@
@@ -37,6 +37,62 @@
#include "vendors/cuda.h"
#endif // defined(GGML_USE_HIP)
@@ -261,12 +261,33 @@ index 9fcb2f9fd..e800ee8f6 100644
+ }
+}
+
+static cublasStatus_t cublasGemmBatchedExReserve(
+ cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
+ int m, int n, int k,
+ const void *alpha,
+ const void *const Aarray[], cudaDataType_t Atype, int lda,
+ const void *const Barray[], cudaDataType_t Btype, int ldb,
+ const void *beta,
+ void *const Carray[], cudaDataType_t Ctype, int ldc,
+ int batchCount,
+ cublasComputeType_t computeType, cublasGemmAlgo_t algo) {
+ if (!reserving_graph) {
+ return cublasGemmBatchedEx(handle, transa, transb, m, n, k,
+ alpha, Aarray, Atype, lda, Barray, Btype, ldb,
+ beta, Carray, Ctype, ldc, batchCount, computeType, algo);
+ } else {
+ return CUBLAS_STATUS_SUCCESS;
+ }
+}
+
+#undef cudaMemcpyAsync
+#define cudaMemcpyAsync cudaMemcpyAsyncReserve
+#undef cudaMemcpy2DAsync
+#define cudaMemcpy2DAsync cudaMemcpy2DAsyncReserve
+#undef cudaMemsetAsync
+#define cudaMemsetAsync cudaMemsetAsyncReserve
+#undef cublasGemmBatchedEx
+#define cublasGemmBatchedEx cublasGemmBatchedExReserve
+
#define STRINGIZE_IMPL(...) #__VA_ARGS__
#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)

View File

@@ -65,12 +65,33 @@ static cudaError_t cudaMemsetAsyncReserve ( void* devPtr, int value, size_t coun
}
}
static cublasStatus_t cublasGemmBatchedExReserve(
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const void *alpha,
const void *const Aarray[], cudaDataType_t Atype, int lda,
const void *const Barray[], cudaDataType_t Btype, int ldb,
const void *beta,
void *const Carray[], cudaDataType_t Ctype, int ldc,
int batchCount,
cublasComputeType_t computeType, cublasGemmAlgo_t algo) {
if (!reserving_graph) {
return cublasGemmBatchedEx(handle, transa, transb, m, n, k,
alpha, Aarray, Atype, lda, Barray, Btype, ldb,
beta, Carray, Ctype, ldc, batchCount, computeType, algo);
} else {
return CUBLAS_STATUS_SUCCESS;
}
}
#undef cudaMemcpyAsync
#define cudaMemcpyAsync cudaMemcpyAsyncReserve
#undef cudaMemcpy2DAsync
#define cudaMemcpy2DAsync cudaMemcpy2DAsyncReserve
#undef cudaMemsetAsync
#define cudaMemsetAsync cudaMemsetAsyncReserve
#undef cublasGemmBatchedEx
#define cublasGemmBatchedEx cublasGemmBatchedExReserve
#define STRINGIZE_IMPL(...) #__VA_ARGS__
#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)