Skip to content

Commit 4df74eb

Browse files
b8zhongvincentzed
andauthored
[Refactor] Add -fp4-gemm-backend to replace SGLANG_FLASHINFER_FP4_GEMM_BACKEND (#16534)
Co-authored-by: Vincent Zhong <207368749+vincentzed@users.noreply.github.com>
1 parent f3a7c7d commit 4df74eb

File tree

9 files changed

+144
-18
lines changed

9 files changed

+144
-18
lines changed

python/sglang/bench_one_batch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
6868
from sglang.srt.entrypoints.engine import _set_envs_and_config
6969
from sglang.srt.layers.moe import initialize_moe_config
70+
from sglang.srt.layers.quantization.fp4_utils import initialize_fp4_gemm_config
7071
from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config
7172
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
7273
from sglang.srt.managers.scheduler_dp_attn_mixin import prepare_mlp_sync_batch_raw
@@ -647,6 +648,7 @@ def latency_test(
647648
):
648649
initialize_moe_config(server_args)
649650
initialize_fp8_gemm_config(server_args)
651+
initialize_fp4_gemm_config(server_args)
650652

651653
# Set CPU affinity
652654
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):

python/sglang/srt/environ.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,10 @@ def _convert_SGL_to_SGLANG():
503503
"SGLANG_SUPPORT_CUTLASS_BLOCK_FP8",
504504
"It will be completely removed in 0.5.7. Please use '--fp8-gemm-backend=cutlass' instead.",
505505
)
506+
_warn_deprecated_env_to_cli_flag(
507+
"SGLANG_FLASHINFER_FP4_GEMM_BACKEND",
508+
"It will be completely removed in 0.5.9. Please use '--fp4-gemm-backend' instead.",
509+
)
506510
_warn_deprecated_env_to_cli_flag(
507511
"SGLANG_SCHEDULER_DECREASE_PREFILL_IDLE",
508512
"Please use '--enable-prefill-delayer' instead.",

python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
1616
CompressedTensorsScheme,
1717
)
18+
from sglang.srt.layers.quantization.fp4_utils import get_fp4_gemm_runner_backend
1819
from sglang.srt.layers.quantization.modelopt_quant import (
19-
FLASHINFER_FP4_GEMM_BACKEND,
2020
enable_flashinfer_fp4_gemm,
2121
fp4_gemm,
2222
fp4_quantize,
@@ -98,7 +98,7 @@ def process_weights_after_loading(self, layer) -> None:
9898
layer.weight_global_scale.max().to(torch.float32), requires_grad=False
9999
)
100100

101-
if FLASHINFER_FP4_GEMM_BACKEND == "trtllm":
101+
if get_fp4_gemm_runner_backend().is_trtllm():
102102
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
103103
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
104104
# layout but we use our own quantization so we have to call
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
from enum import Enum
5+
from typing import TYPE_CHECKING
6+
7+
from sglang.srt.environ import envs
8+
9+
if TYPE_CHECKING:
10+
from sglang.srt.server_args import ServerArgs
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class Fp4GemmRunnerBackend(Enum):
16+
"""Enum for FP4 GEMM runner backend selection."""
17+
18+
AUTO = "auto"
19+
CUDNN = "cudnn"
20+
CUTLASS = "cutlass"
21+
TRTLLM = "trtllm"
22+
23+
def is_auto(self) -> bool:
24+
return self == Fp4GemmRunnerBackend.AUTO
25+
26+
def is_cudnn(self) -> bool:
27+
return self == Fp4GemmRunnerBackend.CUDNN
28+
29+
def is_cutlass(self) -> bool:
30+
return self == Fp4GemmRunnerBackend.CUTLASS
31+
32+
def is_trtllm(self) -> bool:
33+
return self == Fp4GemmRunnerBackend.TRTLLM
34+
35+
36+
FP4_GEMM_RUNNER_BACKEND: Fp4GemmRunnerBackend | None = None
37+
38+
39+
def initialize_fp4_gemm_config(server_args: ServerArgs) -> None:
40+
"""Initialize FP4 GEMM configuration from server args."""
41+
global FP4_GEMM_RUNNER_BACKEND
42+
43+
backend = server_args.fp4_gemm_runner_backend
44+
45+
# Handle deprecated env var for backward compatibility
46+
# TODO: Remove this in a future version
47+
if envs.SGLANG_FLASHINFER_FP4_GEMM_BACKEND.is_set():
48+
env_backend = envs.SGLANG_FLASHINFER_FP4_GEMM_BACKEND.get()
49+
if backend == "auto":
50+
logger.warning(
51+
"SGLANG_FLASHINFER_FP4_GEMM_BACKEND is deprecated. "
52+
f"Please use '--fp4-gemm-backend={env_backend}' instead."
53+
)
54+
backend = env_backend
55+
else:
56+
logger.warning(
57+
f"FP4 GEMM backend set to '{backend}' via --fp4-gemm-backend overrides "
58+
"environment variable SGLANG_FLASHINFER_FP4_GEMM_BACKEND. "
59+
"Using server argument value."
60+
)
61+
62+
FP4_GEMM_RUNNER_BACKEND = Fp4GemmRunnerBackend(backend)
63+
64+
65+
def get_fp4_gemm_runner_backend() -> Fp4GemmRunnerBackend:
66+
"""Get the current FP4 GEMM runner backend."""
67+
global FP4_GEMM_RUNNER_BACKEND
68+
if FP4_GEMM_RUNNER_BACKEND is None:
69+
FP4_GEMM_RUNNER_BACKEND = Fp4GemmRunnerBackend.AUTO
70+
return FP4_GEMM_RUNNER_BACKEND

python/sglang/srt/layers/quantization/modelopt_quant.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
QuantizationConfig,
3131
QuantizeMethodBase,
3232
)
33+
from sglang.srt.layers.quantization.fp4_utils import get_fp4_gemm_runner_backend
3334
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
3435
from sglang.srt.layers.quantization.fp8_utils import (
3536
apply_fp8_linear,
@@ -126,7 +127,10 @@ def fp4_gemm(
126127
out_dtype: torch.dtype,
127128
out_features: int,
128129
) -> torch.Tensor:
129-
backend = FLASHINFER_FP4_GEMM_BACKEND if FLASHINFER_FP4_GEMM_BACKEND else "cutlass"
130+
fp4_backend = get_fp4_gemm_runner_backend()
131+
# TODO(shuw@nvidia.com): Remove the "cutlass" default override after flashinfer 0.6.0
132+
# and let flashinfer's auto backend selection handle it.
133+
backend = fp4_backend.value if not fp4_backend.is_auto() else "cutlass"
130134
if enable_flashinfer_fp4_gemm:
131135
return flashinfer_fp4_gemm(
132136
input, weight, input_sf, weight_sf, alpha, out_dtype, backend=backend
@@ -150,7 +154,6 @@ def _sgl_kernel_scaled_fp4_quant_fake(
150154

151155
# TODO make it true by default when the DeepEP PR is merged
152156
MOE_NVFP4_DISPATCH = envs.SGLANG_MOE_NVFP4_DISPATCH.get()
153-
FLASHINFER_FP4_GEMM_BACKEND = envs.SGLANG_FLASHINFER_FP4_GEMM_BACKEND.get()
154157
# Supported activation schemes for the current configuration
155158
ACTIVATION_SCHEMES = ["static"]
156159

@@ -1152,7 +1155,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
11521155
layer.input_scale_inv = Parameter(
11531156
(1 / input_scale_2).to(torch.float32), requires_grad=False
11541157
)
1155-
if FLASHINFER_FP4_GEMM_BACKEND == "trtllm":
1158+
if get_fp4_gemm_runner_backend().is_trtllm():
11561159
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
11571160
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
11581161
# layout but we use our own quantization so we have to call
@@ -1221,11 +1224,6 @@ def apply(
12211224
if enable_flashinfer_fp4_gemm:
12221225
w = layer.weight.T
12231226
w_scale_interleaved = layer.weight_scale_interleaved.T
1224-
# TODO(shuw@nvidia.com)
1225-
# Remove the default after flashinfer bumped to 0.5.1
1226-
backend = (
1227-
FLASHINFER_FP4_GEMM_BACKEND if FLASHINFER_FP4_GEMM_BACKEND else "cutlass"
1228-
)
12291227
out = fp4_gemm(
12301228
x_fp4,
12311229
w,

python/sglang/srt/managers/scheduler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
get_attention_tp_group,
6666
)
6767
from sglang.srt.layers.moe import initialize_moe_config
68+
from sglang.srt.layers.quantization.fp4_utils import initialize_fp4_gemm_config
6869
from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config
6970
from sglang.srt.managers.io_struct import (
7071
AbortReq,
@@ -473,10 +474,9 @@ def init_moe_gemm_config(self):
473474
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
474475
initialize_moe_config(self.server_args)
475476

476-
# Initialize GEMM-related configuration (currently FP8 Blockwise GEMM backend).
477-
# Other GEMM backends (e.g. FP4, BF16, etc.) can be added here in the future.
478-
# This is needed for FP8 quantization.
477+
# Initialize GEMM-related configuration for FP8 and FP4 backends.
479478
initialize_fp8_gemm_config(self.server_args)
479+
initialize_fp4_gemm_config(self.server_args)
480480

481481
# This must be called after initialize_moe_config
482482
self.require_mlp_sync = require_mlp_sync(self.server_args)

python/sglang/srt/server_args.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,13 @@
192192
"aiter",
193193
]
194194

195+
FP4_GEMM_RUNNER_BACKEND_CHOICES = [
196+
"auto",
197+
"cudnn",
198+
"cutlass",
199+
"trtllm",
200+
]
201+
195202
MAMBA_SSM_DTYPE_CHOICES = ["float32", "bfloat16"]
196203

197204
MAMBA_SCHEDULER_STRATEGY_CHOICES = ["auto", "no_buffer", "extra_buffer"]
@@ -226,6 +233,10 @@ def add_fp8_gemm_runner_backend_choices(choices):
226233
FP8_GEMM_RUNNER_BACKEND_CHOICES.extend(choices)
227234

228235

236+
def add_fp4_gemm_runner_backend_choices(choices):
237+
FP4_GEMM_RUNNER_BACKEND_CHOICES.extend(choices)
238+
239+
229240
def add_deterministic_attention_backend_choices(choices):
230241
DETERMINISTIC_ATTENTION_BACKEND_CHOICES.extend(choices)
231242

@@ -423,6 +434,7 @@ class ServerArgs:
423434
grammar_backend: Optional[str] = None
424435
mm_attention_backend: Optional[str] = None
425436
fp8_gemm_runner_backend: str = "auto"
437+
fp4_gemm_runner_backend: str = "auto"
426438
nsa_prefill_backend: str = "flashmla_sparse"
427439
nsa_decode_backend: str = "fa3"
428440
disable_flashinfer_autotune: bool = False
@@ -3538,6 +3550,20 @@ def add_cli_args(parser: argparse.ArgumentParser):
35383550
"NOTE: This replaces the deprecated environment variables "
35393551
"SGLANG_ENABLE_FLASHINFER_FP8_GEMM and SGLANG_SUPPORT_CUTLASS_BLOCK_FP8.",
35403552
)
3553+
parser.add_argument(
3554+
"--fp4-gemm-backend",
3555+
type=str,
3556+
choices=FP4_GEMM_RUNNER_BACKEND_CHOICES,
3557+
default=ServerArgs.fp4_gemm_runner_backend,
3558+
dest="fp4_gemm_runner_backend",
3559+
help="Choose the runner backend for NVFP4 GEMM operations. "
3560+
"Options: 'auto' (default, selects between cudnn/cutlass based on CUDA/cuDNN version), "
3561+
"'cudnn' (cuDNN backend, optimal on CUDA 13+ with cuDNN 9.15+), "
3562+
"'cutlass' (CUTLASS backend, optimal on CUDA 12), "
3563+
"'trtllm' (TensorRT-LLM backend, requires different weight preparation with shuffling). "
3564+
"NOTE: This replaces the deprecated environment variable "
3565+
"SGLANG_FLASHINFER_FP4_GEMM_BACKEND.",
3566+
)
35413567
parser.add_argument(
35423568
"--disable-flashinfer-autotune",
35433569
default=ServerArgs.disable_flashinfer_autotune,

test/srt/run_suite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
TestFile("test_deepseek_v3_fp4_4gpu.py", 1500),
3636
TestFile("test_fp8_blockwise_gemm.py", 280),
3737
TestFile("test_gpt_oss_4gpu.py", 700),
38-
TestFile("test_llama31_fp4.py", 90),
38+
TestFile("test_nvfp4_gemm.py", 360),
3939
],
4040
# "per-commit-8-gpu-b200": [
4141
# TestFile("test_mistral_large3_basic.py", 275), # Moved to nightly - large model
Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,27 @@
88
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
99
DEFAULT_URL_FOR_TEST,
1010
popen_launch_server,
11+
try_cached_model,
1112
)
1213

13-
MODEL_PATH = "nvidia/Llama-3.1-8B-Instruct-FP4"
14+
MODEL_PATH = "nvidia/Llama-3.1-8B-Instruct-NVFP4"
1415

1516

16-
@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher")
17-
class TestLlama31FP4(unittest.TestCase):
17+
class FP4GemmBase:
18+
backend = None
19+
1820
@classmethod
1921
def setUpClass(cls):
20-
cls.model = MODEL_PATH
22+
if cls.backend is None:
23+
raise NotImplementedError("Subclass must set 'backend' attribute")
24+
cls.model = try_cached_model(MODEL_PATH)
2125
cls.base_url = DEFAULT_URL_FOR_TEST
2226
other_args = [
2327
"--trust-remote-code",
2428
"--quantization",
2529
"modelopt_fp4",
30+
"--fp4-gemm-backend",
31+
cls.backend,
2632
]
2733
cls.process = popen_launch_server(
2834
cls.model,
@@ -52,5 +58,25 @@ def test_gsm8k(self):
5258
self.assertGreater(metrics["accuracy"], 0.64)
5359

5460

61+
@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher")
62+
class TestFP4GemmAuto(FP4GemmBase, unittest.TestCase):
63+
backend = "auto"
64+
65+
66+
@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher")
67+
class TestFP4GemmCutlass(FP4GemmBase, unittest.TestCase):
68+
backend = "cutlass"
69+
70+
71+
@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher")
72+
class TestFP4GemmCudnn(FP4GemmBase, unittest.TestCase):
73+
backend = "cudnn"
74+
75+
76+
@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher")
77+
class TestFP4GemmTrtllm(FP4GemmBase, unittest.TestCase):
78+
backend = "trtllm"
79+
80+
5581
if __name__ == "__main__":
5682
unittest.main()

0 commit comments

Comments
 (0)