mirror of
https://github.com/ollama/ollama.git
synced 2026-04-17 21:54:08 +02:00
ggml: skip cublasGemmBatchedEx during graph reservation
cublasGemmBatchedEx fails during graph capture when pool allocations return fake pointers. This is triggered when NUM_PARALLEL is greater than 1 for models like gemma4 that use batched matmuls. Skip it during reservation since the memory tracking is already handled by the pool allocations. Fixes #15249
This commit is contained in:
@@ -11,9 +11,9 @@ must be recreated with no-alloc set to false before loading data.
|
||||
ggml/include/ggml-backend.h | 1 +
|
||||
ggml/src/ggml-backend-impl.h | 16 +++
|
||||
ggml/src/ggml-backend.cpp | 75 ++++++++++-
|
||||
ggml/src/ggml-cuda/common.cuh | 62 ++++++++-
|
||||
ggml/src/ggml-cuda/common.cuh | 83 +++++++++++-
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 224 ++++++++++++++++++++++++++------
|
||||
5 files changed, 333 insertions(+), 45 deletions(-)
|
||||
5 files changed, 354 insertions(+), 45 deletions(-)
|
||||
|
||||
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
|
||||
index 93c95602d..dbbb61d9c 100644
|
||||
@@ -229,7 +229,7 @@ diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
|
||||
index 9fcb2f9fd..e800ee8f6 100644
|
||||
--- a/ggml/src/ggml-cuda/common.cuh
|
||||
+++ b/ggml/src/ggml-cuda/common.cuh
|
||||
@@ -37,6 +37,41 @@
|
||||
@@ -37,6 +37,62 @@
|
||||
#include "vendors/cuda.h"
|
||||
#endif // defined(GGML_USE_HIP)
|
||||
|
||||
@@ -261,12 +261,33 @@ index 9fcb2f9fd..e800ee8f6 100644
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+static cublasStatus_t cublasGemmBatchedExReserve(
|
||||
+ cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
|
||||
+ int m, int n, int k,
|
||||
+ const void *alpha,
|
||||
+ const void *const Aarray[], cudaDataType_t Atype, int lda,
|
||||
+ const void *const Barray[], cudaDataType_t Btype, int ldb,
|
||||
+ const void *beta,
|
||||
+ void *const Carray[], cudaDataType_t Ctype, int ldc,
|
||||
+ int batchCount,
|
||||
+ cublasComputeType_t computeType, cublasGemmAlgo_t algo) {
|
||||
+ if (!reserving_graph) {
|
||||
+ return cublasGemmBatchedEx(handle, transa, transb, m, n, k,
|
||||
+ alpha, Aarray, Atype, lda, Barray, Btype, ldb,
|
||||
+ beta, Carray, Ctype, ldc, batchCount, computeType, algo);
|
||||
+ } else {
|
||||
+ return CUBLAS_STATUS_SUCCESS;
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+#undef cudaMemcpyAsync
|
||||
+#define cudaMemcpyAsync cudaMemcpyAsyncReserve
|
||||
+#undef cudaMemcpy2DAsync
|
||||
+#define cudaMemcpy2DAsync cudaMemcpy2DAsyncReserve
|
||||
+#undef cudaMemsetAsync
|
||||
+#define cudaMemsetAsync cudaMemsetAsyncReserve
|
||||
+#undef cublasGemmBatchedEx
|
||||
+#define cublasGemmBatchedEx cublasGemmBatchedExReserve
|
||||
+
|
||||
#define STRINGIZE_IMPL(...) #__VA_ARGS__
|
||||
#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
|
||||
|
||||
21
ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
vendored
21
ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
vendored
@@ -65,12 +65,33 @@ static cudaError_t cudaMemsetAsyncReserve ( void* devPtr, int value, size_t coun
|
||||
}
|
||||
}
|
||||
|
||||
static cublasStatus_t cublasGemmBatchedExReserve(
|
||||
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
|
||||
int m, int n, int k,
|
||||
const void *alpha,
|
||||
const void *const Aarray[], cudaDataType_t Atype, int lda,
|
||||
const void *const Barray[], cudaDataType_t Btype, int ldb,
|
||||
const void *beta,
|
||||
void *const Carray[], cudaDataType_t Ctype, int ldc,
|
||||
int batchCount,
|
||||
cublasComputeType_t computeType, cublasGemmAlgo_t algo) {
|
||||
if (!reserving_graph) {
|
||||
return cublasGemmBatchedEx(handle, transa, transb, m, n, k,
|
||||
alpha, Aarray, Atype, lda, Barray, Btype, ldb,
|
||||
beta, Carray, Ctype, ldc, batchCount, computeType, algo);
|
||||
} else {
|
||||
return CUBLAS_STATUS_SUCCESS;
|
||||
}
|
||||
}
|
||||
|
||||
#undef cudaMemcpyAsync
|
||||
#define cudaMemcpyAsync cudaMemcpyAsyncReserve
|
||||
#undef cudaMemcpy2DAsync
|
||||
#define cudaMemcpy2DAsync cudaMemcpy2DAsyncReserve
|
||||
#undef cudaMemsetAsync
|
||||
#define cudaMemsetAsync cudaMemsetAsyncReserve
|
||||
#undef cublasGemmBatchedEx
|
||||
#define cublasGemmBatchedEx cublasGemmBatchedExReserve
|
||||
|
||||
#define STRINGIZE_IMPL(...) #__VA_ARGS__
|
||||
#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
|
||||
|
||||
Reference in New Issue
Block a user