|
19 | 19 | is_dp_attention_enabled, |
20 | 20 | ) |
21 | 21 | from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode |
| 22 | +from sglang.srt.utils import is_gfx95_supported |
22 | 23 |
|
23 | 24 | if TYPE_CHECKING: |
24 | 25 | from sglang.srt.layers.radix_attention import RadixAttention |
|
30 | 31 | flash_attn_varlen_func, |
31 | 32 | get_mla_metadata_info_v1, |
32 | 33 | get_mla_metadata_v1, |
| 34 | + get_ps_metadata_info_v1, |
| 35 | + get_ps_metadata_v1, |
33 | 36 | mha_batch_prefill_func, |
| 37 | + mla_prefill_ps_asm_fwd, |
| 38 | + mla_reduce_v1, |
34 | 39 | paged_attention_ragged, |
35 | 40 | ) |
36 | 41 | from aiter.mla import mla_decode_fwd, mla_prefill_fwd |
|
49 | 54 | # Use aiter mla persist design for fp8-kv cache |
50 | 55 | _use_mla_ps_kernel = get_bool_env_var("SGLANG_AITER_MLA_PERSIST", "True") |
51 | 56 |
|
| 57 | +# Use fp8 prefill only on gfx95 |
| 58 | +_use_fp8_prefill_attn = ( |
| 59 | + get_bool_env_var("SGLANG_AITER_FP8_PREFILL_ATTN", "True") and is_gfx95_supported() |
| 60 | +) |
| 61 | + |
52 | 62 | # Persist |
53 | 63 | # fast_mode=True if _use_mla_ps_kernel else False |
54 | 64 | # intra_batch_mode=False if _use_mla_ps_kernel else True |
@@ -308,6 +318,94 @@ def make_mla_meta_data( |
308 | 318 | dtype_kv=dtype, |
309 | 319 | ) |
310 | 320 |
|
| 321 | + def make_mla_prefill_ps_meta_data_buffer( |
| 322 | + self, batch_size: int, max_qlen: int, qlen_granularity: int |
| 323 | + ): |
| 324 | + ( |
| 325 | + (work_meta_data_size, work_meta_data_type), |
| 326 | + (work_indptr_size, work_indptr_type), |
| 327 | + (work_info_size, work_info_type), |
| 328 | + (reduce_indptr_size, reduce_indptr_type), |
| 329 | + (reduce_final_map_size, reduce_final_map_type), |
| 330 | + (reduce_partial_map_size, reduce_partial_map_type), |
| 331 | + ) = get_ps_metadata_info_v1( |
| 332 | + batch_size=batch_size, |
| 333 | + num_head_k=self.num_kv_head, |
| 334 | + max_qlen=max_qlen, |
| 335 | + qlen_granularity=qlen_granularity, |
| 336 | + ) |
| 337 | + |
| 338 | + device = self.device |
| 339 | + work_metadata_ptrs = torch.empty( |
| 340 | + work_meta_data_size, dtype=work_meta_data_type, device=device |
| 341 | + ) |
| 342 | + work_indptr = torch.empty( |
| 343 | + work_indptr_size, dtype=work_indptr_type, device=device |
| 344 | + ) |
| 345 | + work_info = torch.empty(work_info_size, dtype=work_info_type, device=device) |
| 346 | + reduce_indptr = torch.empty( |
| 347 | + reduce_indptr_size, dtype=reduce_indptr_type, device=device |
| 348 | + ) |
| 349 | + reduce_final_map = torch.empty( |
| 350 | + reduce_final_map_size, dtype=reduce_final_map_type, device=device |
| 351 | + ) |
| 352 | + reduce_partial_map = torch.empty( |
| 353 | + reduce_partial_map_size, dtype=reduce_partial_map_type, device=device |
| 354 | + ) |
| 355 | + |
| 356 | + return ( |
| 357 | + work_metadata_ptrs, |
| 358 | + work_indptr, |
| 359 | + work_info, |
| 360 | + reduce_indptr, |
| 361 | + reduce_final_map, |
| 362 | + reduce_partial_map, |
| 363 | + ) |
| 364 | + |
| 365 | + def make_mla_prefill_ps_meta_data( |
| 366 | + self, |
| 367 | + qo_indptr: torch.Tensor, |
| 368 | + kv_indptr: torch.Tensor, |
| 369 | + seq_lens: torch.Tensor, |
| 370 | + work_metadata: torch.Tensor, |
| 371 | + work_indptr: torch.Tensor, |
| 372 | + work_info: torch.Tensor, |
| 373 | + reduce_indptr: torch.Tensor, |
| 374 | + reduce_final_map: torch.Tensor, |
| 375 | + reduce_partial_map: torch.Tensor, |
| 376 | + is_causal: bool = True, |
| 377 | + ): |
| 378 | + gqa_ratio = self.num_head // self.num_kv_head |
| 379 | + num_heads_k = self.num_kv_head |
| 380 | + tile_q = 256 |
| 381 | + qhead_granularity = gqa_ratio |
| 382 | + qlen_granularity = tile_q // qhead_granularity |
| 383 | + kvlen_granularity = max(128, self.page_size) |
| 384 | + block_size = self.page_size |
| 385 | + |
| 386 | + qo_indptr_cpu = qo_indptr.to("cpu", dtype=torch.int32) |
| 387 | + kv_indptr_cpu = kv_indptr.to("cpu", dtype=torch.int32) |
| 388 | + seq_lens_cpu = seq_lens.to("cpu", dtype=torch.int32) |
| 389 | + |
| 390 | + get_ps_metadata_v1( |
| 391 | + qo_indptr_cpu, |
| 392 | + kv_indptr_cpu, |
| 393 | + seq_lens_cpu, |
| 394 | + gqa_ratio, |
| 395 | + num_heads_k, |
| 396 | + work_metadata, |
| 397 | + work_indptr, |
| 398 | + work_info, |
| 399 | + reduce_indptr, |
| 400 | + reduce_final_map, |
| 401 | + reduce_partial_map, |
| 402 | + qhead_granularity=qhead_granularity, |
| 403 | + qlen_granularity=qlen_granularity, |
| 404 | + kvlen_granularity=kvlen_granularity, |
| 405 | + block_size=block_size, |
| 406 | + is_causal=is_causal, |
| 407 | + ) |
| 408 | + |
311 | 409 | def init_forward_metadata(self, forward_batch: ForwardBatch): |
312 | 410 | """Init auxiliary variables for triton attention backend.""" |
313 | 411 |
|
@@ -587,15 +685,56 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): |
587 | 685 | spec_info=None, |
588 | 686 | ) |
589 | 687 |
|
590 | | - kv_indices = self.mla_indices_updater_prefill.kv_indices |
| 688 | + max_q_len = self.mla_indices_updater_prefill.max_q_len |
| 689 | + qo_indptr = self.mla_indices_updater_prefill.qo_indptr |
| 690 | + |
| 691 | + work_metadata = None |
| 692 | + work_indptr = None |
| 693 | + work_info_set = None |
| 694 | + reduce_indptr = None |
| 695 | + reduce_final_map = None |
| 696 | + reduce_partial_map = None |
| 697 | + |
| 698 | + if _use_fp8_prefill_attn: |
| 699 | + tile_q = 256 |
| 700 | + qlen_granularity = tile_q // (self.num_head // self.num_kv_head) |
| 701 | + ( |
| 702 | + work_metadata, |
| 703 | + work_indptr, |
| 704 | + work_info_set, |
| 705 | + reduce_indptr, |
| 706 | + reduce_final_map, |
| 707 | + reduce_partial_map, |
| 708 | + ) = self.make_mla_prefill_ps_meta_data_buffer( |
| 709 | + bs, max_q_len, qlen_granularity |
| 710 | + ) |
| 711 | + |
| 712 | + self.make_mla_prefill_ps_meta_data( |
| 713 | + qo_indptr, |
| 714 | + qo_indptr, |
| 715 | + forward_batch.seq_lens, |
| 716 | + work_metadata, |
| 717 | + work_indptr, |
| 718 | + work_info_set, |
| 719 | + reduce_indptr, |
| 720 | + reduce_final_map, |
| 721 | + reduce_partial_map, |
| 722 | + is_causal=True, |
| 723 | + ) |
591 | 724 |
|
592 | 725 | self.forward_metadata = ForwardMetadata( |
593 | 726 | self.mla_indices_updater_prefill.kv_indptr, |
594 | | - kv_indices, |
595 | | - self.mla_indices_updater_prefill.qo_indptr, |
| 727 | + self.mla_indices_updater_prefill.kv_indices, |
| 728 | + qo_indptr, |
596 | 729 | self.kv_last_page_len[:bs], |
597 | | - self.mla_indices_updater_prefill.max_q_len, |
| 730 | + max_q_len, |
598 | 731 | self.mla_indices_updater_prefill.max_kv_len, |
| 732 | + work_metadata=work_metadata, |
| 733 | + work_info_set=work_info_set, |
| 734 | + work_indptr=work_indptr, |
| 735 | + reduce_indptr=reduce_indptr, |
| 736 | + reduce_final_map=reduce_final_map, |
| 737 | + reduce_partial_map=reduce_partial_map, |
599 | 738 | ) |
600 | 739 | else: |
601 | 740 | self.indices_updater_prefill.update( |
@@ -1047,18 +1186,93 @@ def forward_extend( |
1047 | 1186 | ): |
1048 | 1187 | extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) |
1049 | 1188 | if kv_indices.shape[0] == 0 or extend_no_prefix: |
1050 | | - o = flash_attn_varlen_func( |
1051 | | - q, |
1052 | | - k, |
1053 | | - v, |
1054 | | - qo_indptr, |
1055 | | - qo_indptr, |
1056 | | - max_q_len, |
1057 | | - max_q_len, |
1058 | | - softmax_scale=layer.scaling, |
1059 | | - causal=True, |
1060 | | - ) |
1061 | | - return o |
| 1189 | + if _use_fp8_prefill_attn: |
| 1190 | + total_s = q.shape[0] |
| 1191 | + nhead = layer.tp_q_head_num |
| 1192 | + v_head_dim = layer.v_head_dim |
| 1193 | + |
| 1194 | + if q.dtype != fp8_dtype: |
| 1195 | + q = q.float().to(fp8_dtype) |
| 1196 | + if k.dtype != fp8_dtype: |
| 1197 | + k = k.float().to(fp8_dtype) |
| 1198 | + if v.dtype != fp8_dtype: |
| 1199 | + v = v.float().to(fp8_dtype) |
| 1200 | + one_scale = torch.tensor( |
| 1201 | + 1.0, dtype=torch.float32, device=q.device |
| 1202 | + ) |
| 1203 | + |
| 1204 | + kv_indptr_asm = qo_indptr |
| 1205 | + kv_indices_asm = torch.arange( |
| 1206 | + total_s, device=q.device, dtype=torch.int32 |
| 1207 | + ) |
| 1208 | + |
| 1209 | + tile_q = 256 |
| 1210 | + reduce_indptr = self.forward_metadata.reduce_indptr |
| 1211 | + reduce_final_map = self.forward_metadata.reduce_final_map |
| 1212 | + reduce_partial_map = self.forward_metadata.reduce_partial_map |
| 1213 | + |
| 1214 | + logits = torch.empty( |
| 1215 | + (reduce_partial_map.size(0) * tile_q, nhead, v_head_dim), |
| 1216 | + dtype=torch.float32, |
| 1217 | + device=q.device, |
| 1218 | + ) |
| 1219 | + attn_lse = torch.empty( |
| 1220 | + (reduce_partial_map.size(0) * tile_q, nhead), |
| 1221 | + dtype=torch.float32, |
| 1222 | + device=q.device, |
| 1223 | + ) |
| 1224 | + final_lse = torch.empty( |
| 1225 | + (total_s, nhead), |
| 1226 | + dtype=torch.float32, |
| 1227 | + device=q.device, |
| 1228 | + ) |
| 1229 | + output = q.new_empty( |
| 1230 | + (total_s, nhead, v_head_dim), |
| 1231 | + dtype=self.input_dtype, |
| 1232 | + ) |
| 1233 | + |
| 1234 | + mla_prefill_ps_asm_fwd( |
| 1235 | + q, |
| 1236 | + k, |
| 1237 | + v, |
| 1238 | + qo_indptr, |
| 1239 | + kv_indptr_asm, |
| 1240 | + kv_indices_asm, |
| 1241 | + self.forward_metadata.work_indptr, |
| 1242 | + self.forward_metadata.work_info_set, |
| 1243 | + max_q_len, |
| 1244 | + layer.scaling, |
| 1245 | + True, |
| 1246 | + logits, |
| 1247 | + attn_lse, |
| 1248 | + output, |
| 1249 | + one_scale, |
| 1250 | + one_scale, |
| 1251 | + one_scale, |
| 1252 | + ) |
| 1253 | + mla_reduce_v1( |
| 1254 | + logits, |
| 1255 | + attn_lse, |
| 1256 | + reduce_indptr, |
| 1257 | + reduce_final_map, |
| 1258 | + reduce_partial_map, |
| 1259 | + tile_q, |
| 1260 | + output, |
| 1261 | + final_lse, |
| 1262 | + ) |
| 1263 | + else: |
| 1264 | + output = flash_attn_varlen_func( |
| 1265 | + q, |
| 1266 | + k, |
| 1267 | + v, |
| 1268 | + qo_indptr, |
| 1269 | + qo_indptr, |
| 1270 | + max_q_len, |
| 1271 | + max_q_len, |
| 1272 | + softmax_scale=layer.scaling, |
| 1273 | + causal=True, |
| 1274 | + ) |
| 1275 | + return output |
1062 | 1276 | elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim): |
1063 | 1277 | K_Buffer = torch.index_select(K_Buffer, 0, kv_indices) |
1064 | 1278 | kvc, k_pe = torch.split( |
|
0 commit comments