@@ -134,7 +134,8 @@ class Fp8GemmRunnerBackend(Enum):
134134 """Enum for FP8 GEMM runner backend selection."""
135135
136136 AUTO = "auto"
137- FLASHINFER = "flashinfer_trtllm"
137+ FLASHINFER_TRTLLM = "flashinfer_trtllm"
138+ FLASHINFER_DEEPGEMM = "flashinfer_deepgemm"
138139 CUTLASS = "cutlass"
139140 DEEP_GEMM = "deep_gemm"
140141 TRITON = "triton"
@@ -144,7 +145,10 @@ def is_auto(self) -> bool:
144145 return self == Fp8GemmRunnerBackend .AUTO
145146
146147 def is_flashinfer (self ) -> bool :
147- return self == Fp8GemmRunnerBackend .FLASHINFER
148+ return self == Fp8GemmRunnerBackend .FLASHINFER_TRTLLM
149+
150+ def is_flashinfer_deepgemm (self ) -> bool :
151+ return self == Fp8GemmRunnerBackend .FLASHINFER_DEEPGEMM
148152
149153 def is_cutlass (self ) -> bool :
150154 return self == Fp8GemmRunnerBackend .CUTLASS
@@ -170,6 +174,10 @@ def _check_cutlass_block_fp8_hardware_support() -> bool:
170174if is_blackwell_supported () and is_flashinfer_available ():
171175 from flashinfer .gemm import gemm_fp8_nt_groupwise
172176
177+ if is_sm90_supported () and is_flashinfer_available ():
178+ # FlashInfer SM90 DeepGEMM with automatic swapAB optimization for small M
179+ from flashinfer .gemm import fp8_blockscale_gemm_sm90
180+
173181
174182def dispatch_w8a8_block_fp8_linear () -> Callable :
175183 """
@@ -200,6 +208,15 @@ def _dispatch_explicit_backend(backend: Fp8GemmRunnerBackend) -> Callable:
200208 )
201209 return flashinfer_gemm_w8a8_block_fp8_linear_with_fallback
202210
211+ elif backend .is_flashinfer_deepgemm ():
212+ if not (is_sm90_supported () and is_flashinfer_available ()):
213+ raise RuntimeError (
214+ "FlashInfer DeepGEMM with swapAB requested via --fp8-gemm-backend=flashinfer_deepgemm, "
215+ "but it's not available. This backend requires Hopper (SM90) GPUs and FlashInfer "
216+ "to be installed."
217+ )
218+ return flashinfer_deepgemm_w8a8_block_fp8_linear_with_fallback
219+
203220 elif backend .is_cutlass ():
204221 if not _check_cutlass_block_fp8_hardware_support ():
205222 raise RuntimeError (
@@ -333,6 +350,60 @@ def flashinfer_gemm_w8a8_block_fp8_linear_with_fallback(
333350 return output .to (dtype = input_2d .dtype ).view (* output_shape )
334351
335352
353+ def flashinfer_deepgemm_w8a8_block_fp8_linear_with_fallback (
354+ input : torch .Tensor ,
355+ weight : torch .Tensor ,
356+ block_size : List [int ],
357+ weight_scale : torch .Tensor ,
358+ input_scale : Optional [torch .Tensor ] = None ,
359+ bias : Optional [torch .Tensor ] = None ,
360+ ) -> torch .Tensor :
361+ """
362+ FlashInfer DeepGEMM backend for SM90 (Hopper) with swapAB optimization.
363+
364+ Uses flashinfer.gemm.fp8_blockscale_gemm_sm90 which automatically selects
365+ the swapAB kernel for small M dimensions (M < 32) for better performance
366+ during decoding/low batch size scenarios.
367+
368+ For SM90 (Hopper), this uses the DeepGEMM JIT with automatic swapAB selection.
369+ """
370+ assert input_scale is None
371+
372+ output_dtype = input .dtype
373+ dtype_supported = output_dtype == torch .bfloat16
374+
375+ # fp8_blockscale_gemm_sm90 requires: N % 64 == 0, K % 128 == 0
376+ shape_supported = weight .shape [0 ] % 64 == 0 and weight .shape [1 ] % 128 == 0
377+
378+ if not (shape_supported and dtype_supported ):
379+ if weight_scale .dtype == torch .int32 :
380+ weight_scale = _unpack_ue8m0_scale_for_triton (
381+ weight_scale , weight .shape , block_size
382+ )
383+ return triton_w8a8_block_fp8_linear (
384+ input , weight , block_size , weight_scale , input_scale , bias
385+ )
386+
387+ input_2d = input .view (- 1 , input .shape [- 1 ])
388+ output_shape = [* input .shape [:- 1 ], weight .shape [0 ]]
389+
390+ # - input: (M, K) BF16 or FP8
391+ # - weight: (N, K) FP8 with weight_scale
392+ # - weight_scale: (N, K//128) for per-token or (N//128, K//128) for per-block
393+
394+ output = fp8_blockscale_gemm_sm90 (
395+ input_2d ,
396+ weight ,
397+ input_scale = None , # BF16 input, internal quantization
398+ weight_scale = weight_scale ,
399+ out_dtype = output_dtype ,
400+ )
401+
402+ if bias is not None :
403+ output += bias
404+ return output .view (* output_shape )
405+
406+
336407def cutlass_w8a8_block_fp8_linear_with_fallback (
337408 input : torch .Tensor ,
338409 weight : torch .Tensor ,
0 commit comments