Skip to content

[ModelOpt] Fix broken Qwen3-235B-A22B-Instruct-2507-NVFP4 launch#18189

Merged
b8zhong merged 1 commit intosgl-project:mainfrom
bzhng-development:vz/fix-qwen-3-nvfp4-moe-235b
Feb 8, 2026
Merged

[ModelOpt] Fix broken Qwen3-235B-A22B-Instruct-2507-NVFP4 launch#18189
b8zhong merged 1 commit intosgl-project:mainfrom
bzhng-development:vz/fix-qwen-3-nvfp4-moe-235b

Conversation

@vincentzed
Copy link
Contributor

Motivation

Support https://huggingface.co/nvidia/Qwen3-235B-A22B-Instruct-2507-NVFP4

Previously it failed to launch on SGLang.
However, the 30B NVFP4 always worked.
https://huggingface.co/nvidia/Qwen3-30B-A3B-NVFP4/blob/main/hf_quant_config.json

Root cause

Qwen3-235B-A22B-Instruct-2507-NVFP4 (Broken)

MoE Expert MLPs (gate_proj, up_proj, down_proj) NVFP4

Attention Projections (q_proj, k_proj, v_proj, o_proj) BF16

Router Gates (mlp.gate) BF16

Lm_head BF16
The config.json quantization_config.ignore list contains all 94 layers × 3 projections = 282
entries for q/k/v:
"ignore": [
"model.layers.0.self_attn.k_proj",
"model.layers.0.self_attn.q_proj",
"model.layers.0.self_attn.v_proj",
"model.layers.0.mlp.gate",
// ... repeated for all 94 layers
"lm_head"
]

Qwen3-30B-A3B-NVFP4 (Works)

MoE Expert MLPs (gate_proj, up_proj, down_proj) NVFP4

Attention Projections (q_proj, k_proj, v_proj, o_proj) NVFP4

Router Gates (mlp.gate) BF16

Lm_head BF16

The config.json quantization_config.ignore list contains only router gates:
"ignore": [
"model.layers.0.mlp.gate",
// ... repeated for all 48 layers
"lm_head"
]

Qwen3MoeForCausalLM has no packed_modules_mapping

Buggy behaviour:

Layer: model.layers.0.self_attn.qkv_proj
├── packed_modules_mapping = {}
├── is_layer_skipped("...qkv_proj", ignore_list, {})
│ ├── proj_name = "qkv_proj"
│ ├── "qkv_proj" in {} → False
│ └── Fallback: is "qkv_proj" in ignore_list? → NO (only q_proj, k_proj, v_proj are)
├── Returns: False (NOT skipped)
├── Quant method: ModelOptFp4LinearMethod
└── Creates param shape: [128, 2048] (FP4 packed, input_size/2 for k_proj shard)

30B
Layer: model.layers.0.self_attn.qkv_proj
├── packed_modules_mapping = {}
├── is_layer_skipped("...qkv_proj", ignore_list, {})
│ ├── proj_name = "qkv_proj"
│ ├── "qkv_proj" in {} → False
│ └── Fallback: is "qkv_proj" in ignore_list? → NO
├── Returns: False (NOT skipped)
├── Quant method: ModelOptFp4LinearMethod
└── Creates param shape: [64, 1024] (FP4 packed)

(Really, it oly worked by coincidence)

So we add packed_modules_mapping for Qwen3MoE

Which produces:
Layer: model.layers.0.self_attn.qkv_proj
├── packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"], ...}
├── is_layer_skipped("...qkv_proj", ignore_list, mapping)
│ ├── proj_name = "qkv_proj"
│ ├── "qkv_proj" in mapping → True
│ ├── Expand to: [...q_proj, ...k_proj, ...v_proj]
│ └── All three in ignore_list? → YES
├── Returns: True (SKIPPED)
├── Quant method: UnquantizedLinearMethod
└── Creates param shape: [128, 4096] (BF16, full size)

