From 850c00d07475dfaf32383afb9deea704e4ea0472 Mon Sep 17 00:00:00 2001 From: Minglei Zhu Date: Fri, 23 Jan 2026 01:39:02 +0000 Subject: [PATCH 1/8] refactor piecewise cuda graph support of Qwen3Next --- .../layers/attention/fla/layernorm_gated.py | 5 ++ .../srt/layers/radix_linear_attention.py | 80 +++++++++++++++++-- python/sglang/srt/models/qwen3_next.py | 42 ---------- 3 files changed, 78 insertions(+), 49 deletions(-) diff --git a/python/sglang/srt/layers/attention/fla/layernorm_gated.py b/python/sglang/srt/layers/attention/fla/layernorm_gated.py index 5d55247da3f5..d1913b10c7c1 100644 --- a/python/sglang/srt/layers/attention/fla/layernorm_gated.py +++ b/python/sglang/srt/layers/attention/fla/layernorm_gated.py @@ -14,6 +14,7 @@ import triton.language as tl from einops import rearrange +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import cdiv, device_context, is_npu, next_power_of_2 _is_npu = is_npu() @@ -158,6 +159,10 @@ def _get_sm_count(device: torch.device) -> int: def calc_rows_per_block(M: int, device: torch.device) -> int: + # When piecewise cuda graph is enabled, use a constant value to avoid + # torch.compile creating guards on the dynamic batch dimension. + if get_global_server_args().enable_piecewise_cuda_graph: + return 4 sm_count = _get_sm_count(device) rows_per_block = next_power_of_2(cdiv(M, 2 * sm_count)) rows_per_block = min(rows_per_block, 4) diff --git a/python/sglang/srt/layers/radix_linear_attention.py b/python/sglang/srt/layers/radix_linear_attention.py index 2fe1dc74921d..8d475a7e86a8 100644 --- a/python/sglang/srt/layers/radix_linear_attention.py +++ b/python/sglang/srt/layers/radix_linear_attention.py @@ -19,6 +19,10 @@ import torch from torch import nn +from sglang.srt.compilation.compilation_config import register_split_op +from sglang.srt.compilation.piecewise_context_manager import get_forward_context +from sglang.srt.utils.custom_op import register_custom_op + if TYPE_CHECKING: from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -74,10 +78,72 @@ def forward( a: torch.Tensor, b: torch.Tensor, ) -> torch.Tensor: - return forward_batch.attn_backend.forward( - layer=self, - forward_batch=forward_batch, - mixed_qkv=mixed_qkv, - a=a, - b=b, - ) + if forward_batch.forward_mode.is_extend() and get_forward_context() is not None: + # Output shape from linear attention: (1, seq_len, num_v_heads, head_v_dim) + seq_len = mixed_qkv.shape[0] + output = torch.empty( + (1, seq_len, self.num_v_heads, self.head_v_dim), + dtype=mixed_qkv.dtype, + device=mixed_qkv.device, + ) + unified_linear_attention_with_output( + mixed_qkv, + a, + b, + output, + self.layer_id, + ) + return output + else: + return forward_batch.attn_backend.forward( + layer=self, + forward_batch=forward_batch, + mixed_qkv=mixed_qkv, + a=a, + b=b, + ) + + +@register_custom_op(mutates_args=["output"]) +@register_split_op() +def unified_linear_attention_with_output( + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + output: torch.Tensor, + layer_id: int, +) -> None: + """ + Custom op wrapper for linear attention computation only. + """ + context = get_forward_context() + forward_batch = context.forward_batch + attention_layers = context.attention_layers + parent_layer = attention_layers[layer_id] + + # For models like Qwen3Next, the RadixLinearAttention is stored + # as a sub-component (parent_layer.linear_attn.linear_attn) + # Navigate to get the actual RadixLinearAttention + if hasattr(parent_layer, 'linear_attn'): + gdn_layer = parent_layer.linear_attn + if hasattr(gdn_layer, 'linear_attn'): + attention_layer = gdn_layer.linear_attn + else: + attention_layer = gdn_layer + else: + attention_layer = parent_layer + + ret = forward_batch.attn_backend.forward( + layer=attention_layer, + forward_batch=forward_batch, + mixed_qkv=mixed_qkv, + a=a, + b=b, + ) + + assert ( + output.numel() == ret.numel() + ), f"Output tensor element mismatch: {output.numel()} != {ret.numel()}" + + output.view(ret.shape).copy_(ret) + return diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 6512afa0810f..70dade62946f 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -5,8 +5,6 @@ import torch from torch import nn -from sglang.srt.compilation.compilation_config import register_split_op -from sglang.srt.compilation.piecewise_context_manager import get_forward_context from sglang.srt.configs.qwen3_next import Qwen3NextConfig from sglang.srt.distributed import get_pp_group from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder @@ -51,7 +49,6 @@ make_layers, set_weight_attrs, ) -from sglang.srt.utils.custom_op import register_custom_op logger = logging.getLogger(__name__) _is_cuda = is_cuda() @@ -392,23 +389,6 @@ def forward( hidden_states: torch.Tensor, forward_batch: ForwardBatch, ): - if forward_batch.forward_mode.is_extend() and get_forward_context() is not None: - output = torch.empty_like(hidden_states) - gdn_with_output( - hidden_states, - output, - self.layer_id, - ) - return output - else: - return self._forward(hidden_states, forward_batch) - - def _forward( - self, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, - ): - seq_len, _ = hidden_states.shape is_cuda_graph = forward_batch.forward_mode.is_cuda_graph() projected_states_qkvz, projected_states_ba = self._forward_input_proj( @@ -1046,25 +1026,3 @@ def get_model_config_for_expert_location(cls, config): EntryClass = Qwen3NextForCausalLM - - -@register_custom_op(mutates_args=["output"]) -@register_split_op() -def gdn_with_output( - hidden_states: torch.Tensor, - output: torch.Tensor, - layer_id: int, -) -> None: - context = get_forward_context() - forward_batch = context.forward_batch - attention_layers = context.attention_layers - attention_layer = attention_layers[layer_id] - - ret = attention_layer._forward(hidden_states, forward_batch) - - assert ( - output.numel() == ret.numel() - ), f"Output tensor element mismatch: {output.numel()} != {ret.numel()}" - - output.view(ret.shape).copy_(ret) - return From cee01993f2e6156c91ef2fd6d65e28b197ce7172 Mon Sep 17 00:00:00 2001 From: Minglei Zhu Date: Fri, 23 Jan 2026 01:48:09 +0000 Subject: [PATCH 2/8] lint --- python/sglang/srt/layers/radix_linear_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/radix_linear_attention.py b/python/sglang/srt/layers/radix_linear_attention.py index 8d475a7e86a8..4606b8b89607 100644 --- a/python/sglang/srt/layers/radix_linear_attention.py +++ b/python/sglang/srt/layers/radix_linear_attention.py @@ -124,9 +124,9 @@ def unified_linear_attention_with_output( # For models like Qwen3Next, the RadixLinearAttention is stored # as a sub-component (parent_layer.linear_attn.linear_attn) # Navigate to get the actual RadixLinearAttention - if hasattr(parent_layer, 'linear_attn'): + if hasattr(parent_layer, "linear_attn"): gdn_layer = parent_layer.linear_attn - if hasattr(gdn_layer, 'linear_attn'): + if hasattr(gdn_layer, "linear_attn"): attention_layer = gdn_layer.linear_attn else: attention_layer = gdn_layer From 2b8f82edaec26ac5aa1875989b1fb5145964ddd4 Mon Sep 17 00:00:00 2001 From: Minglei Zhu Date: Fri, 23 Jan 2026 04:51:12 +0000 Subject: [PATCH 3/8] update --- .../srt/layers/attention/fla/layernorm_gated.py | 8 ++++++-- python/sglang/srt/layers/radix_linear_attention.py | 14 +------------- python/sglang/srt/model_executor/model_runner.py | 5 ++++- python/sglang/srt/models/qwen3_next.py | 4 ++-- 4 files changed, 13 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/layers/attention/fla/layernorm_gated.py b/python/sglang/srt/layers/attention/fla/layernorm_gated.py index d1913b10c7c1..d58577cc4d48 100644 --- a/python/sglang/srt/layers/attention/fla/layernorm_gated.py +++ b/python/sglang/srt/layers/attention/fla/layernorm_gated.py @@ -161,8 +161,12 @@ def _get_sm_count(device: torch.device) -> int: def calc_rows_per_block(M: int, device: torch.device) -> int: # When piecewise cuda graph is enabled, use a constant value to avoid # torch.compile creating guards on the dynamic batch dimension. - if get_global_server_args().enable_piecewise_cuda_graph: - return 4 + try: + if get_global_server_args().enable_piecewise_cuda_graph: + return 4 + except ValueError: + # Global server args not initialized (e.g., in unit tests) + pass sm_count = _get_sm_count(device) rows_per_block = next_power_of_2(cdiv(M, 2 * sm_count)) rows_per_block = min(rows_per_block, 4) diff --git a/python/sglang/srt/layers/radix_linear_attention.py b/python/sglang/srt/layers/radix_linear_attention.py index 4606b8b89607..d6c056dbb776 100644 --- a/python/sglang/srt/layers/radix_linear_attention.py +++ b/python/sglang/srt/layers/radix_linear_attention.py @@ -119,19 +119,7 @@ def unified_linear_attention_with_output( context = get_forward_context() forward_batch = context.forward_batch attention_layers = context.attention_layers - parent_layer = attention_layers[layer_id] - - # For models like Qwen3Next, the RadixLinearAttention is stored - # as a sub-component (parent_layer.linear_attn.linear_attn) - # Navigate to get the actual RadixLinearAttention - if hasattr(parent_layer, "linear_attn"): - gdn_layer = parent_layer.linear_attn - if hasattr(gdn_layer, "linear_attn"): - attention_layer = gdn_layer.linear_attn - else: - attention_layer = gdn_layer - else: - attention_layer = parent_layer + attention_layer = attention_layers[layer_id] ret = forward_batch.attn_backend.forward( layer=attention_layer, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 46fe65ff9445..d834ba4cbd11 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2049,7 +2049,10 @@ def init_piecewise_cuda_graphs(self): elif hasattr(layer, "attn"): self.attention_layers.append(layer.attn) elif hasattr(layer, "linear_attn"): - self.attention_layers.append(layer.linear_attn) + if hasattr(layer.linear_attn, "attn"): + self.attention_layers.append(layer.linear_attn.attn) + else: + self.attention_layers.append(layer.linear_attn) # For InternVL model elif hasattr(layer, "attention"): if hasattr(layer.attention, "attn"): diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 70dade62946f..e4e22fbfa626 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -301,7 +301,7 @@ def __init__( prefix=add_prefix("out_proj", prefix), ) - self.linear_attn = RadixLinearAttention( + self.attn = RadixLinearAttention( layer_id=layer_id, num_qk_heads=self.num_k_heads // self.attn_tp_size, num_v_heads=self.num_v_heads // self.attn_tp_size, @@ -413,7 +413,7 @@ def forward( ) mixed_qkv = torch.cat((query, key, value), dim=-1) - core_attn_out = self.linear_attn( + core_attn_out = self.attn( forward_batch, mixed_qkv=mixed_qkv, a=a, From f286029a8d7dbd383ce90cb0b2cc8961c8894c79 Mon Sep 17 00:00:00 2001 From: Minglei Zhu Date: Fri, 23 Jan 2026 22:11:31 +0000 Subject: [PATCH 4/8] update --- python/sglang/srt/layers/attention/fla/layernorm_gated.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/fla/layernorm_gated.py b/python/sglang/srt/layers/attention/fla/layernorm_gated.py index d58577cc4d48..a2759194457f 100644 --- a/python/sglang/srt/layers/attention/fla/layernorm_gated.py +++ b/python/sglang/srt/layers/attention/fla/layernorm_gated.py @@ -19,6 +19,9 @@ _is_npu = is_npu() +# Maximum rows per Triton block for layernorm gated kernel +MAX_ROWS_PER_BLOCK = 4 + def rms_norm_ref( x, @@ -163,13 +166,13 @@ def calc_rows_per_block(M: int, device: torch.device) -> int: # torch.compile creating guards on the dynamic batch dimension. try: if get_global_server_args().enable_piecewise_cuda_graph: - return 4 + return MAX_ROWS_PER_BLOCK except ValueError: # Global server args not initialized (e.g., in unit tests) pass sm_count = _get_sm_count(device) rows_per_block = next_power_of_2(cdiv(M, 2 * sm_count)) - rows_per_block = min(rows_per_block, 4) + rows_per_block = min(rows_per_block, MAX_ROWS_PER_BLOCK) return rows_per_block From b8e29e76ecdea4e999ed29f67334cbe8c032fc94 Mon Sep 17 00:00:00 2001 From: Minglei Zhu Date: Wed, 11 Feb 2026 04:58:50 +0000 Subject: [PATCH 5/8] re-enable ci --- test/registered/models/test_qwen3_next_models_pcg.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/test/registered/models/test_qwen3_next_models_pcg.py b/test/registered/models/test_qwen3_next_models_pcg.py index 775227f92183..898be3093c42 100644 --- a/test/registered/models/test_qwen3_next_models_pcg.py +++ b/test/registered/models/test_qwen3_next_models_pcg.py @@ -1,9 +1,5 @@ """ Qwen3 Next piecewise CUDA graph tests. - -DISABLED: See https://github.com/sgl-project/sglang/issues/17039 -PCG tests for Qwen3 Next have intermittent failures (5-10% probability). -Investigation ongoing by @YuweiAn. """ import unittest @@ -22,7 +18,6 @@ register_cuda_ci( est_time=400, suite="stage-c-test-4-gpu-h100", - disabled="Intermittent failures, see #17039", ) QWEN3_NEXT_MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct" @@ -31,8 +26,6 @@ QWEN3_NEXT_MODEL: {"kl_div": 0.0025, "gsm8k": 0.93}, } - -@unittest.skip("Disabled: intermittent failures, see #17039") class TestQwen3NextPiecewiseCudaGraph(CustomTestCase): @classmethod From aa7601e98483a2f720d12e60b3fb859eb193e181 Mon Sep 17 00:00:00 2001 From: Minglei Zhu Date: Wed, 11 Feb 2026 05:14:37 +0000 Subject: [PATCH 6/8] lint --- test/registered/models/test_qwen3_next_models_pcg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/registered/models/test_qwen3_next_models_pcg.py b/test/registered/models/test_qwen3_next_models_pcg.py index 898be3093c42..bcba9aee9286 100644 --- a/test/registered/models/test_qwen3_next_models_pcg.py +++ b/test/registered/models/test_qwen3_next_models_pcg.py @@ -26,6 +26,7 @@ QWEN3_NEXT_MODEL: {"kl_div": 0.0025, "gsm8k": 0.93}, } + class TestQwen3NextPiecewiseCudaGraph(CustomTestCase): @classmethod From 3d8b2a1cc3c7bcd1cedc8a0a78397f94e7850080 Mon Sep 17 00:00:00 2001 From: Minglei Zhu Date: Thu, 12 Feb 2026 19:48:13 +0000 Subject: [PATCH 7/8] keep gdn_with_output at now for backward compatibility --- python/sglang/srt/models/qwen3_next.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 9115c7b9965e..0f1eb5ddb0a9 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -5,6 +5,8 @@ import torch from torch import nn +from sglang.srt.compilation.compilation_config import register_split_op +from sglang.srt.compilation.piecewise_context_manager import get_forward_context from sglang.srt.configs.qwen3_next import Qwen3NextConfig from sglang.srt.distributed import get_pp_group from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder @@ -51,6 +53,7 @@ make_layers, set_weight_attrs, ) +from sglang.srt.utils.custom_op import register_custom_op logger = logging.getLogger(__name__) _is_cuda = is_cuda() @@ -1152,3 +1155,25 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[list[int]] = None): EntryClass = Qwen3NextForCausalLM + + +@register_custom_op(mutates_args=["output"]) +@register_split_op() +def gdn_with_output( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_id: int, +) -> None: + context = get_forward_context() + forward_batch = context.forward_batch + attention_layers = context.attention_layers + attention_layer = attention_layers[layer_id] + + ret = attention_layer.forward(hidden_states, forward_batch) + + assert ( + output.numel() == ret.numel() + ), f"Output tensor element mismatch: {output.numel()} != {ret.numel()}" + + output.view(ret.shape).copy_(ret) + return From e6e49b85907b7170d152a1a4700f758fb963b534 Mon Sep 17 00:00:00 2001 From: Minglei Zhu Date: Thu, 12 Feb 2026 20:01:13 +0000 Subject: [PATCH 8/8] update --- python/sglang/srt/models/qwen3_next.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 0f1eb5ddb0a9..ba89973f55c7 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -1169,7 +1169,7 @@ def gdn_with_output( attention_layers = context.attention_layers attention_layer = attention_layers[layer_id] - ret = attention_layer.forward(hidden_states, forward_batch) + ret = attention_layer._forward(hidden_states, forward_batch) assert ( output.numel() == ret.numel()