From a1c8153073d7da318ebaf210b9568e697f4164fe Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 25 Mar 2024 16:00:05 -0700 Subject: [PATCH] [Cutlass] Add check for group gemm param shapes --- src/runtime/contrib/cutlass/fp8_group_gemm.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/runtime/contrib/cutlass/fp8_group_gemm.cu b/src/runtime/contrib/cutlass/fp8_group_gemm.cu index c93da6ff5766..31ad4367afcf 100644 --- a/src/runtime/contrib/cutlass/fp8_group_gemm.cu +++ b/src/runtime/contrib/cutlass/fp8_group_gemm.cu @@ -54,9 +54,11 @@ void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr, NDArr CHECK_EQ(out->ndim, 2); CHECK_EQ(alpha->dtype.code, kDLFloat); CHECK_EQ(alpha->dtype.bits, 32); + CHECK_EQ(alpha->ndim, 1); + CHECK_EQ(alpha->shape[0], 1); int num_groups = weight->shape[0]; int n = weight->shape[1]; - int k = weight->shape[2]; + int k = x->shape[1]; const float* beta = nullptr; cudaStream_t stream = static_cast((*func)().operator void*()); cutlass_group_gemm(static_cast(x->data), static_cast(weight->data),