@@ -176,6 +176,10 @@ def _check_cutlass_block_fp8_hardware_support() -> bool:
176176if is_blackwell_supported () and is_flashinfer_available ():
177177 from flashinfer .gemm import gemm_fp8_nt_groupwise
178178
179+ if is_sm90_supported () and is_flashinfer_available ():
180+ # FlashInfer SM90 DeepGEMM with automatic swapAB optimization for small M
181+ from flashinfer .gemm import fp8_blockscale_gemm_sm90
182+
179183
180184def dispatch_w8a8_block_fp8_linear () -> Callable :
181185 """
@@ -359,8 +363,9 @@ def flashinfer_deepgemm_w8a8_block_fp8_linear_with_fallback(
359363 """
360364 FlashInfer DeepGEMM backend for SM90 (Hopper) with swapAB optimization.
361365
362- This backend uses FlashInfer's TensorRT-LLM DeepGEMM JIT compiler which includes
363- the swapAB optimization for small M dimensions (decoding/low batch sizes).
366+ Uses flashinfer.gemm.fp8_blockscale_gemm_sm90 which automatically selects
367+ the swapAB kernel for small M dimensions (M < 32) for better performance
368+ during decoding/low batch size scenarios.
364369
365370 For SM90 (Hopper), this uses the DeepGEMM JIT with automatic swapAB selection.
366371 """
@@ -369,6 +374,7 @@ def flashinfer_deepgemm_w8a8_block_fp8_linear_with_fallback(
369374 output_dtype = input .dtype
370375 dtype_supported = output_dtype == torch .bfloat16
371376
377+ # fp8_blockscale_gemm_sm90 requires: N % 64 == 0, K % 128 == 0
372378 shape_supported = weight .shape [0 ] % 64 == 0 and weight .shape [1 ] % 128 == 0
373379
374380 if not (shape_supported and dtype_supported ):
@@ -383,20 +389,21 @@ def flashinfer_deepgemm_w8a8_block_fp8_linear_with_fallback(
383389 input_2d = input .view (- 1 , input .shape [- 1 ])
384390 output_shape = [* input .shape [:- 1 ], weight .shape [0 ]]
385391
386- q_input , x_scale = sglang_per_token_group_quant_fp8 (
392+ # - input: (M, K) BF16 or FP8
393+ # - weight: (N, K) FP8 with weight_scale
394+ # - weight_scale: (N, K//128) for per-token or (N//128, K//128) for per-block
395+
396+ output = fp8_blockscale_gemm_sm90 (
387397 input_2d ,
388- block_size [ 1 ] ,
389- column_major_scales = True ,
390- scale_tma_aligned = True ,
391- scale_ue8m0 = deep_gemm_wrapper . DEEPGEMM_SCALE_UE8M0 ,
398+ weight ,
399+ input_scale = None , # BF16 input, internal quantization
400+ weight_scale = weight_scale ,
401+ out_dtype = output_dtype ,
392402 )
393403
394- output = w8a8_block_fp8_matmul_deepgemm (
395- q_input , weight , x_scale , weight_scale , block_size , output_dtype = output_dtype
396- )
397404 if bias is not None :
398405 output += bias
399- return output .to ( dtype = output_dtype ). view (* output_shape )
406+ return output .view (* output_shape )
400407
401408
402409def cutlass_w8a8_block_fp8_linear_with_fallback (
0 commit comments