Skip to content

[Perf] refactor piecewise cuda graph support of Qwen3-Next#17613

Merged
ispobock merged 12 commits intosgl-project:mainfrom
zminglei:refactor-qwen-pcg
Feb 14, 2026
Merged

[Perf] refactor piecewise cuda graph support of Qwen3-Next#17613
ispobock merged 12 commits intosgl-project:mainfrom
zminglei:refactor-qwen-pcg

Conversation

@zminglei
Copy link
Collaborator

@zminglei zminglei commented Jan 23, 2026

Motivation

Currently in Qwen3Next, we hide all ops inside Qwen3GatedDeltaNet.forward in a custom fake op, so that even with piecewise cuda graph, the whole forward of Qwen3GatedDeltaNet is running in eager mode.
Here we refactor RadixLinearAttention to follow the same pattern of RadixAttention where only the attention op is skipped during piecewise cuda graph capturing, allowing more ops to be captured, improve the prefill performance.

The gsm8k benchmark shows this change could speed up Qwen3-Next with PCG by 14.3%.

Modifications

  • Refactored RadixLinearAttention to follow the same pattern as RadixAttention, wrapping only the attention computation in a custom op (unified_linear_attention_with_output)
  • Updated Qwen3GatedDeltaNet.forward to allow input projections, split/reshape, norm, and output projections to be piecewise CUDA graph captured
  • Fixed calc_rows_per_block in layernorm_gated.py to use a constant value when piecewise CUDA graph is enabled, avoiding torch.compile guards on dynamic batch dimensions

Accuracy Tests

Benchmarking and Profiling

python3 -m sglang.launch_server --model-path /shared/public/elr-models/Qwen/Qwen3-Next-80B-A3B-Instruct/9c7f2fbe84465e40164a94cc16cd30b6999b0cc7/ --tp 4 --enable-piecewise-cuda-graph

Before the change:

python benchmark/gsm8k/bench_sglang.py --data-path /shared/public/data/gsm8k/test.jsonl
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:12<00:00, 16.23it/s]
Accuracy: 0.950
Invalid: 0.000
Latency: 12.379 s
Output throughput: 2592.343 token/s

After the change:

python benchmark/gsm8k/bench_sglang.py --data-path /shared/public/data/gsm8k/test.jsonl
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:11<00:00, 18.18it/s]
Accuracy: 0.960
Invalid: 0.000
Latency: 11.041 s
Output throughput: 2963.644 token/s

The throughput is improved by 14.3%.

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @zminglei, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the piecewise CUDA graph support for the Qwen3Next model. By refactoring attention mechanisms and addressing dynamic dimension issues, the changes aim to improve the efficiency and performance of model inference, as evidenced by the provided benchmarking results.

Highlights

  • Refactored RadixLinearAttention: The RadixLinearAttention module has been refactored to align with RadixAttention, now wrapping only the core attention computation within a new custom operation called unified_linear_attention_with_output. This change improves modularity and control over CUDA graph capture.
  • Enhanced Qwen3GatedDeltaNet CUDA Graph Support: The Qwen3GatedDeltaNet.forward method has been updated to enable piecewise CUDA graph capture for various operations including input projections, split/reshape, normalization, and output projections, optimizing the execution flow for Qwen3Next models.
  • Fixed Dynamic Batch Dimension Issue: A fix was implemented in calc_rows_per_block within layernorm_gated.py to use a constant value when piecewise CUDA graph is active. This prevents torch.compile from creating guards on dynamic batch dimensions, ensuring smoother operation with CUDA graphs.
  • Performance Improvement: Benchmarking on the GSM8K dataset shows a significant performance improvement, with output throughput increasing from 2592.343 token/s to 2963.644 token/s, and latency decreasing from 12.379s to 11.041s.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the piecewise CUDA graph support for Qwen3Next models to improve performance. The changes involve moving the custom op wrapper for linear attention to a more granular level, allowing more operations to be captured in the main graph. Additionally, a fix is included for calc_rows_per_block to avoid torch.compile guards on dynamic batch dimensions. The changes look good and are well-aligned with the goal of improving CUDA graph capture. I have one suggestion to improve code readability.

@zminglei zminglei changed the title refactor piecewise cuda graph support of Qwen3Next [Perf] refactor piecewise cuda graph support of Qwen3-Next Jan 23, 2026
@zminglei zminglei marked this pull request as ready for review January 23, 2026 02:14
@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@zminglei
Copy link
Collaborator Author

zminglei commented Jan 23, 2026

/tag-and-rerun-ci again

@zminglei zminglei requested a review from hnyls2002 as a code owner January 23, 2026 04:51
# torch.compile creating guards on the dynamic batch dimension.
try:
if get_global_server_args().enable_piecewise_cuda_graph:
return 4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why 4? avoid magic number by declare a constant maybe

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 4 is the maximum rows per block value, capped for kernel performance. It appeared at the end line of this function:
rows_per_block = min(rows_per_block, 4)
I've updated it to declare it as a constant for better readability, ptal.

except ValueError:
# Global server args not initialized (e.g., in unit tests)
pass
sm_count = _get_sm_count(device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a constant value like 128, why it will affect torch compile

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function returns rows_per_block which would be consumed by a triton kernel _layer_norm_fwd_1pass_kernel as a tl.constexpr, with different M here, it could get different rows_per_block and trigger torch recompile.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think M is a constant during the compilation of a single graph, why would it trigger recompilation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It changes when num_tokens change, basically breaks torch compile guards and triggers a lot of recompilations during capturing all tokens, taking forever for the capture to finish.

@Chen-0210
Copy link
Contributor

qwen3-next with pcg still has some accuracy issues.... so the ci is currently skipped.
I think it would be better to fix the issue and re-enable ci first, and then merge this PR

@hebiao064
Copy link
Collaborator

qwen3-next with pcg still has some accuracy issues.... so the ci is currently skipped. I think it would be better to fix the issue and re-enable ci first, and then merge this PR

is anyone working on this? cc @yizhang2077

@Chen-0210
Copy link
Contributor

@hebiao064 This pr #17706 is expected to fix it.

@zminglei
Copy link
Collaborator Author

qwen3-next with pcg still has some accuracy issues.... so the ci is currently skipped. I think it would be better to fix the issue and re-enable ci first, and then merge this PR

Actually I'm seeing with the change, the qwen3-next with pcg is not having any accuracy issue. The default attention backend here is fa3. You can confirm by trying the command in this PR description to see. I've run locally test many times (>10) to ensure no intermittent issue, so I re-enable the ci here.

@zminglei
Copy link
Collaborator Author

zminglei commented Feb 11, 2026

/rerun-failed-ci again

@zminglei
Copy link
Collaborator Author

zminglei commented Feb 12, 2026

/tag-and-rerun-ci retry

@Chen-0210
Copy link
Contributor

Actually I'm seeing with the change, the qwen3-next with pcg is not having any accuracy issue. The default attention backend here is fa3. You can confirm by trying the command in this PR description to see. I've run locally test many times (>10) to ensure no intermittent issue, so I re-enable the ci here.

Yes, the issue is mainly related to FlashInfer. I think it’s fine to enable ci here.

@ispobock ispobock merged commit 8be18c6 into sgl-project:main Feb 14, 2026
466 of 501 checks passed
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
alisonshao added a commit that referenced this pull request Feb 17, 2026
Reverts the piecewise cuda graph refactor for Qwen3-Next to determine
if it caused the KL divergence increase in the Qwen3-Next KL tests.
Also restores original KL thresholds (0.0025 and 0.008).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants