Skip to content

Commit 35b6375

Browse files
committed
update test_fp8_blockwise_gemm.py
1 parent 6668293 commit 35b6375

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

python/sglang/srt/layers/quantization/fp8_utils.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ def _check_cutlass_block_fp8_hardware_support() -> bool:
176176
if is_blackwell_supported() and is_flashinfer_available():
177177
from flashinfer.gemm import gemm_fp8_nt_groupwise
178178

179+
if is_sm90_supported() and is_flashinfer_available():
180+
# FlashInfer SM90 DeepGEMM with automatic swapAB optimization for small M
181+
from flashinfer.gemm import fp8_blockscale_gemm_sm90
182+
179183

180184
def dispatch_w8a8_block_fp8_linear() -> Callable:
181185
"""
@@ -359,8 +363,9 @@ def flashinfer_deepgemm_w8a8_block_fp8_linear_with_fallback(
359363
"""
360364
FlashInfer DeepGEMM backend for SM90 (Hopper) with swapAB optimization.
361365
362-
This backend uses FlashInfer's TensorRT-LLM DeepGEMM JIT compiler which includes
363-
the swapAB optimization for small M dimensions (decoding/low batch sizes).
366+
Uses flashinfer.gemm.fp8_blockscale_gemm_sm90 which automatically selects
367+
the swapAB kernel for small M dimensions (M < 32) for better performance
368+
during decoding/low batch size scenarios.
364369
365370
For SM90 (Hopper), this uses the DeepGEMM JIT with automatic swapAB selection.
366371
"""
@@ -369,6 +374,7 @@ def flashinfer_deepgemm_w8a8_block_fp8_linear_with_fallback(
369374
output_dtype = input.dtype
370375
dtype_supported = output_dtype == torch.bfloat16
371376

377+
# fp8_blockscale_gemm_sm90 requires: N % 64 == 0, K % 128 == 0
372378
shape_supported = weight.shape[0] % 64 == 0 and weight.shape[1] % 128 == 0
373379

374380
if not (shape_supported and dtype_supported):
@@ -383,20 +389,21 @@ def flashinfer_deepgemm_w8a8_block_fp8_linear_with_fallback(
383389
input_2d = input.view(-1, input.shape[-1])
384390
output_shape = [*input.shape[:-1], weight.shape[0]]
385391

386-
q_input, x_scale = sglang_per_token_group_quant_fp8(
392+
# - input: (M, K) BF16 or FP8
393+
# - weight: (N, K) FP8 with weight_scale
394+
# - weight_scale: (N, K//128) for per-token or (N//128, K//128) for per-block
395+
396+
output = fp8_blockscale_gemm_sm90(
387397
input_2d,
388-
block_size[1],
389-
column_major_scales=True,
390-
scale_tma_aligned=True,
391-
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
398+
weight,
399+
input_scale=None, # BF16 input, internal quantization
400+
weight_scale=weight_scale,
401+
out_dtype=output_dtype,
392402
)
393403

394-
output = w8a8_block_fp8_matmul_deepgemm(
395-
q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
396-
)
397404
if bias is not None:
398405
output += bias
399-
return output.to(dtype=output_dtype).view(*output_shape)
406+
return output.view(*output_shape)
400407

401408

402409
def cutlass_w8a8_block_fp8_linear_with_fallback(

test/srt/test_fp8_blockwise_gemm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,10 @@ class TestFP8BlockwiseGemmFlashinferTrtllm(FP8BlockwiseGemmBase, unittest.TestCa
6969
backend = "flashinfer_trtllm"
7070

7171

72+
@unittest.skipIf(get_device_sm() != 90, "Test requires CUDA SM 90")
73+
class TestFP8BlockwiseGemmFlashinferDeepGemm(FP8BlockwiseGemmBase, unittest.TestCase):
74+
backend = "flashinfer_deepgemm"
75+
76+
7277
if __name__ == "__main__":
7378
unittest.main()

0 commit comments

Comments
 (0)