|
192 | 192 | "aiter", |
193 | 193 | ] |
194 | 194 |
|
| 195 | +FP4_GEMM_RUNNER_BACKEND_CHOICES = [ |
| 196 | + "auto", |
| 197 | + "cudnn", |
| 198 | + "cutlass", |
| 199 | + "trtllm", |
| 200 | +] |
| 201 | + |
195 | 202 | MAMBA_SSM_DTYPE_CHOICES = ["float32", "bfloat16"] |
196 | 203 |
|
197 | 204 | MAMBA_SCHEDULER_STRATEGY_CHOICES = ["auto", "no_buffer", "extra_buffer"] |
@@ -226,6 +233,10 @@ def add_fp8_gemm_runner_backend_choices(choices): |
226 | 233 | FP8_GEMM_RUNNER_BACKEND_CHOICES.extend(choices) |
227 | 234 |
|
228 | 235 |
|
| 236 | +def add_fp4_gemm_runner_backend_choices(choices): |
| 237 | + FP4_GEMM_RUNNER_BACKEND_CHOICES.extend(choices) |
| 238 | + |
| 239 | + |
229 | 240 | def add_deterministic_attention_backend_choices(choices): |
230 | 241 | DETERMINISTIC_ATTENTION_BACKEND_CHOICES.extend(choices) |
231 | 242 |
|
@@ -423,6 +434,7 @@ class ServerArgs: |
423 | 434 | grammar_backend: Optional[str] = None |
424 | 435 | mm_attention_backend: Optional[str] = None |
425 | 436 | fp8_gemm_runner_backend: str = "auto" |
| 437 | + fp4_gemm_runner_backend: str = "auto" |
426 | 438 | nsa_prefill_backend: str = "flashmla_sparse" |
427 | 439 | nsa_decode_backend: str = "fa3" |
428 | 440 | disable_flashinfer_autotune: bool = False |
@@ -3538,6 +3550,20 @@ def add_cli_args(parser: argparse.ArgumentParser): |
3538 | 3550 | "NOTE: This replaces the deprecated environment variables " |
3539 | 3551 | "SGLANG_ENABLE_FLASHINFER_FP8_GEMM and SGLANG_SUPPORT_CUTLASS_BLOCK_FP8.", |
3540 | 3552 | ) |
| 3553 | + parser.add_argument( |
| 3554 | + "--fp4-gemm-backend", |
| 3555 | + type=str, |
| 3556 | + choices=FP4_GEMM_RUNNER_BACKEND_CHOICES, |
| 3557 | + default=ServerArgs.fp4_gemm_runner_backend, |
| 3558 | + dest="fp4_gemm_runner_backend", |
| 3559 | + help="Choose the runner backend for NVFP4 GEMM operations. " |
| 3560 | + "Options: 'auto' (default, selects between cudnn/cutlass based on CUDA/cuDNN version), " |
| 3561 | + "'cudnn' (cuDNN backend, optimal on CUDA 13+ with cuDNN 9.15+), " |
| 3562 | + "'cutlass' (CUTLASS backend, optimal on CUDA 12), " |
| 3563 | + "'trtllm' (TensorRT-LLM backend, requires different weight preparation with shuffling). " |
| 3564 | + "NOTE: This replaces the deprecated environment variable " |
| 3565 | + "SGLANG_FLASHINFER_FP4_GEMM_BACKEND.", |
| 3566 | + ) |
3541 | 3567 | parser.add_argument( |
3542 | 3568 | "--disable-flashinfer-autotune", |
3543 | 3569 | default=ServerArgs.disable_flashinfer_autotune, |
|
0 commit comments