Skip to content

Commit 6a8b09a

Browse files
b8zhongsfiisf
authored andcommitted
[Perf] Add Flashinfer DeepGEMM SM90 for SwapAB Optimization (sgl-project#15514)
Co-authored-by: Brayden Zhong <b8zhong@users.noreply.github.com>
1 parent 08e3bba commit 6a8b09a

File tree

3 files changed

+80
-2
lines changed

3 files changed

+80
-2
lines changed

python/sglang/srt/layers/quantization/fp8_utils.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ class Fp8GemmRunnerBackend(Enum):
134134
"""Enum for FP8 GEMM runner backend selection."""
135135

136136
AUTO = "auto"
137-
FLASHINFER = "flashinfer_trtllm"
137+
FLASHINFER_TRTLLM = "flashinfer_trtllm"
138+
FLASHINFER_DEEPGEMM = "flashinfer_deepgemm"
138139
CUTLASS = "cutlass"
139140
DEEP_GEMM = "deep_gemm"
140141
TRITON = "triton"
@@ -144,7 +145,10 @@ def is_auto(self) -> bool:
144145
return self == Fp8GemmRunnerBackend.AUTO
145146

146147
def is_flashinfer(self) -> bool:
147-
return self == Fp8GemmRunnerBackend.FLASHINFER
148+
return self == Fp8GemmRunnerBackend.FLASHINFER_TRTLLM
149+
150+
def is_flashinfer_deepgemm(self) -> bool:
151+
return self == Fp8GemmRunnerBackend.FLASHINFER_DEEPGEMM
148152

149153
def is_cutlass(self) -> bool:
150154
return self == Fp8GemmRunnerBackend.CUTLASS
@@ -170,6 +174,10 @@ def _check_cutlass_block_fp8_hardware_support() -> bool:
170174
if is_blackwell_supported() and is_flashinfer_available():
171175
from flashinfer.gemm import gemm_fp8_nt_groupwise
172176

177+
if is_sm90_supported() and is_flashinfer_available():
178+
# FlashInfer SM90 DeepGEMM with automatic swapAB optimization for small M
179+
from flashinfer.gemm import fp8_blockscale_gemm_sm90
180+
173181

174182
def dispatch_w8a8_block_fp8_linear() -> Callable:
175183
"""
@@ -200,6 +208,15 @@ def _dispatch_explicit_backend(backend: Fp8GemmRunnerBackend) -> Callable:
200208
)
201209
return flashinfer_gemm_w8a8_block_fp8_linear_with_fallback
202210

211+
elif backend.is_flashinfer_deepgemm():
212+
if not (is_sm90_supported() and is_flashinfer_available()):
213+
raise RuntimeError(
214+
"FlashInfer DeepGEMM with swapAB requested via --fp8-gemm-backend=flashinfer_deepgemm, "
215+
"but it's not available. This backend requires Hopper (SM90) GPUs and FlashInfer "
216+
"to be installed."
217+
)
218+
return flashinfer_deepgemm_w8a8_block_fp8_linear_with_fallback
219+
203220
elif backend.is_cutlass():
204221
if not _check_cutlass_block_fp8_hardware_support():
205222
raise RuntimeError(
@@ -333,6 +350,60 @@ def flashinfer_gemm_w8a8_block_fp8_linear_with_fallback(
333350
return output.to(dtype=input_2d.dtype).view(*output_shape)
334351

335352

353+
def flashinfer_deepgemm_w8a8_block_fp8_linear_with_fallback(
354+
input: torch.Tensor,
355+
weight: torch.Tensor,
356+
block_size: List[int],
357+
weight_scale: torch.Tensor,
358+
input_scale: Optional[torch.Tensor] = None,
359+
bias: Optional[torch.Tensor] = None,
360+
) -> torch.Tensor:
361+
"""
362+
FlashInfer DeepGEMM backend for SM90 (Hopper) with swapAB optimization.
363+
364+
Uses flashinfer.gemm.fp8_blockscale_gemm_sm90 which automatically selects
365+
the swapAB kernel for small M dimensions (M < 32) for better performance
366+
during decoding/low batch size scenarios.
367+
368+
For SM90 (Hopper), this uses the DeepGEMM JIT with automatic swapAB selection.
369+
"""
370+
assert input_scale is None
371+
372+
output_dtype = input.dtype
373+
dtype_supported = output_dtype == torch.bfloat16
374+
375+
# fp8_blockscale_gemm_sm90 requires: N % 64 == 0, K % 128 == 0
376+
shape_supported = weight.shape[0] % 64 == 0 and weight.shape[1] % 128 == 0
377+
378+
if not (shape_supported and dtype_supported):
379+
if weight_scale.dtype == torch.int32:
380+
weight_scale = _unpack_ue8m0_scale_for_triton(
381+
weight_scale, weight.shape, block_size
382+
)
383+
return triton_w8a8_block_fp8_linear(
384+
input, weight, block_size, weight_scale, input_scale, bias
385+
)
386+
387+
input_2d = input.view(-1, input.shape[-1])
388+
output_shape = [*input.shape[:-1], weight.shape[0]]
389+
390+
# - input: (M, K) BF16 or FP8
391+
# - weight: (N, K) FP8 with weight_scale
392+
# - weight_scale: (N, K//128) for per-token or (N//128, K//128) for per-block
393+
394+
output = fp8_blockscale_gemm_sm90(
395+
input_2d,
396+
weight,
397+
input_scale=None, # BF16 input, internal quantization
398+
weight_scale=weight_scale,
399+
out_dtype=output_dtype,
400+
)
401+
402+
if bias is not None:
403+
output += bias
404+
return output.view(*output_shape)
405+
406+
336407
def cutlass_w8a8_block_fp8_linear_with_fallback(
337408
input: torch.Tensor,
338409
weight: torch.Tensor,

python/sglang/srt/server_args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@
198198
"auto",
199199
"deep_gemm",
200200
"flashinfer_trtllm",
201+
"flashinfer_deepgemm",
201202
"cutlass",
202203
"triton",
203204
"aiter",
@@ -3685,6 +3686,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
36853686
"Options: 'auto' (default, auto-selects based on hardware), "
36863687
"'deep_gemm' (JIT-compiled; enabled by default on NVIDIA Hopper (SM90) and Blackwell (SM100) when DeepGEMM is installed), "
36873688
"'flashinfer_trtllm' (optimal for Blackwell and low-latency), "
3689+
"'flashinfer_deepgemm' (Hopper SM90 only; uses swapAB optimization for small M dimensions in decoding), "
36883690
"'cutlass' (optimal for Hopper/Blackwell GPUs and high-throughput), "
36893691
"'triton' (fallback, widely compatible), "
36903692
"'aiter' (ROCm only). "

test/srt/test_fp8_blockwise_gemm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,10 @@ class TestFP8BlockwiseGemmFlashinferTrtllm(FP8BlockwiseGemmBase, unittest.TestCa
6969
backend = "flashinfer_trtllm"
7070

7171

72+
@unittest.skipIf(get_device_sm() != 90, "Test requires CUDA SM 90")
73+
class TestFP8BlockwiseGemmFlashinferDeepGemm(FP8BlockwiseGemmBase, unittest.TestCase):
74+
backend = "flashinfer_deepgemm"
75+
76+
7277
if __name__ == "__main__":
7378
unittest.main()

0 commit comments

Comments
 (0)