Skip to content

Commit 22af14c

Browse files
mmangkadJohnsonms
authored andcommitted
fix: fix NVFP4 Kimi-K2.5 weight mapping and exclude list (sgl-project#18370)
1 parent f6b8216 commit 22af14c

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

python/sglang/srt/layers/quantization/modelopt_quant.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
CombineInput,
6565
StandardDispatchOutput,
6666
)
67+
from sglang.srt.models.utils import WeightsMapper
6768

6869
fp4_quantize = None
6970
try:
@@ -304,6 +305,22 @@ def get_config_filenames(cls) -> List[str]:
304305
def get_scaled_act_names(self) -> List[str]:
305306
return []
306307

308+
def apply_weight_name_mapper(
309+
self, hf_to_sglang_mapper: "WeightsMapper"
310+
): # noqa: B027
311+
# Map excluded module patterns from HF layout to sglang layout.
312+
# Ref: HF hf_quant_config.json for nvidia/Kimi-K2.5-NVFP4
313+
# https://huggingface.co/nvidia/Kimi-K2.5-NVFP4/blob/main/hf_quant_config.json
314+
if self.exclude_modules:
315+
mapped = hf_to_sglang_mapper.apply_list(self.exclude_modules)
316+
expanded: List[str] = []
317+
for name in mapped:
318+
expanded.append(name)
319+
if name.startswith("language_model."):
320+
expanded.append(name.removeprefix("language_model."))
321+
# Preserve order, drop duplicates.
322+
self.exclude_modules = list(dict.fromkeys(expanded))
323+
307324

308325
class ModelOptFp8Config(ModelOptQuantConfig):
309326
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""

python/sglang/srt/models/kimi_k25.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from sglang.srt.model_loader.weight_utils import default_weight_loader
3535
from sglang.srt.models.deepseek_v2 import DeepseekV3ForCausalLM
3636
from sglang.srt.models.kimi_vl_moonvit import MLP2
37+
from sglang.srt.models.utils import WeightsMapper
3738
from sglang.srt.utils import add_prefix
3839

3940
KIMIV_VT_INFER_MAX_PATCH_NUM = 16328
@@ -643,6 +644,15 @@ def vision_tower_forward_auto(
643644

644645

645646
class KimiK25ForConditionalGeneration(nn.Module):
647+
# Support nvidia/Kimi-K2.5-NVFP4 naming: language_model.layers.*.
648+
# Ref: HF config.json for nvidia/Kimi-K2.5-NVFP4
649+
# https://huggingface.co/nvidia/Kimi-K2.5-NVFP4/blob/main/config.json
650+
hf_to_sglang_mapper = WeightsMapper(
651+
orig_to_new_prefix={
652+
"language_model.layers.": "language_model.model.layers.",
653+
}
654+
)
655+
646656
def __init__(
647657
self,
648658
config: KimiK25Config,
@@ -710,7 +720,9 @@ def forward(
710720

711721
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
712722
"""Load weights for the model, separating vision and language weights"""
713-
weights = list(weights)
723+
mapper = getattr(self, "hf_to_sglang_mapper", None)
724+
if mapper is not None:
725+
weights = mapper.apply(weights)
714726

715727
# Separate vision tower weights and language model weights
716728
vision_weights = []

0 commit comments

Comments
 (0)