Skip to content

[PyTorch][Core] Fix CUBLAS GGEMM when weight dims are not divisible by 128#2954

Open
vthumbe1503 wants to merge 6 commits intoNVIDIA:mainfrom
vthumbe1503:fix_cublas_grouped_gemm_gptoss_sizes
Open

[PyTorch][Core] Fix CUBLAS GGEMM when weight dims are not divisible by 128#2954
vthumbe1503 wants to merge 6 commits intoNVIDIA:mainfrom
vthumbe1503:fix_cublas_grouped_gemm_gptoss_sizes

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vthumbe1503 and others added 2 commits May 1, 2026 21:00
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 1, 2026

Greptile Summary

This PR fixes incorrect MXFP8 scale-inverse pointer offsets in the cuBLAS grouped GEMM kernel when per-expert weight dimensions are not multiples of 128. The old code used data_offset / 32 (a flat, unpadded stride) to index into the swizzled scale buffer; the fix introduces compute_grouped_tensor_mxfp8_scale_inv_offset, which correctly rounds both the Y dimension to 128-element tiles and the X scale dimension to 4-element tiles before computing the cumulative byte offset. A regression test for the non-divisible shape (2, 256, 2880, 1440) is added.

Confidence Score: 5/5

Safe to merge; the core offset-calculation fix is correct and the only finding is a cosmetic inconsistency in a zero-work test.

No P0/P1 issues found. The kernel change is logically sound: padded_mxfp8_scale_inv_bytes matches the 128×4 swizzle tile layout defined in swizzle.cuh, rowwise is correctly threaded through all call paths via the use_rowwise/use_columnwise lambdas, and the uniform-shape fast path is intact. The single P2 finding (missing is_weight=True in a zero-work test) has no runtime impact.

No files require special attention beyond the minor test inconsistency at tests/pytorch/test_numerics.py line 3086.

Important Files Changed

Filename Overview
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Core fix: replaces a_offset/32 flat scale-inv offset with compute_grouped_tensor_mxfp8_scale_inv_offset, which properly accounts for 128×4 swizzle-tile padding in MXFP8 scale buffers. Adds rowwise field to GroupedOperandSelection and threads it through to the kernel; select_grouped_operand sets it correctly via the use_rowwise/use_columnwise lambdas.
transformer_engine/common/cast/mxfp8/swizzle.cuh Promotes local tile-dimension constants TILE_DIM_X=4 / TILE_DIM_Y=128 to namespace-level GEMM_SWIZZLED_SCALE_TILE_DIM_X/Y, allowing the GEMM kernel to import them without repeating magic numbers. No logic change.
tests/pytorch/test_numerics.py Refactors _make_grouped_tensor_quantized_mxfp8 to accept explicit rowwise/columnwise/is_weight args, adds _per_tensor_quantize_mxfp8 helper for the discrete_in reference path, and adds the non-128-divisible shape (2, 256, 2880, 1440). Minor: test_grouped_gemm_grouped_tensor_zero_work is missing is_weight=True for its uniform-shape weight tensors.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[launch_grouped_gemm_setup] --> B[Reads A_sel.rowwise / B_sel.rowwise]
    B --> C[setup_grouped_gemm_kernel\na_rowwise, b_rowwise args]
    C --> D{scaling_mode ==\nNVTE_MXFP8_1D_SCALING?}
    D -- Yes --> E[compute_grouped_tensor_mxfp8_scale_inv_offset\nA_meta, idx, a_rowwise]
    D -- Yes --> F[compute_grouped_tensor_mxfp8_scale_inv_offset\nB_meta, idx, b_rowwise]
    D -- No  --> G[scale_inv_ptr = base + idx]
    E --> H{first_dims or last_dims non-null?}
    F --> H
    H -- Yes: non-uniform --> I[Loop: sum padded_mxfp8_scale_inv_bytes\nfor each tensor i < idx]
    H -- No: uniform --> J[idx x padded_mxfp8_scale_inv_bytes\nuniform_first, uniform_last]
    I --> K[padded_mxfp8_scale_inv_bytes\nrounds Y to 128, X_scales to 4]
    J --> K
    subgraph select_grouped_operand
        L[use_rowwise lambda] -- sets --> M[sel.rowwise = true]
        N[use_columnwise lambda] -- sets --> O[sel.rowwise = false]
    end
Loading

Reviews (3): Last reviewed commit: "ci fix" | Re-trigger Greptile

Comment thread transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
Comment thread transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Outdated
Comment thread tests/pytorch/test_numerics.py Outdated
Comment thread tests/pytorch/test_numerics.py Outdated
Copy link
Copy Markdown
Collaborator

@zhongbozhu zhongbozhu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhongbozhu
Copy link
Copy Markdown
Collaborator

For future reference, this fix PR should be applied to NVFP4 recipe as well.

@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@vthumbe1503 vthumbe1503 requested a review from timmoon10 May 1, 2026 22:22
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks good. My only serious suggestion is to change the name of GroupedOperandSelection.scale_rowwise to rowwise, since that actually changes the intent of the kernel.

Comment thread transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Outdated
Comment thread transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Outdated
Comment thread tests/pytorch/test_numerics.py Outdated
timmoon10
timmoon10 previously approved these changes May 2, 2026
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci

@vthumbe1503 vthumbe1503 changed the title Fix CUBLAS GGEMM when weight dims are not divisible by 128 [PyTorch][Core] Fix CUBLAS GGEMM when weight dims are not divisible by 128 May 2, 2026
Comment on lines +3163 to +3175
quantizer.optimize_for_gemm = not is_weight
grouped_input = torch.cat(tensors, dim=0)
first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device)
if is_weight:
first_dims = None
else:
first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device)
return tex.group_quantize(grouped_input, quantizer, len(tensors), first_dims)


def _per_tensor_quantize_mxfp8(
tensors: List[torch.Tensor],
*,
rowwise: bool,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 _per_tensor_quantize_mxfp8 produces non-swizzled scales that break discrete_in case

MXFP8Quantizer.optimize_for_gemm defaults to False (set in Quantizer.__init__), so every tensor returned by this helper has with_gemm_swizzled_scales=False. When case == "discrete_in", the test passes A_fp8 directly to general_grouped_gemm_for_grouped_tensor, which routes to nvte_grouped_gemm_with_discrete_inputA. That function contains a hard NVTE_CHECK(A_list_info.with_gemm_swizzled_scales, "MXFP8 grouped GEMM: A scales must be swizzled for GEMM."), so the test will throw for every MXFP8 shape when case == "discrete_in".

The old code called grouped_A.split_into_quantized_tensors() on a grouped tensor built with optimize_for_gemm=True, so the split tensors inherited with_gemm_swizzled_scales=True.

Suggested change
quantizer.optimize_for_gemm = not is_weight
grouped_input = torch.cat(tensors, dim=0)
first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device)
if is_weight:
first_dims = None
else:
first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device)
return tex.group_quantize(grouped_input, quantizer, len(tensors), first_dims)
def _per_tensor_quantize_mxfp8(
tensors: List[torch.Tensor],
*,
rowwise: bool,
def _per_tensor_quantize_mxfp8(
tensors: List[torch.Tensor],
*,
rowwise: bool,
columnwise: bool,
) -> List:
"""Quantize each tensor individually with MXFP8.
Used to build reference discrete inputs for grouped GEMM.
"""
quantizer = MXFP8Quantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=rowwise,
columnwise=columnwise,
)
quantizer.optimize_for_gemm = True
return [quantizer(t) for t in tensors]

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If inputs are not swizzled already, then TE ensures to do that before GEMM

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So not needed for a functionality test

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants