Skip to content

Commit 2ff0880

Browse files
authored
[Fix] GLM 4.7 + NVFP4 + MTP (#17166)
1 parent 2c1b164 commit 2ff0880

File tree

6 files changed

+114
-9
lines changed

6 files changed

+114
-9
lines changed

python/sglang/srt/configs/model_config.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -834,20 +834,29 @@ def _verify_quantization(self) -> None:
834834
if self.quantization is None:
835835
self.quantization = quant_method
836836
elif self.quantization != quant_method:
837-
# Allow auto-detection of quantization from checkpoint for draft model
838-
# even if it differs from main model's quantization
839-
if self.is_draft_model:
837+
# Check if the CLI-specified quantization is compatible with HF config's quant_method
838+
is_compatible = (
839+
self.quantization in compatible_quantization_methods
840+
and quant_method
841+
in compatible_quantization_methods[self.quantization]
842+
)
843+
if is_compatible:
844+
# Keep the CLI-specified quantization (e.g., modelopt_fp4) even if
845+
# HF config says "modelopt" - they are compatible
846+
logger.info(
847+
f"Using CLI-specified quantization ({self.quantization}) which is "
848+
f"compatible with HF config quant_method ({quant_method})."
849+
)
850+
elif self.is_draft_model:
851+
# Allow auto-detection of quantization from checkpoint for draft model
852+
# only if the CLI quantization is not compatible
840853
logger.info(
841854
f"Draft model quantization ({quant_method}) differs from "
842855
f"main model quantization ({self.quantization}). "
843856
f"Using draft model's detected quantization: {quant_method}"
844857
)
845858
self.quantization = quant_method
846-
elif (
847-
self.quantization not in compatible_quantization_methods
848-
or quant_method
849-
not in compatible_quantization_methods[self.quantization]
850-
):
859+
else:
851860
raise ValueError(
852861
"Quantization method specified in the model config "
853862
f"({quant_method}) does not match the quantization "

python/sglang/srt/model_loader/loader.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
get_quant_config,
9797
gguf_quant_weights_iterator,
9898
initialize_dummy_weights,
99+
maybe_add_mtp_safetensors,
99100
multi_thread_pt_weights_iterator,
100101
multi_thread_safetensors_weights_iterator,
101102
np_cache_weights_iterator,
@@ -321,13 +322,17 @@ class Source:
321322
fall_back_to_pt: bool = True
322323
"""Whether .pt weights can be used."""
323324

325+
model_config: Optional["ModelConfig"] = None
326+
"""The model configuration (for checking architecture, etc)."""
327+
324328
@classmethod
325329
def init_new(cls, model_config: ModelConfig, model):
326330
return cls(
327331
model_config.model_path,
328332
model_config.revision,
329333
prefix="",
330334
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
335+
model_config=model_config,
331336
)
332337

333338
def __init__(self, load_config: LoadConfig):
@@ -471,6 +476,15 @@ def _get_weights_iterator(
471476
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
472477
source.model_or_path, source.revision, source.fall_back_to_pt
473478
)
479+
480+
if use_safetensors and source.model_config is not None:
481+
hf_weights_files = maybe_add_mtp_safetensors(
482+
hf_weights_files,
483+
hf_folder,
484+
"model.safetensors.index.json",
485+
source.model_config.hf_config,
486+
)
487+
474488
if self.load_config.load_format == LoadFormat.NPCACHE:
475489
# Currently np_cache only support *.bin checkpoints
476490
assert use_safetensors is False

python/sglang/srt/model_loader/weight_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,44 @@ def filter_duplicate_safetensors_files(
594594
return hf_weights_files
595595

596596

597+
def maybe_add_mtp_safetensors(
598+
hf_weights_files: List[str], hf_folder: str, index_file: str, hf_config
599+
) -> List[str]:
600+
"""
601+
Auto-detect and add mtp.safetensors for GLM4Moe MTP/NextN models if:
602+
1. mtp.safetensors exists in the model directory
603+
2. mtp.safetensors is NOT in the index (checkpoint packaging bug)
604+
3. Model architecture is Glm4MoeForCausalLM with num_nextn_predict_layers > 0
605+
606+
This works around incorrectly packaged FP4 checkpoints like
607+
baseten-admin/glm-4.7-fp4 where mtp.safetensors exists but
608+
isn't referenced in model.safetensors.index.json.
609+
"""
610+
# Only apply for GLM4Moe architecture with nextn layers
611+
arch = getattr(hf_config, "architectures", [None])[0]
612+
num_nextn_layers = getattr(hf_config, "num_nextn_predict_layers", 0)
613+
if not (
614+
arch in ["Glm4MoeForCausalLM", "Glm4MoeForCausalLMNextN"]
615+
and num_nextn_layers > 0
616+
):
617+
return hf_weights_files
618+
619+
# Check if mtp.safetensors exists and is not already in the file list
620+
mtp_path = os.path.join(hf_folder, "mtp.safetensors")
621+
if not os.path.isfile(mtp_path) or mtp_path in hf_weights_files:
622+
return hf_weights_files
623+
624+
# mtp.safetensors exists but not in index - this is a bug
625+
logger.warning(
626+
f"Found mtp.safetensors but it's not referenced in {index_file}. "
627+
f"This is a checkpoint packaging bug. Auto-adding it for loading. "
628+
f"Please report this to the checkpoint provider."
629+
)
630+
631+
# Add it to the files list
632+
return hf_weights_files + [mtp_path]
633+
634+
597635
def filter_files_not_needed_for_inference(hf_weights_files: List[str]) -> List[str]:
598636
"""
599637
Exclude files that are not needed for inference.

python/sglang/srt/models/glm4_moe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@
6363
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
6464
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
6565
from sglang.srt.layers.moe.topk import TopK
66-
from sglang.srt.layers.moe.utils import filter_moe_weight_param_global_expert
66+
from sglang.srt.layers.moe.utils import (
67+
RoutingMethodType,
68+
filter_moe_weight_param_global_expert,
69+
)
6770
from sglang.srt.layers.quantization.base_config import QuantizationConfig
6871
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
6972
from sglang.srt.layers.radix_attention import RadixAttention
@@ -376,6 +379,7 @@ def __init__(
376379
intermediate_size=config.moe_intermediate_size,
377380
quant_config=quant_config,
378381
routed_scaling_factor=self.routed_scaling_factor,
382+
routing_method_type=RoutingMethodType.DeepSeekV3,
379383
prefix=add_prefix("experts", prefix),
380384
)
381385

python/sglang/srt/server_args.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from sglang.srt.utils.common import (
3636
LORA_TARGET_ALL_MODULES,
3737
SUPPORTED_LORA_TARGET_MODULES,
38+
check_pkg_version_at_least,
3839
configure_ipv6,
3940
cpu_has_amx_support,
4041
get_bool_env_var,
@@ -1508,6 +1509,27 @@ def _handle_model_specific_adjustments(self):
15081509
)
15091510
self.disable_radix_cache = True
15101511
self.disable_overlap_schedule = False
1512+
elif model_arch in ["Glm4MoeForCausalLM"]:
1513+
if is_sm100_supported():
1514+
quantization_config = getattr(hf_config, "quantization_config", None)
1515+
quant_method = (
1516+
quantization_config.get("quant_method")
1517+
if quantization_config is not None
1518+
else None
1519+
)
1520+
if self.quantization is None and quant_method is not None:
1521+
self.quantization = quant_method
1522+
if (
1523+
self.quantization == "modelopt_fp4"
1524+
and self.moe_a2a_backend == "none"
1525+
and self.moe_runner_backend == "auto"
1526+
):
1527+
# Only enable flashinfer_trtllm if flashinfer-python version is >= 0.6.2
1528+
if check_pkg_version_at_least("flashinfer-python", "0.6.2"):
1529+
self.moe_runner_backend = "flashinfer_trtllm"
1530+
logger.info(
1531+
"Use flashinfer_trtllm as MoE runner backend on sm100 for Glm4MoeForCausalLM"
1532+
)
15111533

15121534
# Mamba radix cache v2
15131535
if self.enable_mamba_extra_buffer():

python/sglang/srt/utils/common.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,24 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
10451045
)
10461046

10471047

1048+
def check_pkg_version_at_least(pkg: str, min_version: str) -> bool:
1049+
"""
1050+
Check if a package is installed and meets the minimum version requirement.
1051+
1052+
Args:
1053+
pkg: Package name (distribution name, e.g., "flashinfer-python")
1054+
min_version: Minimum version required (e.g., "0.6.2")
1055+
1056+
Returns:
1057+
True if package is installed and version >= min_version, False otherwise
1058+
"""
1059+
try:
1060+
installed_version = version(pkg)
1061+
return pkg_version.parse(installed_version) >= pkg_version.parse(min_version)
1062+
except PackageNotFoundError:
1063+
return False
1064+
1065+
10481066
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
10491067
"""Kill the process and all its child processes."""
10501068
# Remove sigchld handler to avoid spammy logs.

0 commit comments

Comments
 (0)