Skip to content
5 changes: 5 additions & 0 deletions python/sglang/srt/layers/attention/fla/layernorm_gated.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import triton.language as tl
from einops import rearrange

from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import cdiv, device_context, is_npu, next_power_of_2

_is_npu = is_npu()
Expand Down Expand Up @@ -158,6 +159,10 @@ def _get_sm_count(device: torch.device) -> int:


def calc_rows_per_block(M: int, device: torch.device) -> int:
# When piecewise cuda graph is enabled, use a constant value to avoid
# torch.compile creating guards on the dynamic batch dimension.
if get_global_server_args().enable_piecewise_cuda_graph:
return 4
sm_count = _get_sm_count(device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a constant value like 128, why it will affect torch compile

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function returns rows_per_block which would be consumed by a triton kernel _layer_norm_fwd_1pass_kernel as a tl.constexpr, with different M here, it could get different rows_per_block and trigger torch recompile.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think M is a constant during the compilation of a single graph, why would it trigger recompilation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It changes when num_tokens change, basically breaks torch compile guards and triggers a lot of recompilations during capturing all tokens, taking forever for the capture to finish.

rows_per_block = next_power_of_2(cdiv(M, 2 * sm_count))
rows_per_block = min(rows_per_block, 4)
Expand Down
80 changes: 73 additions & 7 deletions python/sglang/srt/layers/radix_linear_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import torch
from torch import nn

from sglang.srt.compilation.compilation_config import register_split_op
from sglang.srt.compilation.piecewise_context_manager import get_forward_context
from sglang.srt.utils.custom_op import register_custom_op

if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch

Expand Down Expand Up @@ -74,10 +78,72 @@ def forward(
a: torch.Tensor,
b: torch.Tensor,
) -> torch.Tensor:
return forward_batch.attn_backend.forward(
layer=self,
forward_batch=forward_batch,
mixed_qkv=mixed_qkv,
a=a,
b=b,
)
if forward_batch.forward_mode.is_extend() and get_forward_context() is not None:
# Output shape from linear attention: (1, seq_len, num_v_heads, head_v_dim)
seq_len = mixed_qkv.shape[0]
output = torch.empty(
(1, seq_len, self.num_v_heads, self.head_v_dim),
dtype=mixed_qkv.dtype,
device=mixed_qkv.device,
)
unified_linear_attention_with_output(
mixed_qkv,
a,
b,
output,
self.layer_id,
)
return output
else:
return forward_batch.attn_backend.forward(
layer=self,
forward_batch=forward_batch,
mixed_qkv=mixed_qkv,
a=a,
b=b,
)


@register_custom_op(mutates_args=["output"])
@register_split_op()
def unified_linear_attention_with_output(
mixed_qkv: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
output: torch.Tensor,
layer_id: int,
) -> None:
"""
Custom op wrapper for linear attention computation only.
"""
context = get_forward_context()
forward_batch = context.forward_batch
attention_layers = context.attention_layers
parent_layer = attention_layers[layer_id]

# For models like Qwen3Next, the RadixLinearAttention is stored
# as a sub-component (parent_layer.linear_attn.linear_attn)
# Navigate to get the actual RadixLinearAttention
if hasattr(parent_layer, "linear_attn"):
gdn_layer = parent_layer.linear_attn
if hasattr(gdn_layer, "linear_attn"):
attention_layer = gdn_layer.linear_attn
else:
attention_layer = gdn_layer
else:
attention_layer = parent_layer

ret = forward_batch.attn_backend.forward(
layer=attention_layer,
forward_batch=forward_batch,
mixed_qkv=mixed_qkv,
a=a,
b=b,
)

assert (
output.numel() == ret.numel()
), f"Output tensor element mismatch: {output.numel()} != {ret.numel()}"

output.view(ret.shape).copy_(ret)
return
42 changes: 0 additions & 42 deletions python/sglang/srt/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import torch
from torch import nn

from sglang.srt.compilation.compilation_config import register_split_op
from sglang.srt.compilation.piecewise_context_manager import get_forward_context
from sglang.srt.configs.qwen3_next import Qwen3NextConfig
from sglang.srt.distributed import get_pp_group
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
Expand Down Expand Up @@ -51,7 +49,6 @@
make_layers,
set_weight_attrs,
)
from sglang.srt.utils.custom_op import register_custom_op

logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
Expand Down Expand Up @@ -392,23 +389,6 @@ def forward(
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
):
if forward_batch.forward_mode.is_extend() and get_forward_context() is not None:
output = torch.empty_like(hidden_states)
gdn_with_output(
hidden_states,
output,
self.layer_id,
)
return output
else:
return self._forward(hidden_states, forward_batch)

def _forward(
self,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
):
seq_len, _ = hidden_states.shape
is_cuda_graph = forward_batch.forward_mode.is_cuda_graph()

projected_states_qkvz, projected_states_ba = self._forward_input_proj(
Expand Down Expand Up @@ -1046,25 +1026,3 @@ def get_model_config_for_expert_location(cls, config):


EntryClass = Qwen3NextForCausalLM


@register_custom_op(mutates_args=["output"])
@register_split_op()
def gdn_with_output(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_id: int,
) -> None:
context = get_forward_context()
forward_batch = context.forward_batch
attention_layers = context.attention_layers
attention_layer = attention_layers[layer_id]

ret = attention_layer._forward(hidden_states, forward_batch)

assert (
output.numel() == ret.numel()
), f"Output tensor element mismatch: {output.numel()} != {ret.numel()}"

output.view(ret.shape).copy_(ret)
return
Loading