Make TE Sequential Grouped linear Op CUDA graphable#2923
Make TE Sequential Grouped linear Op CUDA graphable#2923vthumbe1503 wants to merge 16 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds a CUDA-graph-safe forward/backward path (
Confidence Score: 4/5Safe to merge on non-FSDP setups; the main_grad.view(-1) bug only triggers on Blackwell + accumulate_into_main_grad=True + FSDP when main_grad is non-contiguous with the grouped shape One confirmed P1 (unsafe .view(-1) on potentially non-contiguous FSDP main_grad in _fuser_backward_grouped_tensor) caps the score at 4. All other paths are correct and the refactor of shared helpers is clean. transformer_engine/pytorch/ops/basic/grouped_linear.py — specifically _fuser_backward_grouped_tensor around the make_grouped_tensor_from_rowwise_data call with main_grad.view(-1) Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
FW[fuser_forward] --> CHECK{SM100+ and BF16/FP16\nor MXFP8?}
CHECK -->|Yes| GT_FW[_fuser_forward_grouped_tensor\nCUDA-graph safe\nGPU-resident offsets]
CHECK -->|No| SQ_FW[_fuser_forward_split_quantize\nLegacy path\nCPU split sizes sync]
GT_FW --> SAVE_GT[ctx.use_grouped_tensor_path = True\nSave split_sizes, base_offsets,\nx_data, x_scale, weights]
SQ_FW --> SAVE_SQ[ctx.use_grouped_tensor_path = False\nSave split_sizes, xs, ws]
BW[fuser_backward] --> DISPATCH{ctx.use_grouped_tensor_path?}
DISPATCH -->|True| GT_BW[_fuser_backward_grouped_tensor\nRebuild GroupedTensors\nGrouped GEMM dgrad+wgrad]
DISPATCH -->|False| SQ_BW[_fuser_backward_split_quantize\nCPU list ops\ntex.split_quantize + general_grouped_gemm]
GT_BW --> MAIN_GRAD{accumulate_into_main_grad?}
MAIN_GRAD -->|Yes| MG_VIEW[get_main_grad_from_param\nview_main_grad_as_grouped_buffer\nmain_grad.view⚠️]
MAIN_GRAD -->|No| ALLOC[GroupedTensor.make_grouped_tensor_with_shapes]
Reviews (5): Last reviewed commit: "Merge branch 'main' into grouped_linear_..." | Re-trigger Greptile |
| bias_scale: Optional[torch.Tensor] = None | ||
| if has_bias: | ||
| # Bias always needs to be passed as a GroupedTensor for the grouped GEMM. | ||
| grouped_bias = self._get_grouped_bias_for_gemm(dtype, device) | ||
| if self._scale_bias: | ||
| bias_scale = scales.reshape(-1) | ||
| if bias_scale.dtype != torch.float32: | ||
| bias_scale = bias_scale.to(dtype=torch.float32) | ||
|
|
||
| # Forward grouped GEMM (TN layout: out[i] = x[i] @ w[i]^T) | ||
| general_grouped_gemm_for_grouped_tensor( | ||
| grouped_weights, | ||
| grouped_x, | ||
| grouped_out, |
There was a problem hiding this comment.
Missing contiguity error handling for
main_grad.view(-1)
main_grad.view(-1) will raise a generic RuntimeError if main_grad is non-contiguous (e.g. when returned by get_main_grad() via __fsdp_param__). The equivalent code in backward_grouped_mlp.py wraps the reshape in try/except and re-raises with an actionable message that includes the shape and stride. Without that guard, users hitting this case will see an opaque PyTorch error instead of a clear diagnostic.
| if ctx.requires_grad: | ||
| saved: list[Optional[torch.Tensor]] = [split_sizes, base_offsets] | ||
| if self._scale_bias: | ||
| saved.append(scales) | ||
| # For the wgrad input we save (data, scale_inv). | ||
| # * Quantized path saves columnwise data + scale. | ||
| # * Unquantized path saves the raw rowwise data and a None scale. | ||
| if grouped_x is not None: | ||
| if with_quantized_compute: | ||
| saved.extend( | ||
| [ | ||
| grouped_x.columnwise_data, | ||
| grouped_x.columnwise_scale_inv, | ||
| ] | ||
| ) | ||
| else: | ||
| saved.extend([grouped_x.rowwise_data, None]) | ||
| else: | ||
| saved.extend([None, None]) | ||
| if self.single_grouped_weight: | ||
| saved.append(grouped_weights) | ||
| else: | ||
| saved.extend(grouped_weights) | ||
| ctx.save_for_backward(*saved) | ||
| ctx.use_grouped_tensor_path = True | ||
| ctx.with_quantized_compute = with_quantized_compute | ||
| ctx.input_quantizers = input_quantizers | ||
| ctx.weight_quantizers = weight_quantizers |
There was a problem hiding this comment.
Comment contradicts implementation for weight saving
The block comment says "we save the GroupedTensor's component buffers (rather than the wrapper) and rebuild it in backward" — but the code that follows saves the entire GroupedTensor wrapper for grouped_weights (when single_grouped_weight=True, saved.append(grouped_weights)). Component-buffer saving only applies to grouped_x (which saves columnwise_data/rowwise_data). The misleading comment could cause confusion when debugging or extending this path.
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
| grouped_wgrad = GroupedTensor.make_grouped_tensor_from_rowwise_data( | ||
| num_tensors=num_groups, | ||
| tensor_shape=weight_shape, | ||
| rowwise_data=main_grad.view(-1), | ||
| dtype=main_grad.dtype, | ||
| ) |
There was a problem hiding this comment.
Non-contiguous
main_grad causes opaque RuntimeError in graph-safe backward
main_grad.view(-1) will raise a plain RuntimeError whenever main_grad is non-contiguous (e.g., when FSDP returns a main_grad that is already shaped (num_groups, out_features, in_features) but lives as a non-unit-stride slice of a larger gradient buffer). view_main_grad_as_grouped_buffer only guards the reshape-to-grouped-shape step — once the helper returns the tensor unchanged (shape already matches), the subsequent .view(-1) is unprotected.
The equivalent code in backward_grouped_mlp.py avoids this by passing main_grad directly to make_grouped_tensor_from_rowwise_data, which internally calls .contiguous() when needed. grouped_linear.py should do the same:
| grouped_wgrad = GroupedTensor.make_grouped_tensor_from_rowwise_data( | |
| num_tensors=num_groups, | |
| tensor_shape=weight_shape, | |
| rowwise_data=main_grad.view(-1), | |
| dtype=main_grad.dtype, | |
| ) | |
| grouped_wgrad = GroupedTensor.make_grouped_tensor_from_rowwise_data( | |
| num_tensors=num_groups, | |
| tensor_shape=weight_shape, | |
| rowwise_data=main_grad, | |
| dtype=main_grad.dtype, | |
| ) |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: