Optimizations for MXFP8/NVFP4 dequantize kernels#2865
Optimizations for MXFP8/NVFP4 dequantize kernels#2865YigongQin wants to merge 14 commits intoNVIDIA:mainfrom
Conversation
f5e7375 to
39c0fb1
Compare
|
The following relevant unit tests passed on SM100 (with the drop |
ddab15d to
3a4afdd
Compare
|
After this PR, fwd is around 3%-4% faster for DeepSeek shape MoE: |
Greptile SummaryThis PR extends the MXFP8 and NVFP4 dequantize kernels to support GEMM-swizzled scale layouts by templating both kernels on Confidence Score: 5/5Safe to merge; no P0/P1 issues found — all findings are P2 suggestions. The kernel index math for both rowwise and colwise swizzled paths was verified to be within bounds. The empty-tensor guard is correctly placed at the dispatch level. Python workaround removals are consistent with the new dequantize capability. One P2 asymmetry exists in basic_linear.py where the removed condition was broader than the analogous sibling modules, but it does not affect current behavior. transformer_engine/pytorch/ops/basic/basic_linear.py — removed condition was broader (backward_override is not None) than linear.py/layernorm_linear.py (== "dequantized"). Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[nvte_dequantize] --> B[dequantize_helper dispatch]
B --> C{numel == 0?}
C -->|yes| D[early return — CUDA graph safe]
C -->|no| E{scaling_mode}
E -->|NVTE_MXFP8_1D_SCALING| F[mxfp8::dequantize]
E -->|NVTE_NVFP4_1D_SCALING| G[nvfp4::dequantize]
F --> H{with_gemm_swizzled_scales?}
G --> I{with_gemm_swizzled_scales?}
H -->|true| J[dequantize_mxfp8_kernel WITH_GEMM_SWIZZLED_SCALES=true]
H -->|false| K[dequantize_mxfp8_kernel WITH_GEMM_SWIZZLED_SCALES=false]
I -->|true| L[dequantize_fp4_kernel WITH_GEMM_SWIZZLED_SCALES=true]
I -->|false| M[dequantize_fp4_kernel WITH_GEMM_SWIZZLED_SCALES=false]
Reviews (14): Last reviewed commit: "Remove unnecessary scale from NVFP4 C++ ..." | Re-trigger Greptile |
| } | ||
| } | ||
|
|
||
| std::vector<std::pair<size_t, size_t>> nvfp4_tensor_dims = { |
There was a problem hiding this comment.
There is one edge case:
For MXFP8, When the input shape is like 64x64, it will produce scaling factor shape 64x2, but then zero padded to 128x4. We should be able to inject some very large random values in the padded region during malloc (because we don't use torch.zeros to malloc but torch.empty), and detect whether dequantize results is affected. If things work as expected, this line will be triggered
and the dequantize numerics won't be affected.For NVFP4, I think we optimize for GEMM (or swizzle fusion) is actually not enabled, same for the zero-out edge case handling logic?
So there shouldn't be any unswizzle logic needed here?There was a problem hiding this comment.
For NVFP4, I believe currently only device-init grouped quantize with RHT has the swizzle fusion feature, so the scaling factor zero-out is the job of the dedicated swizzle kernel. So if we dequantize + unswizzle for NVFP4, the unswizzle logic might not be correct.
There was a problem hiding this comment.
For both MXFP8 and NVFP4, the unit test logic is: 1. generate compact scales (or from quantization); 2. call nvte_swizzle_scaling_factors to swizzle compact scales; 3. compare results of nvte_dequantize with compact scales and swizzled scales. Quantize with swizzle fusion is never enabled for both MXFP8 and NVFP4
e6f2a6c to
0eccfb1
Compare
0eccfb1 to
2c479b0
Compare
| fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() | ||
| if fp8_recipe.backward_override == "dequantized" and ( | ||
| fp8_recipe.mxfp8() or fp8_recipe.nvfp4() | ||
| ): | ||
| input_quantizer.optimize_for_gemm = False | ||
| if grad_output_quantizer is not None: | ||
| grad_output_quantizer.optimize_for_gemm = False |
There was a problem hiding this comment.
I'm of two minds about this:
- Logically, GEMM-optimized data is not guaranteed to support anything except GEMMs. Even if MXFP8 and NVFP4 dequant happens to support them, these are custom optimizations. Future recipes can not be expected to support dequantizing GEMM-optimzied data by default.
- It's a little pedantic to have edge-case logic that won't be triggered by any of our existing use-cases. Given how subtle this is, I worry about it becoming stale and distracting.
I think for now, this change is fine. However, if we encounter problems in a future recipe, we should reimplement it properly:
# LOGICALLY WRONG!
# Fails if we add a new recipe
if recipe.backward_override == "dequantized" and recipe.future_recipe():
input_quantizer.optimize_for_gemm = False
# LOGICALLY RIGHT!
# Automatically handles new recipes
if recipe.backward_override == "dequantized" and not (
recipe.float8_per_tensor_scaling()
or recipe.float8_block_scaling()
or recipe.mxfp8()
or recipe.nvfp4()
):
input_quantizer.optimize_for_gemm = False666c496 to
80484a9
Compare
|
/te-ci pytorch L1 |
|
/te-ci core |
timmoon10
left a comment
There was a problem hiding this comment.
LGTM, pending CI. These kernels will be very useful.
c6e4288 to
ce7b295
Compare
This comment was marked as spam.
This comment was marked as spam.
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: Ziang Li <ziangli@umich.edu> Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: Ziang Li <ziangli@umich.edu> Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
for more information, see https://pre-commit.ci
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
a2cdef5 to
03affc2
Compare
Signed-off-by: Tim Moon <tmoon@nvidia.com>
4ff09d4 to
881ff03
Compare
|
/te-ci |
Description
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: