[Perf] refactor piecewise cuda graph support of Qwen3-Next#17613
[Perf] refactor piecewise cuda graph support of Qwen3-Next#17613ispobock merged 12 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @zminglei, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the piecewise CUDA graph support for the Qwen3Next model. By refactoring attention mechanisms and addressing dynamic dimension issues, the changes aim to improve the efficiency and performance of model inference, as evidenced by the provided benchmarking results. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request refactors the piecewise CUDA graph support for Qwen3Next models to improve performance. The changes involve moving the custom op wrapper for linear attention to a more granular level, allowing more operations to be captured in the main graph. Additionally, a fix is included for calc_rows_per_block to avoid torch.compile guards on dynamic batch dimensions. The changes look good and are well-aligned with the goal of improving CUDA graph capture. I have one suggestion to improve code readability.
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
/tag-and-rerun-ci again |
| # torch.compile creating guards on the dynamic batch dimension. | ||
| try: | ||
| if get_global_server_args().enable_piecewise_cuda_graph: | ||
| return 4 |
There was a problem hiding this comment.
why 4? avoid magic number by declare a constant maybe
There was a problem hiding this comment.
The 4 is the maximum rows per block value, capped for kernel performance. It appeared at the end line of this function:
rows_per_block = min(rows_per_block, 4)
I've updated it to declare it as a constant for better readability, ptal.
| except ValueError: | ||
| # Global server args not initialized (e.g., in unit tests) | ||
| pass | ||
| sm_count = _get_sm_count(device) |
There was a problem hiding this comment.
this is a constant value like 128, why it will affect torch compile
There was a problem hiding this comment.
This function returns rows_per_block which would be consumed by a triton kernel _layer_norm_fwd_1pass_kernel as a tl.constexpr, with different M here, it could get different rows_per_block and trigger torch recompile.
There was a problem hiding this comment.
I think M is a constant during the compilation of a single graph, why would it trigger recompilation
There was a problem hiding this comment.
It changes when num_tokens change, basically breaks torch compile guards and triggers a lot of recompilations during capturing all tokens, taking forever for the capture to finish.
|
qwen3-next with pcg still has some accuracy issues.... so the ci is currently skipped. |
is anyone working on this? cc @yizhang2077 |
|
@hebiao064 This pr #17706 is expected to fix it. |
Actually I'm seeing with the change, the qwen3-next with pcg is not having any accuracy issue. The default attention backend here is fa3. You can confirm by trying the command in this PR description to see. I've run locally test many times (>10) to ensure no intermittent issue, so I re-enable the ci here. |
|
/rerun-failed-ci again |
|
/tag-and-rerun-ci retry |
Yes, the issue is mainly related to FlashInfer. I think it’s fine to enable ci here. |
Reverts the piecewise cuda graph refactor for Qwen3-Next to determine if it caused the KL divergence increase in the Qwen3-Next KL tests. Also restores original KL thresholds (0.0025 and 0.008).
Motivation
Currently in Qwen3Next, we hide all ops inside
Qwen3GatedDeltaNet.forwardin a custom fake op, so that even with piecewise cuda graph, the whole forward of Qwen3GatedDeltaNet is running in eager mode.Here we refactor
RadixLinearAttentionto follow the same pattern ofRadixAttentionwhere only the attention op is skipped during piecewise cuda graph capturing, allowing more ops to be captured, improve the prefill performance.The gsm8k benchmark shows this change could speed up Qwen3-Next with PCG by 14.3%.
Modifications
RadixLinearAttentionto follow the same pattern asRadixAttention, wrapping only the attention computation in a custom op (unified_linear_attention_with_output)Qwen3GatedDeltaNet.forwardto allow input projections, split/reshape, norm, and output projections to be piecewise CUDA graph capturedcalc_rows_per_blockinlayernorm_gated.pyto use a constant value when piecewise CUDA graph is enabled, avoiding torch.compile guards on dynamic batch dimensionsAccuracy Tests
Benchmarking and Profiling
python3 -m sglang.launch_server --model-path /shared/public/elr-models/Qwen/Qwen3-Next-80B-A3B-Instruct/9c7f2fbe84465e40164a94cc16cd30b6999b0cc7/ --tp 4 --enable-piecewise-cuda-graphBefore the change:
After the change:
The throughput is improved by 14.3%.
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci