Skip to content

Commit bcc6d84

Browse files
authored
Use fused_sigmoid_gating_delta_rule_update_kernel for KDA (#17108)
1 parent a618202 commit bcc6d84

File tree

3 files changed

+37
-17
lines changed

3 files changed

+37
-17
lines changed

python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
3434
USE_INITIAL_STATE: tl.constexpr,
3535
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
3636
IS_VARLEN: tl.constexpr,
37+
IS_KDA: tl.constexpr,
3738
):
3839
"""
3940
Fused kernel that combines sigmoid gating computation with recurrent delta rule update.
@@ -64,8 +65,12 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
6465

6566
# Gating computation pointers
6667
p_A_log = A_log + i_hv
67-
p_a = a + bos * HV + i_hv
68-
p_dt_bias = dt_bias + i_hv
68+
if IS_KDA:
69+
p_a = a + (bos * HV + i_hv) * K + o_k
70+
p_dt_bias = dt_bias + i_hv * K + o_k
71+
else:
72+
p_a = a + bos * HV + i_hv
73+
p_dt_bias = dt_bias + i_hv
6974

7075
mask_k = o_k < K
7176
mask_v = o_v < V
@@ -119,7 +124,10 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
119124
b_q = b_q * scale
120125

121126
# Apply gating to hidden state: h *= exp(g)
122-
b_h *= tl.exp(b_g)
127+
if IS_KDA:
128+
b_h *= tl.exp(b_g[:, None])
129+
else:
130+
b_h *= tl.exp(b_g)
123131

124132
# Delta rule: v -= sum(h * k, dim=0)
125133
b_v -= tl.sum(b_h * b_k[:, None], 0)
@@ -172,6 +180,7 @@ def fused_sigmoid_gating_delta_rule_update(
172180
scale: Optional[float] = None,
173181
use_qk_l2norm_in_kernel: bool = False,
174182
cu_seqlens: Optional[torch.Tensor] = None,
183+
is_kda: bool = False,
175184
):
176185
"""
177186
Fused triton implementation of sigmoid gating delta rule update.
@@ -221,6 +230,7 @@ def fused_sigmoid_gating_delta_rule_update(
221230
USE_INITIAL_STATE=initial_state_source is not None,
222231
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
223232
IS_VARLEN=cu_seqlens is not None,
233+
IS_KDA=is_kda,
224234
num_warps=num_warps,
225235
num_stages=num_stages,
226236
)

python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
1818
fused_sigmoid_gating_delta_rule_update,
1919
)
20-
from sglang.srt.layers.attention.fla.kda import chunk_kda, fused_recurrent_kda
20+
from sglang.srt.layers.attention.fla.kda import chunk_kda
2121
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
2222
PAD_SLOT_ID,
2323
causal_conv1d_fn,
@@ -647,6 +647,9 @@ def forward_decode(
647647
beta = kwargs["beta"]
648648
g = kwargs["gate"]
649649

650+
A_log = kwargs["A_log"]
651+
dt_bias = kwargs["dt_bias"]
652+
650653
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
651654
q_conv_state, k_conv_state, v_conv_state = layer_cache.conv
652655
ssm_states = layer_cache.temporal
@@ -686,21 +689,23 @@ def forward_decode(
686689
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=head_dim), (q, k, v)
687690
)
688691

689-
initial_state = ssm_states[cache_indices].contiguous()
690-
(
691-
core_attn_out,
692-
last_recurrent_state,
693-
) = fused_recurrent_kda(
692+
core_attn_out = fused_sigmoid_gating_delta_rule_update(
693+
A_log=A_log,
694+
dt_bias=dt_bias,
694695
q=q,
695696
k=k,
696697
v=v,
697-
g=g,
698-
beta=beta,
699-
initial_state=initial_state,
700-
use_qk_l2norm_in_kernel=True,
698+
a=g,
699+
b=beta,
700+
initial_state_source=ssm_states,
701+
initial_state_indices=cache_indices,
701702
cu_seqlens=query_start_loc,
703+
use_qk_l2norm_in_kernel=True,
704+
softplus_beta=1.0,
705+
softplus_threshold=20.0,
706+
is_kda=True,
702707
)
703-
ssm_states[cache_indices] = last_recurrent_state
708+
704709
return core_attn_out
705710

706711
def forward_extend(

python/sglang/srt/models/kimi_linear.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,9 +316,12 @@ def forward(
316316

317317
beta = self.b_proj(hidden_states)[0].float().sigmoid()
318318
forget_gate = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
319-
forget_gate = fused_kda_gate(
320-
forget_gate, self.A_log, self.head_dim, g_bias=self.dt_bias
321-
)
319+
320+
# fused_kda_gate is fused to KimiLinearAttentionBackend with decode
321+
if not forward_batch.forward_mode.is_decode():
322+
forget_gate = fused_kda_gate(
323+
forget_gate, self.A_log, self.head_dim, g_bias=self.dt_bias
324+
)
322325
beta = beta.unsqueeze(0)
323326
forget_gate = forget_gate.unsqueeze(0)
324327

@@ -336,6 +339,8 @@ def forward(
336339
"layer_id": self.layer_idx,
337340
"beta": beta,
338341
"gate": forget_gate,
342+
"A_log": self.A_log,
343+
"dt_bias": self.dt_bias,
339344
}
340345

341346
core_attn_out = forward_batch.attn_backend.forward(

0 commit comments

Comments
 (0)