[PyTorch][Core] Fix CUBLAS GGEMM when weight dims are not divisible by 128#2954
[PyTorch][Core] Fix CUBLAS GGEMM when weight dims are not divisible by 128#2954vthumbe1503 wants to merge 6 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR fixes incorrect MXFP8 scale-inverse pointer offsets in the cuBLAS grouped GEMM kernel when per-expert weight dimensions are not multiples of 128. The old code used Confidence Score: 5/5Safe to merge; the core offset-calculation fix is correct and the only finding is a cosmetic inconsistency in a zero-work test. No P0/P1 issues found. The kernel change is logically sound: No files require special attention beyond the minor test inconsistency at Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[launch_grouped_gemm_setup] --> B[Reads A_sel.rowwise / B_sel.rowwise]
B --> C[setup_grouped_gemm_kernel\na_rowwise, b_rowwise args]
C --> D{scaling_mode ==\nNVTE_MXFP8_1D_SCALING?}
D -- Yes --> E[compute_grouped_tensor_mxfp8_scale_inv_offset\nA_meta, idx, a_rowwise]
D -- Yes --> F[compute_grouped_tensor_mxfp8_scale_inv_offset\nB_meta, idx, b_rowwise]
D -- No --> G[scale_inv_ptr = base + idx]
E --> H{first_dims or last_dims non-null?}
F --> H
H -- Yes: non-uniform --> I[Loop: sum padded_mxfp8_scale_inv_bytes\nfor each tensor i < idx]
H -- No: uniform --> J[idx x padded_mxfp8_scale_inv_bytes\nuniform_first, uniform_last]
I --> K[padded_mxfp8_scale_inv_bytes\nrounds Y to 128, X_scales to 4]
J --> K
subgraph select_grouped_operand
L[use_rowwise lambda] -- sets --> M[sel.rowwise = true]
N[use_columnwise lambda] -- sets --> O[sel.rowwise = false]
end
Reviews (3): Last reviewed commit: "ci fix" | Re-trigger Greptile |
|
For future reference, this fix PR should be applied to NVFP4 recipe as well. |
|
/te-ci pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci |
| quantizer.optimize_for_gemm = not is_weight | ||
| grouped_input = torch.cat(tensors, dim=0) | ||
| first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device) | ||
| if is_weight: | ||
| first_dims = None | ||
| else: | ||
| first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device) | ||
| return tex.group_quantize(grouped_input, quantizer, len(tensors), first_dims) | ||
|
|
||
|
|
||
| def _per_tensor_quantize_mxfp8( | ||
| tensors: List[torch.Tensor], | ||
| *, | ||
| rowwise: bool, |
There was a problem hiding this comment.
_per_tensor_quantize_mxfp8 produces non-swizzled scales that break discrete_in case
MXFP8Quantizer.optimize_for_gemm defaults to False (set in Quantizer.__init__), so every tensor returned by this helper has with_gemm_swizzled_scales=False. When case == "discrete_in", the test passes A_fp8 directly to general_grouped_gemm_for_grouped_tensor, which routes to nvte_grouped_gemm_with_discrete_inputA. That function contains a hard NVTE_CHECK(A_list_info.with_gemm_swizzled_scales, "MXFP8 grouped GEMM: A scales must be swizzled for GEMM."), so the test will throw for every MXFP8 shape when case == "discrete_in".
The old code called grouped_A.split_into_quantized_tensors() on a grouped tensor built with optimize_for_gemm=True, so the split tensors inherited with_gemm_swizzled_scales=True.
| quantizer.optimize_for_gemm = not is_weight | |
| grouped_input = torch.cat(tensors, dim=0) | |
| first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device) | |
| if is_weight: | |
| first_dims = None | |
| else: | |
| first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device) | |
| return tex.group_quantize(grouped_input, quantizer, len(tensors), first_dims) | |
| def _per_tensor_quantize_mxfp8( | |
| tensors: List[torch.Tensor], | |
| *, | |
| rowwise: bool, | |
| def _per_tensor_quantize_mxfp8( | |
| tensors: List[torch.Tensor], | |
| *, | |
| rowwise: bool, | |
| columnwise: bool, | |
| ) -> List: | |
| """Quantize each tensor individually with MXFP8. | |
| Used to build reference discrete inputs for grouped GEMM. | |
| """ | |
| quantizer = MXFP8Quantizer( | |
| fp8_dtype=tex.DType.kFloat8E4M3, | |
| rowwise=rowwise, | |
| columnwise=columnwise, | |
| ) | |
| quantizer.optimize_for_gemm = True | |
| return [quantizer(t) for t in tensors] |
There was a problem hiding this comment.
If inputs are not swizzled already, then TE ensures to do that before GEMM
There was a problem hiding this comment.
So not needed for a functionality test
|
/te-ci |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: