Skip to content

Commit ded068a

Browse files
authored
Add LMF2 MoE model architecture (sgl-project#17997)
1 parent 5d185ef commit ded068a

File tree

5 files changed

+913
-1
lines changed

5 files changed

+913
-1
lines changed

python/sglang/srt/configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sglang.srt.configs.kimi_vl import KimiVLConfig
1515
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
1616
from sglang.srt.configs.lfm2 import Lfm2Config
17+
from sglang.srt.configs.lfm2_moe import Lfm2MoeConfig
1718
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
1819
from sglang.srt.configs.nano_nemotron_vl import NemotronH_Nano_VL_V2_Config
1920
from sglang.srt.configs.nemotron_h import NemotronHConfig
@@ -50,6 +51,7 @@
5051
"DotsOCRConfig",
5152
"FalconH1Config",
5253
"Lfm2Config",
54+
"Lfm2MoeConfig",
5355
"NemotronHConfig",
5456
"NemotronH_Nano_VL_V2_Config",
5557
"JetNemotronConfig",
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# Copyright 2025 SGLang Team
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
"""LFM2-MoE (Liquid Foundation Model 2 - Mixture of Experts) configuration
14+
15+
Note: HF transformers has Lfm2MoeConfig in v5.0.0rc2 (unreleased).
16+
Once released, we could inherit from it like Lfm2Config does with HFLfm2Config.
17+
For now, we define a standalone config to support the model immediately.
18+
"""
19+
20+
from typing import List, Optional
21+
22+
from transformers import CONFIG_MAPPING
23+
from transformers.configuration_utils import PretrainedConfig
24+
25+
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
26+
27+
28+
class Lfm2MoeConfig(PretrainedConfig):
29+
"""
30+
Configuration for LFM2-MoE models (e.g., LiquidAI/LFM2-8B-A1B).
31+
32+
LFM2-MoE is a hybrid architecture with:
33+
- Attention layers and ShortConv layers (like dense LFM2)
34+
- MoE (Mixture of Experts) FFN layers with sigmoid routing
35+
36+
Key MoE specifics:
37+
- First `num_dense_layers` use dense MLP, rest use MoE
38+
- Sigmoid routing (not softmax) with expert_bias for load balancing
39+
- expert_bias is fp32 for numerical stability
40+
"""
41+
42+
model_type = "lfm2_moe"
43+
keys_to_ignore_at_inference = ["past_key_values"]
44+
45+
def __init__(
46+
self,
47+
vocab_size: int = 65536,
48+
hidden_size: int = 2048,
49+
intermediate_size: int = 7168,
50+
moe_intermediate_size: int = 1792,
51+
num_hidden_layers: int = 32,
52+
num_attention_heads: int = 32,
53+
num_key_value_heads: int = 8,
54+
max_position_embeddings: int = 128000,
55+
initializer_range: float = 0.02,
56+
norm_eps: float = 1e-5,
57+
use_cache: bool = True,
58+
pad_token_id: int = 0,
59+
bos_token_id: int = 1,
60+
eos_token_id: int = 2,
61+
tie_word_embeddings: bool = True,
62+
rope_parameters: Optional[dict] = None,
63+
conv_bias: bool = False,
64+
conv_L_cache: int = 3,
65+
# MoE-specific parameters
66+
num_dense_layers: int = 2,
67+
num_experts: int = 32,
68+
num_experts_per_tok: int = 4,
69+
use_expert_bias: bool = True,
70+
routed_scaling_factor: float = 1.0,
71+
norm_topk_prob: bool = True,
72+
# Layer types
73+
layer_types: Optional[List[str]] = None,
74+
**kwargs,
75+
):
76+
self.vocab_size = vocab_size
77+
self.hidden_size = hidden_size
78+
self.intermediate_size = intermediate_size
79+
self.moe_intermediate_size = moe_intermediate_size
80+
self.num_hidden_layers = num_hidden_layers
81+
self.num_attention_heads = num_attention_heads
82+
self.num_key_value_heads = num_key_value_heads
83+
self.max_position_embeddings = max_position_embeddings
84+
self.initializer_range = initializer_range
85+
self.norm_eps = norm_eps
86+
self.use_cache = use_cache
87+
88+
# Conv parameters
89+
self.conv_bias = conv_bias
90+
self.conv_L_cache = conv_L_cache
91+
92+
# MoE parameters
93+
self.num_dense_layers = num_dense_layers
94+
self.num_experts = num_experts
95+
self.num_experts_per_tok = num_experts_per_tok
96+
self.use_expert_bias = use_expert_bias
97+
self.routed_scaling_factor = routed_scaling_factor
98+
self.norm_topk_prob = norm_topk_prob
99+
100+
# Layer types (attention vs conv)
101+
self.layer_types = layer_types
102+
103+
# RoPE parameters
104+
self.rope_parameters = rope_parameters
105+
106+
# Validate layer_types length matches num_hidden_layers
107+
if layer_types is not None and len(layer_types) != num_hidden_layers:
108+
raise ValueError(
109+
f"layer_types length ({len(layer_types)}) must match "
110+
f"num_hidden_layers ({num_hidden_layers})"
111+
)
112+
113+
# Handle tie_embedding alias from original config
114+
tie_word_embeddings = kwargs.pop("tie_embedding", tie_word_embeddings)
115+
116+
super().__init__(
117+
pad_token_id=pad_token_id,
118+
bos_token_id=bos_token_id,
119+
eos_token_id=eos_token_id,
120+
tie_word_embeddings=tie_word_embeddings,
121+
**kwargs,
122+
)
123+
124+
@property
125+
def full_attention_layer_ids(self) -> List[int]:
126+
"""Return indices of attention layers for KV cache."""
127+
if self.layer_types is None:
128+
return []
129+
return [i for i, lt in enumerate(self.layer_types) if lt == "full_attention"]
130+
131+
@property
132+
def linear_layer_ids(self) -> List[int]:
133+
"""Return indices of conv layers for conv state cache."""
134+
if self.layer_types is None:
135+
return []
136+
return [
137+
i for i, lt in enumerate(self.layer_types) if lt in ("conv", "short_conv")
138+
]
139+
140+
@property
141+
def mamba_chunk_size(self) -> int:
142+
"""Return chunk size for Mamba2 backend. LFM2 doesn't use chunking."""
143+
return 1
144+
145+
@property
146+
def mamba2_cache_params(self) -> Optional[Mamba2CacheParams]:
147+
"""
148+
Get cache params for HybridReqToTokenPool initialization.
149+
150+
LFM2-MoE uses ShortConv layers with a small fixed-size cache.
151+
"""
152+
from sglang.srt.layers.dp_attention import get_attention_tp_size
153+
154+
conv_layer_ids = self.linear_layer_ids
155+
if not conv_layer_ids:
156+
return None
157+
158+
hidden_size = self.hidden_size
159+
# conv_L_cache in config is kernel_size (e.g., 3)
160+
conv_kernel = int(self.conv_L_cache)
161+
# actual cache size is kernel_size - 1 (e.g., 2 for kernel=3)
162+
163+
try:
164+
tp_size = get_attention_tp_size()
165+
except (AssertionError, RuntimeError):
166+
tp_size = 1
167+
168+
shape = Mamba2StateShape.create(
169+
tp_world_size=tp_size,
170+
intermediate_size=hidden_size,
171+
n_groups=1,
172+
num_heads=tp_size, # Ensures divide works; temporal state is empty anyway
173+
head_dim=hidden_size,
174+
state_size=0,
175+
conv_kernel=conv_kernel,
176+
)
177+
178+
# Uses default mamba2_state_dtype() which reads SGLANG_MAMBA_CONV_DTYPE env var
179+
# (defaults to bfloat16). Set SGLANG_MAMBA_CONV_DTYPE=float16 for fp16 inference.
180+
return Mamba2CacheParams(
181+
shape=shape,
182+
layers=conv_layer_ids,
183+
)
184+
185+
186+
# Register with transformers CONFIG_MAPPING so AutoConfig.from_pretrained()
187+
# can instantiate our config class when loading models with model_type="lfm2_moe"
188+
try:
189+
CONFIG_MAPPING.register("lfm2_moe", Lfm2MoeConfig)
190+
except Exception:
191+
# Already registered or registration failed - use direct assignment
192+
CONFIG_MAPPING._extra_content["lfm2_moe"] = Lfm2MoeConfig

python/sglang/srt/model_executor/model_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
JetVLMConfig,
3737
KimiLinearConfig,
3838
Lfm2Config,
39+
Lfm2MoeConfig,
3940
NemotronH_Nano_VL_V2_Config,
4041
NemotronHConfig,
4142
Qwen3_5Config,
@@ -1571,7 +1572,9 @@ def mamba2_config(self):
15711572
pattern = getattr(config, "mtp_hybrid_override_pattern", None)
15721573
if pattern is not None and "M" not in pattern:
15731574
return None
1574-
if isinstance(config, FalconH1Config | NemotronHConfig | Lfm2Config):
1575+
if isinstance(
1576+
config, FalconH1Config | NemotronHConfig | Lfm2Config | Lfm2MoeConfig
1577+
):
15751578
return config
15761579
if isinstance(config, NemotronH_Nano_VL_V2_Config):
15771580
return config.llm_config

0 commit comments

Comments
 (0)