diff --git a/python/sglang/srt/layers/attention/fla/layernorm_gated.py b/python/sglang/srt/layers/attention/fla/layernorm_gated.py index 7bc7b9f47c48..1e653e350f9a 100644 --- a/python/sglang/srt/layers/attention/fla/layernorm_gated.py +++ b/python/sglang/srt/layers/attention/fla/layernorm_gated.py @@ -14,6 +14,7 @@ import triton.language as tl from einops import rearrange +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( cdiv, cpu_has_amx_support, @@ -26,6 +27,9 @@ _is_npu = is_npu() _use_cpu = is_cpu() and cpu_has_amx_support() +# Maximum rows per Triton block for layernorm gated kernel +MAX_ROWS_PER_BLOCK = 4 + def rms_norm_ref( x, @@ -166,9 +170,17 @@ def _get_sm_count(device: torch.device) -> int: def calc_rows_per_block(M: int, device: torch.device) -> int: + # When piecewise cuda graph is enabled, use a constant value to avoid + # torch.compile creating guards on the dynamic batch dimension. + try: + if get_global_server_args().enable_piecewise_cuda_graph: + return MAX_ROWS_PER_BLOCK + except ValueError: + # Global server args not initialized (e.g., in unit tests) + pass sm_count = _get_sm_count(device) rows_per_block = next_power_of_2(cdiv(M, 2 * sm_count)) - rows_per_block = min(rows_per_block, 4) + rows_per_block = min(rows_per_block, MAX_ROWS_PER_BLOCK) return rows_per_block diff --git a/python/sglang/srt/layers/radix_linear_attention.py b/python/sglang/srt/layers/radix_linear_attention.py index 4cc44ff4f279..7955cb0a0beb 100644 --- a/python/sglang/srt/layers/radix_linear_attention.py +++ b/python/sglang/srt/layers/radix_linear_attention.py @@ -19,6 +19,10 @@ import torch from torch import nn +from sglang.srt.compilation.compilation_config import register_split_op +from sglang.srt.compilation.piecewise_context_manager import get_forward_context +from sglang.srt.utils.custom_op import register_custom_op + if TYPE_CHECKING: from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -70,10 +74,60 @@ def forward( a: torch.Tensor, b: torch.Tensor, ) -> torch.Tensor: - return forward_batch.attn_backend.forward( - layer=self, - forward_batch=forward_batch, - mixed_qkv=mixed_qkv, - a=a, - b=b, - ) + if forward_batch.forward_mode.is_extend() and get_forward_context() is not None: + # Output shape from linear attention: (1, seq_len, num_v_heads, head_v_dim) + seq_len = mixed_qkv.shape[0] + output = torch.empty( + (1, seq_len, self.num_v_heads, self.head_v_dim), + dtype=mixed_qkv.dtype, + device=mixed_qkv.device, + ) + unified_linear_attention_with_output( + mixed_qkv, + a, + b, + output, + self.layer_id, + ) + return output + else: + return forward_batch.attn_backend.forward( + layer=self, + forward_batch=forward_batch, + mixed_qkv=mixed_qkv, + a=a, + b=b, + ) + + +@register_custom_op(mutates_args=["output"]) +@register_split_op() +def unified_linear_attention_with_output( + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + output: torch.Tensor, + layer_id: int, +) -> None: + """ + Custom op wrapper for linear attention computation only. + """ + context = get_forward_context() + forward_batch = context.forward_batch + attention_layers = context.attention_layers + attention_layer = attention_layers[layer_id] + + ret = forward_batch.attn_backend.forward( + layer=attention_layer, + forward_batch=forward_batch, + mixed_qkv=mixed_qkv, + a=a, + b=b, + ) + + assert ( + output.numel() == ret.numel() + ), f"Output tensor element mismatch: {output.numel()} != {ret.numel()}" + + output.view(ret.shape).copy_(ret) + return diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d23d8e9ad891..f0d8cc8e3449 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2138,7 +2138,10 @@ def init_piecewise_cuda_graphs(self): elif hasattr(layer, "attn"): self.attention_layers.append(layer.attn) elif hasattr(layer, "linear_attn"): - self.attention_layers.append(layer.linear_attn) + if hasattr(layer.linear_attn, "attn"): + self.attention_layers.append(layer.linear_attn.attn) + else: + self.attention_layers.append(layer.linear_attn) # For InternVL model elif hasattr(layer, "attention"): if hasattr(layer.attention, "attn"): diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index a013dca531e4..ba89973f55c7 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -316,7 +316,7 @@ def __init__( prefix=add_prefix("out_proj", prefix), ) - self.linear_attn = RadixLinearAttention( + self.attn = RadixLinearAttention( layer_id=layer_id, num_q_heads=self.num_k_heads // self.attn_tp_size, num_k_heads=self.num_k_heads // self.attn_tp_size, @@ -405,23 +405,6 @@ def forward( hidden_states: torch.Tensor, forward_batch: ForwardBatch, ): - if forward_batch.forward_mode.is_extend() and get_forward_context() is not None: - output = torch.empty_like(hidden_states) - gdn_with_output( - hidden_states, - output, - self.layer_id, - ) - return output - else: - return self._forward(hidden_states, forward_batch) - - def _forward( - self, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, - ): - seq_len, _ = hidden_states.shape is_cuda_graph = forward_batch.forward_mode.is_cuda_graph() projected_states_qkvz, projected_states_ba = self._forward_input_proj( @@ -460,7 +443,7 @@ def _forward( lambda x: x.reshape(x.shape[0], -1), (query, key, value) ) mixed_qkv = torch.cat((query, key, value), dim=-1) - core_attn_out = self.linear_attn( + core_attn_out = self.attn( forward_batch, mixed_qkv=mixed_qkv, a=a, diff --git a/test/registered/models/test_qwen3_next_models_pcg.py b/test/registered/models/test_qwen3_next_models_pcg.py index 775227f92183..bcba9aee9286 100644 --- a/test/registered/models/test_qwen3_next_models_pcg.py +++ b/test/registered/models/test_qwen3_next_models_pcg.py @@ -1,9 +1,5 @@ """ Qwen3 Next piecewise CUDA graph tests. - -DISABLED: See https://github.com/sgl-project/sglang/issues/17039 -PCG tests for Qwen3 Next have intermittent failures (5-10% probability). -Investigation ongoing by @YuweiAn. """ import unittest @@ -22,7 +18,6 @@ register_cuda_ci( est_time=400, suite="stage-c-test-4-gpu-h100", - disabled="Intermittent failures, see #17039", ) QWEN3_NEXT_MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct" @@ -32,7 +27,6 @@ } -@unittest.skip("Disabled: intermittent failures, see #17039") class TestQwen3NextPiecewiseCudaGraph(CustomTestCase): @classmethod