30B behavoiur is unchange

  • Bug:
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 3022, in run_scheduler_process
    scheduler = Scheduler(
                ^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 350, in __init__
    self.init_model_worker()
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 539, in init_model_worker
    self.init_tp_model_worker()
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 501, in init_tp_model_worker
    self.tp_worker = TpModelWorker(
                     ^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 242, in __init__
    self._init_model_runner()
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 325, in _init_model_runner
    self._model_runner = ModelRunner(
                         ^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 391, in __init__
    self.initialize(min_per_gpu_memory)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 471, in initialize
    self.load_model()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 909, in load_model
    self.model = self.loader.load_model(
                 ^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 2618, in load_model
    return super().load_model(
           ^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 671, in load_model
    self.load_weights_and_postprocess(
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 684, in load_weights_and_postprocess
    model.load_weights(weights)
  File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_moe.py", line 1074, in load_weights
    weight_loader(param, loaded_weight, shard_id)
  File "/sgl-workspace/sglang/python/sglang/srt/layers/linear.py", line 1004, in weight_loader_v2
    param.load_qkv_weight(
  File "/sgl-workspace/sglang/python/sglang/srt/layers/parameter.py", line 266, in load_qkv_weight
    param_data.shape == loaded_weight.shape
AssertionError: param_data.shape=torch.Size([128, 2048]), loaded_weight.shape=torch.Size([128, 4096])

Modifications

Add packed modules mapping

Accuracy Tests

python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 2500 --parallel 2500 --platinum
Loading GSM8K Platinum dataset from HuggingFace...
100%|█████████████████████████████████████████████████████████| 1209/1209 [00:16<00:00, 74.04it/s]
Accuracy: 0.980
Invalid: 0.000
Latency: 16.442 s
Output throughput: 9916.770 token/s
❯ python3 -m sglang.launch_server --model-path nvidia/Qwen3-30B-A3B-NVFP4 --quantization modelopt_fp4

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
@vincentzed
Copy link
Contributor Author

vincentzed commented Feb 3, 2026

/tag-and-rerun-ci again

@vincentzed
Copy link
Contributor Author

/tag-and-rerun-ci

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @vincentzed, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical loading issue for the Qwen3-235B-A22B-Instruct-2507-NVFP4 model by enhancing its quantization configuration. It introduces a mapping for fused modules, ensuring that the model's attention and MLP projections are correctly handled during the quantization process, preventing an AssertionError and allowing the model to launch successfully. This change aligns the behavior of the 235B model with the working 30B NVFP4 model's quantization logic.

Highlights

  • Fix for Qwen3-235B-A22B-Instruct-2507-NVFP4 Launch Failure: Addresses an AssertionError during model loading for the Qwen3-235B-A22B-Instruct-2507-NVFP4 model, which previously failed to launch due to incorrect handling of quantized layers, specifically qkv_proj.
  • Introduction of packed_modules_mapping: Adds a packed_modules_mapping dictionary to the Qwen3MoeForCausalLM class. This mapping enables the quantization configuration to correctly identify and skip fused modules like qkv_proj and gate_up_proj when their component parts are specified in the ignore list, ensuring proper quantization behavior.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/models/qwen3_moe.py
    • Added the packed_modules_mapping dictionary to the Qwen3MoeForCausalLM class. This dictionary defines how fused modules (e.g., qkv_proj and gate_up_proj) correspond to their individual component weights, which is essential for accurate quantization skipping based on the ignore list.
Activity
  • The pull request was created to address a model loading issue for Qwen3-235B-A22B-Instruct-2507-NVFP4.
  • The author implemented changes to introduce packed_modules_mapping in Qwen3MoeForCausalLM to resolve the quantization error.
  • Accuracy tests were performed on the GSM8K Platinum dataset, showing an accuracy of 0.980 after the changes.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@github-actions github-actions bot added the run-ci label Feb 3, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a packed_modules_mapping to the Qwen3MoeForCausalLM class to fix an issue with quantization, specifically for the Qwen3-235B model. The mapping correctly identifies fused modules, ensuring that quantization skipping rules are applied properly. The change is well-explained and appears to be a correct and necessary fix. I have one minor suggestion to improve code clarity by using typing.ClassVar for the new class attribute.

@b8zhong b8zhong enabled auto-merge (squash) February 3, 2026 22:09
# Mapping from fused module names to their component weight names.
# Required for quantization configs (e.g., ModelOpt FP4) to correctly identify
# which layers should be skipped based on the exclude_modules/ignore list.
packed_modules_mapping = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question as for this PR #18224

Copy link
Contributor Author

@vincentzed vincentzed Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because qkv and o proj are NVFP4 in these recipes, which is not the case for the nv chkpt

Nv does not show recipe for NVFP4 model.

Copy link
Contributor

@ssshinigami ssshinigami left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks not correct to change models for this fix. It is quantization specific things, and should be in quantization part.

@b8zhong b8zhong merged commit ca36d88 into sgl-project:main Feb 8, 2026
513 of 560 checks passed
charlesHsuGG pushed a commit to charlesHsuGG/sglang that referenced this pull request Feb 9, 2026
…-project#18189)

Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
…-project#18189)

Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
1StepForever pushed a commit to 1StepForever/sglang that referenced this pull request Feb 26, 2026
* www/pr/ks: (265 commits)
  [BugFix][PD]Fix metadata_buffer_index leak when aborted in PD (sgl-project#17483)
  Refactoring Mooncake TE as a shared distributed component (sgl-project#17810)
  [ModelOPT] Support Qwen 3 Next Coder NVFP4 (sgl-project#18224)
  Update author information in pyproject.toml (sgl-project#18453)
  [Kimi-K2.5] Fix missing `quant_config` in `KimiK25` (sgl-project#18440)
  Add tensor parallelism support to LFM2 ShortConv layers (sgl-project#17777)
  [diffusion] chore: revise process title (sgl-project#18446)
  Fix TRT-LLM MLA backend applying k_scale to BF16 KV cache in BMM1 (sgl-project#18396)
  [diffusion] refactor: group component loaders under the component_loaders/ directory (sgl-project#18438)
  [ModelOpt] Fix broken Qwen3-235B-A22B-Instruct-2507-NVFP4 launch (sgl-project#18189)
  [diffusion] feat: support efficient sequence shard (sgl-project#18161)
  [CI] fix: notebook ci may not working (sgl-project#18417)
  fix: sync server_args.kv_cache_dtype when detecting FP8 KV cache (sgl-project#18394)
  [Fix] Fix backend selection after flashinfer version update (sgl-project#18364)
  [diffusion] platform: support WAN/FLUX/Qwen-Image/Qwen-Image-edit on Ascend (sgl-project#13662)
  fix: fix NVFP4 Kimi-K2.5 weight mapping and exclude list (sgl-project#18370)
  [diffusion] feat: support saving videos directly on the server to avoid the overhead of tensor transfer (sgl-project#18253)
  [diffusion] fix: respect dist_timeout option (sgl-project#18386)
  [Doc] Fix outdated `--fp4-gemm-backend` documentation (sgl-project#18350)
  [diffusion] fix: remove unnecessary norm_type argument from GLM-Image dits (sgl-project#18382)
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants