Skip to content

[ROCm] Enable VLLM triton FP8 moe for gfx1201, tuned for Qwen3-30B-A3B-FP8 tp=2 and Qwen/Qwen3.5-35B-A3B-FP8 tp=2#79

Open
big-yellow-duck wants to merge 2 commits intomainfrom
rdna4-moe
Open

[ROCm] Enable VLLM triton FP8 moe for gfx1201, tuned for Qwen3-30B-A3B-FP8 tp=2 and Qwen/Qwen3.5-35B-A3B-FP8 tp=2#79
big-yellow-duck wants to merge 2 commits intomainfrom
rdna4-moe

Conversation

@big-yellow-duck
Copy link
Copy Markdown

@big-yellow-duck big-yellow-duck commented Mar 12, 2026

Purpose

Enable Triton FP8 MoE for RDNA4 (gfx12xx) in vLLM so FP8 MoE models can run on ROCm with these GPUs.

This resolves vllm/issues/36105, where FP8 MoE model startup failed with NotImplementedError: No FP8 MoE backend supports the deployment configuration.

This PR also includes tuned Triton MoE performance improvements for:

  • Qwen/Qwen3-30B-A3B-Instruct-2507-FP8
  • Qwen/Qwen3.5-35B-A3B-FP8

Test Plan

benchmark Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 and Qwen/Qwen3.5-35B-A3B-FP8 with triton moe tuned on 2 Radeon PRO 9700

# Qwen/Qwen3-30B-A3B-Instruct-2507-FP8
VLLM_ROCM_USE_AITER=0 vllm serve Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 -tp 2 --enable-expert-parallel --gpu-memory-utilization 0.95 --max-model-len 65536

# Qwen/Qwen3.5-35B-A3B-FP8
Qwen/Qwen3.5-35B-A3B-FP8 -tp 2 --enable-expert-parallel --gpu-memory-utilization 0.95 --max-model-len 65536

Test Results

Qwen/Qwen3-30B-A3B-Instruct-2507-FP8

TTFT (ms)

ISL-OSL Triton MoE Untuned Triton MoE Tuned
512-512 1176.38 982.67
1024-1024 884.88 711.35
2048-2048 1573.39 1250.83
4096-4096 5349.30 4329.35
8192-1024 12695.60 11351.34
16384-2048 126817.12 131106.68
Average 24749.45 24955.37

TPOT (ms)

ISL-OSL Triton MoE Untuned Triton MoE Tuned
512-512 35.04 33.18
1024-1024 35.87 34.94
2048-2048 40.13 36.65
4096-4096 47.98 45.05
8192-1024 81.05 73.46
16384-2048 81.00 70.46
Average 53.51 48.95

E2E Latency (ms)

ISL-OSL Triton MoE Untuned Triton MoE Tuned
512-512 19084.12 17936.15
1024-1024 37578.58 36449.97
2048-2048 83725.59 76264.82
4096-4096 201846.60 188823.78
8192-1024 95608.87 86497.21
16384-2048 292618.43 275340.95
Average 121743.70 113552.15

Qwen/Qwen3.5-35B-A3B-FP8

TTFT (ms)

ISL-OSL Triton MoE Untuned Triton MoE Tuned
512-512 12534.25 6276.20
1024-1024 1376.32 1112.45
2048-2048 2428.18 2005.34
4096-4096 4696.00 3932.73
8192-1024 10276.23 8854.24
16384-2048 25947.62 22757.39
Average 9543.10 7489.73

TPOT (ms)

ISL-OSL Triton MoE Untuned Triton MoE Tuned
512-512 43.96 40.95
1024-1024 44.70 41.82
2048-2048 44.98 42.79
4096-4096 47.20 44.48
8192-1024 67.49 61.32
16384-2048 80.14 73.03
Average 54.75 50.73

E2E Latency (ms)

ISL-OSL Triton MoE Untuned Triton MoE Tuned
512-512 35000.01 27203.91
1024-1024 47107.52 43893.47
2048-2048 94500.88 89586.83
4096-4096 197986.20 186077.49
8192-1024 79313.44 71585.55
16384-2048 189994.94 172245.95
Average 107317.17 98432.20

Accuracy checks

GSM8K Accuracy

Qwen/Qwen3-30B-A3B-Instruct-2507-FP8

Metric Triton MoE Untuned Triton MoE Tuned
exact_match, strict-match 83.24% 83.40%
exact_match, flexible-extract 85.82% 86.43%

Qwen/Qwen3.5-35B-A3B-FP8

Metric Triton MoE Untuned Triton MoE Tuned
exact_match, strict-match 86.50% 87.11%
exact_match, flexible-extract 87.49% 88.02%

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

"""
if current_platform.is_rocm() and IS_AITER_FOUND:
from vllm.platforms.rocm import on_gfx9
from vllm.platforms.rocm import on_gfx9, on_gfx12x
Copy link
Copy Markdown
Member

@tjtanaa tjtanaa Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be include here because this is only about pushing the triton tuned config json

@big-yellow-duck big-yellow-duck changed the title [ROCm] Enable VLLM triton FP8 moe for gfx1201 [ROCm] Enable VLLM triton FP8 moe for gfx1201, tuned for Qwen3-30B-A3B-FP8 tp=2 Mar 12, 2026
@big-yellow-duck big-yellow-duck marked this pull request as ready for review March 16, 2026 04:52
@big-yellow-duck big-yellow-duck changed the title [ROCm] Enable VLLM triton FP8 moe for gfx1201, tuned for Qwen3-30B-A3B-FP8 tp=2 [ROCm] Enable VLLM triton FP8 moe for gfx1201, tuned for Qwen3-30B-A3B-FP8 tp=2 and Qwen/Qwen3.5-35B-A3B-FP8 tp=2 Mar 25, 2026
@nie3e
Copy link
Copy Markdown

nie3e commented Mar 25, 2026

@big-yellow-duck great job on this PR.
How did you create the configs for R9700?

@big-yellow-duck
Copy link
Copy Markdown
Author

@big-yellow-duck great job on this PR. How did you create the configs for R9700?

its a modified version of benchmark_moe.py
will add to this pr

@big-yellow-duck
Copy link
Copy Markdown
Author

the previous tuning was run with rocm7.2 that is not in rocm/vllm-dev:nightly. the tuning is redone with the patched tuning script. in rocm.vllm-dev:nightly which has rocm 7.0.0.

Tuning steps

Tuned on 2x Radeon AI PRO 9700

# run in vllm root
HIP_VISIBLE_DEVICES='0,1' python benchmarks/kernels/benchmark_moe.py --model Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 --tp-size 2 --enable-expert-parallel --tune --save-dir rocm7.0-tune --dtype fp8_w8a8 

then move the result so to the config dir

#example
mv rocm7.0-tune/E=64,N=768,device_name=AMD_Radeon_R9700,dtype=fp8_w8a8,block_shape=[128,128].json vllm/model_executor/layers/fused_moe/configs/

the same steps are repeated for Qwen/Qwen3.5-35B-A3B-FP8

@big-yellow-duck big-yellow-duck force-pushed the rdna4-moe branch 2 times, most recently from e1e23e8 to 0b1dbe5 Compare March 27, 2026 13:38
Copy link
Copy Markdown
Member

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@big-yellow-duck big-yellow-duck force-pushed the rdna4-moe branch 2 times, most recently from e3e741e to c4eb217 Compare March 28, 2026 08:27
Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: big-yellow-duck <jeffaw99@hotmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants