Skip to content

Make TE Sequential Grouped linear Op CUDA graphable#2923

Open
vthumbe1503 wants to merge 16 commits intoNVIDIA:mainfrom
vthumbe1503:grouped_linear_integration_v2
Open

Make TE Sequential Grouped linear Op CUDA graphable#2923
vthumbe1503 wants to merge 16 commits intoNVIDIA:mainfrom
vthumbe1503:grouped_linear_integration_v2

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 marked this pull request as draft April 24, 2026 20:05
@vthumbe1503 vthumbe1503 changed the title Grouped linear integration v2 Make Grouped linear TE Sequential Op CUDA graphable Apr 24, 2026
@vthumbe1503 vthumbe1503 changed the title Make Grouped linear TE Sequential Op CUDA graphable Make TE Sequential Grouped linear Op CUDA graphable Apr 24, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 24, 2026

Greptile Summary

This PR adds a CUDA-graph-safe forward/backward path (_fuser_forward_grouped_tensor / _fuser_backward_grouped_tensor) for GroupedLinear that uses general_grouped_gemm_for_grouped_tensor with GPU-resident split offsets, gated behind an SM100+ (Blackwell) + BF16/FP16/MXFP8 check. It also introduces single_grouped_weight and single_grouped_bias parameters, extracts repeated main-grad boilerplate into four shared helpers in _common.py, and adds a dedicated CUDA-graph capture test.

  • P1rowwise_data=main_grad.view(-1) in _fuser_backward_grouped_tensor will raise an opaque RuntimeError for any non-contiguous main_grad whose shape already equals the grouped shape (the common FSDP case). The equivalent code in backward_grouped_mlp.py correctly passes main_grad directly and lets make_grouped_tensor_from_rowwise_data call .contiguous() as needed.

Confidence Score: 4/5

Safe to merge on non-FSDP setups; the main_grad.view(-1) bug only triggers on Blackwell + accumulate_into_main_grad=True + FSDP when main_grad is non-contiguous with the grouped shape

One confirmed P1 (unsafe .view(-1) on potentially non-contiguous FSDP main_grad in _fuser_backward_grouped_tensor) caps the score at 4. All other paths are correct and the refactor of shared helpers is clean.

transformer_engine/pytorch/ops/basic/grouped_linear.py — specifically _fuser_backward_grouped_tensor around the make_grouped_tensor_from_rowwise_data call with main_grad.view(-1)

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/basic/grouped_linear.py Core change: adds graph-safe grouped-tensor forward/backward paths and helper methods; P1 issue with main_grad.view(-1) in _fuser_backward_grouped_tensor for non-contiguous FSDP main_grads
transformer_engine/pytorch/ops/_common.py Adds shared helpers extracted from repeated inline code: get_main_grad_from_param, get_accumulate_flag_in_param, view_main_grad_as_grouped_buffer, get_dummy_wgrads_for_params
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Refactors _compute_grad_params to use shared helpers; correctly passes main_grad directly to make_grouped_tensor_from_rowwise_data
tests/pytorch/test_fusible_ops.py Adds test_grouped_linear_cuda_graph_safe and extends test_grouped_linear with single_grouped_weight/single_grouped_bias axes
transformer_engine/pytorch/ops/fused/backward_linear_add.py Simple cleanup: replaces inline main_grad logic with shared helpers
transformer_engine/pytorch/ops/fused/backward_linear_scale.py Same cleanup as backward_linear_add.py; no behavioral changes
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py Same cleanup; improved error label in get_main_grad_from_param call
transformer_engine/pytorch/ops/basic/basic_linear.py Replaces inline main_grad boilerplate with new _common helpers; clean refactor

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    FW[fuser_forward] --> CHECK{SM100+ and BF16/FP16\nor MXFP8?}
    CHECK -->|Yes| GT_FW[_fuser_forward_grouped_tensor\nCUDA-graph safe\nGPU-resident offsets]
    CHECK -->|No| SQ_FW[_fuser_forward_split_quantize\nLegacy path\nCPU split sizes sync]
    GT_FW --> SAVE_GT[ctx.use_grouped_tensor_path = True\nSave split_sizes, base_offsets,\nx_data, x_scale, weights]
    SQ_FW --> SAVE_SQ[ctx.use_grouped_tensor_path = False\nSave split_sizes, xs, ws]
    BW[fuser_backward] --> DISPATCH{ctx.use_grouped_tensor_path?}
    DISPATCH -->|True| GT_BW[_fuser_backward_grouped_tensor\nRebuild GroupedTensors\nGrouped GEMM dgrad+wgrad]
    DISPATCH -->|False| SQ_BW[_fuser_backward_split_quantize\nCPU list ops\ntex.split_quantize + general_grouped_gemm]
    GT_BW --> MAIN_GRAD{accumulate_into_main_grad?}
    MAIN_GRAD -->|Yes| MG_VIEW[get_main_grad_from_param\nview_main_grad_as_grouped_buffer\nmain_grad.view⚠️]
    MAIN_GRAD -->|No| ALLOC[GroupedTensor.make_grouped_tensor_with_shapes]
