Skip to content

[FIX] Always support TP > 4 for FP4 Gemm#17300

Merged
Fridge003 merged 6 commits intosgl-project:mainfrom
danielafrimi:tp4_nvfp4
Feb 5, 2026
Merged

[FIX] Always support TP > 4 for FP4 Gemm#17300
Fridge003 merged 6 commits intosgl-project:mainfrom
danielafrimi:tp4_nvfp4

Conversation

@danielafrimi
Copy link
Contributor

Summary
This PR enables FP4 (NVFP4) quantization to work with TP >= 4

Background
Previously, FP4 quantized models would fail to initialize with TP=4/8 due to kernel alignment requirements. The FlashInfer-CUTLASS FP4 GEMM kernels require the N/K-dimension to be divisible by 32.

Changes
This PR makes the following changes to support TP >= 4, by padding the weights accordingly. In addition, in cases where we pad the K dim of the weights, we need to pad the activation in the forward pass.

@github-actions github-actions bot added the quant LLM Quantization label Jan 18, 2026
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @danielafrimi, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical issue preventing FP4 quantized models from initializing with high Tensor Parallelism (TP >= 4). It achieves this by implementing a robust padding strategy for both weights and activations. This ensures that the dimensions of the matrices conform to the strict alignment requirements of FlashInfer-CUTLASS FP4 GEMM kernels, thereby enhancing the compatibility and performance of FP4 quantization across various parallel configurations.

Highlights

  • FP4 Quantization Support for High Tensor Parallelism: This pull request enables FP4 (NVFP4) quantized models to function correctly with Tensor Parallelism (TP) values of 4 or greater, addressing previous initialization failures.
  • Kernel Alignment for FlashInfer-CUTLASS: The changes introduce padding mechanisms to ensure that the N and K dimensions of matrices are divisible by 32, satisfying the strict alignment requirements of FlashInfer-CUTLASS FP4 GEMM kernels for efficient memory access and tensor core operations.
  • Dynamic Weight and Activation Padding: New utility functions have been added to dynamically pad both the packed NVFP4 weights and activations. Weights are padded in both N (rows) and K (columns) dimensions, and activations are padded in the K-dimension to match the padded weights.
  • Output Slicing to Remove Padding: A mechanism is implemented to slice the output tensor after GEMM operations, effectively removing any padding introduced in the N-dimension of the weights to restore the original output size.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces padding for FP4 quantized weights and activations to ensure they meet the alignment requirements of CUTLASS/FlashInfer kernels, specifically for tensor parallelism (TP) levels of 4 or greater. The changes include adding helper functions for padding and slicing tensors, and integrating this logic into the ModelOptFp4LinearMethod. The implementation correctly stores the original tensor dimensions, applies padding during weight processing and activation quantization, and slices the output to remove padding after the GEMM operation. My review focuses on the correctness and efficiency of the new padding logic. I've suggested a minor optimization to combine padding operations for better performance.

@Fridge003
Copy link
Collaborator

Fridge003 commented Jan 19, 2026

@danielafrimi Have you tried other fp4 gemm implementations other than cutlass?
We have two other options (flashinfer_trtllm and flashinfer_cudnn). If any of them can work, then we just need to change the fp4 gemm kernel to use

@netanel-haber
Copy link
Contributor

/tag-and-rerun-ci

Copy link
Collaborator

@b8zhong b8zhong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Btw, what models with weird hidden dim are there issues with? (guess: nemotron? Just curious). Because recently we fixd something related to GLM scale divisibility as well

layer.weights_padding_cols = 0
return

# Pad weights for CUTLASS/FlashInfer kernel alignment (K and N divisible by 32)
Copy link
Collaborator

@b8zhong b8zhong Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: if this comment is accurate, we also need to pad it for the trtllm backend? should we also do it under the block above

@b8zhong
Copy link
Collaborator

b8zhong commented Jan 29, 2026

Hi @danielafrimi can you fix the merge conflicts

Signed-off-by: root <root@gpu-139.slurm-workers-slurm.slurm.svc.cluster.local>
Signed-off-by: root <dafrimi@nvidia.com>
Signed-off-by: root <dafrimi@nvidia.com>
Signed-off-by: root <dafrimi@nvidia.com>
@danielafrimi
Copy link
Contributor Author

danielafrimi commented Feb 2, 2026

@b8zhong @Fridge003
Tried other FP4 kernels (flashinfer_trtllm and flashinfer_cudnn) in addition to cutlass one.
flashinfer_cudnn and flashinfer_cutlass have the same constraints which are fixed in this PR.

flashinfer_trtllm has other constraints (can be seen in flashinfer/utils.py for the shuffle_matrix_a), which require the

            # Alignment requirements:
            #   - shuffle_matrix_a: weight.shape[0] (N) % 32 == 0
            #   - shuffle_matrix_sf_a: scale.shape[0] (N) % 128 == 0, scale.shape[1] (K/16) % 4 == 0

So to match the scale and weights N-dim we pad it to be mul of 128, and for the K-dim ((K/16) % 4 == 0) we pad accordingly the scales and weights, which forces us to pad the activation as well.

BTW, for flashinfer_trtllm tp=2 didn't work for our nemotron-nano-v3 nvfp4 model - this PR fixed it as well.

@Fridge003
Copy link
Collaborator

/tag-and-rerun-ci

@Fridge003 Fridge003 merged commit 3f1df32 into sgl-project:main Feb 5, 2026
367 of 404 checks passed
charlesHsuGG pushed a commit to charlesHsuGG/sglang that referenced this pull request Feb 9, 2026
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

quant LLM Quantization run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants