Skip to content

Commit a8eef53

Browse files
1am9trashkkHuang-amd
andauthored
Fp8 prefill attn kernel integration (sgl-project#18528)
Co-authored-by: kkHuang-amd <wunhuang@amd.com>
1 parent 2cc235e commit a8eef53

File tree

1 file changed

+230
-16
lines changed

1 file changed

+230
-16
lines changed

python/sglang/srt/layers/attention/aiter_backend.py

Lines changed: 230 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
is_dp_attention_enabled,
2020
)
2121
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
22+
from sglang.srt.utils import is_gfx95_supported
2223

2324
if TYPE_CHECKING:
2425
from sglang.srt.layers.radix_attention import RadixAttention
@@ -30,7 +31,11 @@
3031
flash_attn_varlen_func,
3132
get_mla_metadata_info_v1,
3233
get_mla_metadata_v1,
34+
get_ps_metadata_info_v1,
35+
get_ps_metadata_v1,
3336
mha_batch_prefill_func,
37+
mla_prefill_ps_asm_fwd,
38+
mla_reduce_v1,
3439
paged_attention_ragged,
3540
)
3641
from aiter.mla import mla_decode_fwd, mla_prefill_fwd
@@ -49,6 +54,11 @@
4954
# Use aiter mla persist design for fp8-kv cache
5055
_use_mla_ps_kernel = get_bool_env_var("SGLANG_AITER_MLA_PERSIST", "True")
5156

57+
# Use fp8 prefill only on gfx95
58+
_use_fp8_prefill_attn = (
59+
get_bool_env_var("SGLANG_AITER_FP8_PREFILL_ATTN", "True") and is_gfx95_supported()
60+
)
61+
5262
# Persist
5363
# fast_mode=True if _use_mla_ps_kernel else False
5464
# intra_batch_mode=False if _use_mla_ps_kernel else True
@@ -308,6 +318,94 @@ def make_mla_meta_data(
308318
dtype_kv=dtype,
309319
)
310320

321+
def make_mla_prefill_ps_meta_data_buffer(
322+
self, batch_size: int, max_qlen: int, qlen_granularity: int
323+
):
324+
(
325+
(work_meta_data_size, work_meta_data_type),
326+
(work_indptr_size, work_indptr_type),
327+
(work_info_size, work_info_type),
328+
(reduce_indptr_size, reduce_indptr_type),
329+
(reduce_final_map_size, reduce_final_map_type),
330+
(reduce_partial_map_size, reduce_partial_map_type),
331+
) = get_ps_metadata_info_v1(
332+
batch_size=batch_size,
333+
num_head_k=self.num_kv_head,
334+
max_qlen=max_qlen,
335+
qlen_granularity=qlen_granularity,
336+
)
337+
338+
device = self.device
339+
work_metadata_ptrs = torch.empty(
340+
work_meta_data_size, dtype=work_meta_data_type, device=device
341+
)
342+
work_indptr = torch.empty(
343+
work_indptr_size, dtype=work_indptr_type, device=device
344+
)
345+
work_info = torch.empty(work_info_size, dtype=work_info_type, device=device)
346+
reduce_indptr = torch.empty(
347+
reduce_indptr_size, dtype=reduce_indptr_type, device=device
348+
)
349+
reduce_final_map = torch.empty(
350+
reduce_final_map_size, dtype=reduce_final_map_type, device=device
351+
)
352+
reduce_partial_map = torch.empty(
353+
reduce_partial_map_size, dtype=reduce_partial_map_type, device=device
354+
)
355+
356+
return (
357+
work_metadata_ptrs,
358+
work_indptr,
359+
work_info,
360+
reduce_indptr,
361+
reduce_final_map,
362+
reduce_partial_map,
363+
)
364+
365+
def make_mla_prefill_ps_meta_data(
366+
self,
367+
qo_indptr: torch.Tensor,
368+
kv_indptr: torch.Tensor,
369+
seq_lens: torch.Tensor,
370+
work_metadata: torch.Tensor,
371+
work_indptr: torch.Tensor,
372+
work_info: torch.Tensor,
373+
reduce_indptr: torch.Tensor,
374+
reduce_final_map: torch.Tensor,
375+
reduce_partial_map: torch.Tensor,
376+
is_causal: bool = True,
377+
):
378+
gqa_ratio = self.num_head // self.num_kv_head
379+
num_heads_k = self.num_kv_head
380+
tile_q = 256
381+
qhead_granularity = gqa_ratio
382+
qlen_granularity = tile_q // qhead_granularity
383+
kvlen_granularity = max(128, self.page_size)
384+
block_size = self.page_size
385+
386+
qo_indptr_cpu = qo_indptr.to("cpu", dtype=torch.int32)
387+
kv_indptr_cpu = kv_indptr.to("cpu", dtype=torch.int32)
388+
seq_lens_cpu = seq_lens.to("cpu", dtype=torch.int32)
389+
390+
get_ps_metadata_v1(
391+
qo_indptr_cpu,
392+
kv_indptr_cpu,
393+
seq_lens_cpu,
394+
gqa_ratio,
395+
num_heads_k,
396+
work_metadata,
397+
work_indptr,
398+
work_info,
399+
reduce_indptr,
400+
reduce_final_map,
401+
reduce_partial_map,
402+
qhead_granularity=qhead_granularity,
403+
qlen_granularity=qlen_granularity,
404+
kvlen_granularity=kvlen_granularity,
405+
block_size=block_size,
406+
is_causal=is_causal,
407+
)
408+
311409
def init_forward_metadata(self, forward_batch: ForwardBatch):
312410
"""Init auxiliary variables for triton attention backend."""
313411

@@ -587,15 +685,56 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
587685
spec_info=None,
588686
)
589687

590-
kv_indices = self.mla_indices_updater_prefill.kv_indices
688+
max_q_len = self.mla_indices_updater_prefill.max_q_len
689+
qo_indptr = self.mla_indices_updater_prefill.qo_indptr
690+
691+
work_metadata = None
692+
work_indptr = None
693+
work_info_set = None
694+
reduce_indptr = None
695+
reduce_final_map = None
696+
reduce_partial_map = None
697+
698+
if _use_fp8_prefill_attn:
699+
tile_q = 256
700+
qlen_granularity = tile_q // (self.num_head // self.num_kv_head)
701+
(
702+
work_metadata,
703+
work_indptr,
704+
work_info_set,
705+
reduce_indptr,
706+
reduce_final_map,
707+
reduce_partial_map,
708+
) = self.make_mla_prefill_ps_meta_data_buffer(
709+
bs, max_q_len, qlen_granularity
710+
)
711+
712+
self.make_mla_prefill_ps_meta_data(
713+
qo_indptr,
714+
qo_indptr,
715+
forward_batch.seq_lens,
716+
work_metadata,
717+
work_indptr,
718+
work_info_set,
719+
reduce_indptr,
720+
reduce_final_map,
721+
reduce_partial_map,
722+
is_causal=True,
723+
)
591724

592725
self.forward_metadata = ForwardMetadata(
593726
self.mla_indices_updater_prefill.kv_indptr,
594-
kv_indices,
595-
self.mla_indices_updater_prefill.qo_indptr,
727+
self.mla_indices_updater_prefill.kv_indices,
728+
qo_indptr,
596729
self.kv_last_page_len[:bs],
597-
self.mla_indices_updater_prefill.max_q_len,
730+
max_q_len,
598731
self.mla_indices_updater_prefill.max_kv_len,
732+
work_metadata=work_metadata,
733+
work_info_set=work_info_set,
734+
work_indptr=work_indptr,
735+
reduce_indptr=reduce_indptr,
736+
reduce_final_map=reduce_final_map,
737+
reduce_partial_map=reduce_partial_map,
599738
)
600739
else:
601740
self.indices_updater_prefill.update(
@@ -1047,18 +1186,93 @@ def forward_extend(
10471186
):
10481187
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
10491188
if kv_indices.shape[0] == 0 or extend_no_prefix:
1050-
o = flash_attn_varlen_func(
1051-
q,
1052-
k,
1053-
v,
1054-
qo_indptr,
1055-
qo_indptr,
1056-
max_q_len,
1057-
max_q_len,
1058-
softmax_scale=layer.scaling,
1059-
causal=True,
1060-
)
1061-
return o
1189+
if _use_fp8_prefill_attn:
1190+
total_s = q.shape[0]
1191+
nhead = layer.tp_q_head_num
1192+
v_head_dim = layer.v_head_dim
1193+
1194+
if q.dtype != fp8_dtype:
1195+
q = q.float().to(fp8_dtype)
1196+
if k.dtype != fp8_dtype:
1197+
k = k.float().to(fp8_dtype)
1198+
if v.dtype != fp8_dtype:
1199+
v = v.float().to(fp8_dtype)
1200+
one_scale = torch.tensor(
1201+
1.0, dtype=torch.float32, device=q.device
1202+
)
1203+
1204+
kv_indptr_asm = qo_indptr
1205+
kv_indices_asm = torch.arange(
1206+
total_s, device=q.device, dtype=torch.int32
1207+
)
1208+
1209+
tile_q = 256
1210+
reduce_indptr = self.forward_metadata.reduce_indptr
1211+
reduce_final_map = self.forward_metadata.reduce_final_map
1212+
reduce_partial_map = self.forward_metadata.reduce_partial_map
1213+
1214+
logits = torch.empty(
1215+
(reduce_partial_map.size(0) * tile_q, nhead, v_head_dim),
1216+
dtype=torch.float32,
1217+
device=q.device,
1218+
)
1219+
attn_lse = torch.empty(
1220+
(reduce_partial_map.size(0) * tile_q, nhead),
1221+
dtype=torch.float32,
1222+
device=q.device,
1223+
)
1224+
final_lse = torch.empty(
1225+
(total_s, nhead),
1226+
dtype=torch.float32,
1227+
device=q.device,
1228+
)
1229+
output = q.new_empty(
1230+
(total_s, nhead, v_head_dim),
1231+
dtype=self.input_dtype,
1232+
)
1233+
1234+
mla_prefill_ps_asm_fwd(
1235+
q,
1236+
k,
1237+
v,
1238+
qo_indptr,
1239+
kv_indptr_asm,
1240+
kv_indices_asm,
1241+
self.forward_metadata.work_indptr,
1242+
self.forward_metadata.work_info_set,
1243+
max_q_len,
1244+
layer.scaling,
1245+
True,
1246+
logits,
1247+
attn_lse,
1248+
output,
1249+
one_scale,
1250+
one_scale,
1251+
one_scale,
1252+
)
1253+
mla_reduce_v1(
1254+
logits,
1255+
attn_lse,
1256+
reduce_indptr,
1257+
reduce_final_map,
1258+
reduce_partial_map,
1259+
tile_q,
1260+
output,
1261+
final_lse,
1262+
)
1263+
else:
1264+
output = flash_attn_varlen_func(
1265+
q,
1266+
k,
1267+
v,
1268+
qo_indptr,
1269+
qo_indptr,
1270+
max_q_len,
1271+
max_q_len,
1272+
softmax_scale=layer.scaling,
1273+
causal=True,
1274+
)
1275+
return output
10621276
elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
10631277
K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
10641278
kvc, k_pe = torch.split(

0 commit comments

Comments
 (0)