Loading

Reviews (5): Last reviewed commit: "Merge branch 'main' into grouped_linear_..." | Re-trigger Greptile

Comment on lines +1157 to +1170
bias_scale: Optional[torch.Tensor] = None
if has_bias:
# Bias always needs to be passed as a GroupedTensor for the grouped GEMM.
grouped_bias = self._get_grouped_bias_for_gemm(dtype, device)
if self._scale_bias:
bias_scale = scales.reshape(-1)
if bias_scale.dtype != torch.float32:
bias_scale = bias_scale.to(dtype=torch.float32)

# Forward grouped GEMM (TN layout: out[i] = x[i] @ w[i]^T)
general_grouped_gemm_for_grouped_tensor(
grouped_weights,
grouped_x,
grouped_out,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing contiguity error handling for main_grad.view(-1)

main_grad.view(-1) will raise a generic RuntimeError if main_grad is non-contiguous (e.g. when returned by get_main_grad() via __fsdp_param__). The equivalent code in backward_grouped_mlp.py wraps the reshape in try/except and re-raises with an actionable message that includes the shape and stride. Without that guard, users hitting this case will see an opaque PyTorch error instead of a clear diagnostic.

Comment on lines +1191 to 1218
if ctx.requires_grad:
saved: list[Optional[torch.Tensor]] = [split_sizes, base_offsets]
if self._scale_bias:
saved.append(scales)
# For the wgrad input we save (data, scale_inv).
# * Quantized path saves columnwise data + scale.
# * Unquantized path saves the raw rowwise data and a None scale.
if grouped_x is not None:
if with_quantized_compute:
saved.extend(
[
grouped_x.columnwise_data,
grouped_x.columnwise_scale_inv,
]
)
else:
saved.extend([grouped_x.rowwise_data, None])
else:
saved.extend([None, None])
if self.single_grouped_weight:
saved.append(grouped_weights)
else:
saved.extend(grouped_weights)
ctx.save_for_backward(*saved)
ctx.use_grouped_tensor_path = True
ctx.with_quantized_compute = with_quantized_compute
ctx.input_quantizers = input_quantizers
ctx.weight_quantizers = weight_quantizers
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Comment contradicts implementation for weight saving

The block comment says "we save the GroupedTensor's component buffers (rather than the wrapper) and rebuild it in backward" — but the code that follows saves the entire GroupedTensor wrapper for grouped_weights (when single_grouped_weight=True, saved.append(grouped_weights)). Component-buffer saving only applies to grouped_x (which saves columnwise_data/rowwise_data). The misleading comment could cause confusion when debugging or extending this path.

Comment thread transformer_engine/pytorch/ops/basic/grouped_linear.py
vthumbe1503 and others added 4 commits April 25, 2026 00:34
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Comment thread transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Outdated
@vthumbe1503 vthumbe1503 marked this pull request as ready for review April 28, 2026 23:32
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Comment on lines +1586 to +1591
grouped_wgrad = GroupedTensor.make_grouped_tensor_from_rowwise_data(
num_tensors=num_groups,
tensor_shape=weight_shape,
rowwise_data=main_grad.view(-1),
dtype=main_grad.dtype,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Non-contiguous main_grad causes opaque RuntimeError in graph-safe backward

main_grad.view(-1) will raise a plain RuntimeError whenever main_grad is non-contiguous (e.g., when FSDP returns a main_grad that is already shaped (num_groups, out_features, in_features) but lives as a non-unit-stride slice of a larger gradient buffer). view_main_grad_as_grouped_buffer only guards the reshape-to-grouped-shape step — once the helper returns the tensor unchanged (shape already matches), the subsequent .view(-1) is unprotected.

The equivalent code in backward_grouped_mlp.py avoids this by passing main_grad directly to make_grouped_tensor_from_rowwise_data, which internally calls .contiguous() when needed. grouped_linear.py should do the same:

Suggested change
grouped_wgrad = GroupedTensor.make_grouped_tensor_from_rowwise_data(
num_tensors=num_groups,
tensor_shape=weight_shape,
rowwise_data=main_grad.view(-1),
dtype=main_grad.dtype,
)
grouped_wgrad = GroupedTensor.make_grouped_tensor_from_rowwise_data(
num_tensors=num_groups,
tensor_shape=weight_shape,
rowwise_data=main_grad,
dtype=main_grad.dtype,
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants