Skip to content

Commit 3ca29df

Browse files
authored
support qwen3-next eagle3 (#14607)
1 parent 9bb1260 commit 3ca29df

File tree

1 file changed

+73
-6
lines changed

1 file changed

+73
-6
lines changed

python/sglang/srt/models/qwen3_next.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -547,12 +547,18 @@ def forward(
547547
self,
548548
hidden_states: torch.Tensor,
549549
residual: Optional[torch.Tensor],
550+
captured_last_layer_outputs: Optional[list[torch.Tensor]] = None,
550551
**kwargs,
551552
):
552553
forward_batch = kwargs.get("forward_batch", None)
553554

554-
hidden_states, residual = self.layer_communicator.prepare_attn(
555-
hidden_states, residual, forward_batch
555+
hidden_states, residual = (
556+
self.layer_communicator.prepare_attn_and_capture_last_layer_outputs(
557+
hidden_states,
558+
residual,
559+
forward_batch,
560+
captured_last_layer_outputs=captured_last_layer_outputs,
561+
)
556562
)
557563

558564
if not forward_batch.forward_mode.is_idle():
@@ -769,10 +775,16 @@ def forward(
769775
hidden_states: torch.Tensor,
770776
residual: Optional[torch.Tensor],
771777
forward_batch: ForwardBatch,
778+
captured_last_layer_outputs: Optional[list[torch.Tensor]] = None,
772779
**kwargs: Any,
773780
):
774-
hidden_states, residual = self.layer_communicator.prepare_attn(
775-
hidden_states, residual, forward_batch
781+
hidden_states, residual = (
782+
self.layer_communicator.prepare_attn_and_capture_last_layer_outputs(
783+
hidden_states,
784+
residual,
785+
forward_batch,
786+
captured_last_layer_outputs=captured_last_layer_outputs,
787+
)
776788
)
777789

778790
if not forward_batch.forward_mode.is_idle():
@@ -844,6 +856,14 @@ def get_layer(idx: int, prefix: str):
844856
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
845857
self.infer_count = 0
846858

859+
# For EAGLE3 support
860+
self.layers_to_capture = []
861+
862+
def set_eagle3_layers_to_capture(self, layers_to_capture: list[int]):
863+
self.layers_to_capture = layers_to_capture
864+
for layer_id in self.layers_to_capture:
865+
setattr(self.layers[layer_id], "_is_layer_to_capture", True)
866+
847867
def forward(
848868
self,
849869
input_ids: torch.Tensor,
@@ -862,6 +882,7 @@ def forward(
862882
hidden_states = self.embed_tokens(input_ids)
863883

864884
residual = None
885+
aux_hidden_states = []
865886
for i in range(len(self.layers)):
866887
layer = self.layers[i]
867888
with get_global_expert_distribution_recorder().with_current_layer(i):
@@ -871,6 +892,11 @@ def forward(
871892
hidden_states=hidden_states,
872893
residual=residual,
873894
forward_batch=forward_batch,
895+
captured_last_layer_outputs=(
896+
aux_hidden_states
897+
if getattr(layer, "_is_layer_to_capture", False)
898+
else None
899+
),
874900
)
875901

876902
if not forward_batch.forward_mode.is_idle():
@@ -879,7 +905,10 @@ def forward(
879905
else:
880906
hidden_states, _ = self.norm(hidden_states, residual)
881907

882-
return hidden_states
908+
if len(aux_hidden_states) == 0:
909+
return hidden_states
910+
911+
return hidden_states, aux_hidden_states
883912

884913

885914
class HybridLayerType(enum.Enum):
@@ -915,6 +944,8 @@ def __init__(
915944
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
916945
)
917946
self.logits_processor = LogitsProcessor(config)
947+
# For EAGLE3 support
948+
self.capture_aux_hidden_states = False
918949

919950
self._routed_experts_weights_of_layer = LazyValue(
920951
lambda: {
@@ -939,8 +970,12 @@ def forward(
939970
):
940971
hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
941972

973+
aux_hidden_states = None
974+
if self.capture_aux_hidden_states:
975+
hidden_states, aux_hidden_states = hidden_states
976+
942977
return self.logits_processor(
943-
input_ids, hidden_states, self.lm_head, forward_batch
978+
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
944979
)
945980

946981
def get_embed_and_head(self):
@@ -954,6 +989,21 @@ def set_embed_and_head(self, embed, head):
954989
torch.cuda.empty_cache()
955990
torch.cuda.synchronize()
956991

992+
def get_embed(self):
993+
return self.model.embed_tokens.weight
994+
995+
def set_embed(self, embed):
996+
# NOTE: If draft hidden size != target hidden size, the embed weight cannot be shared for EAGLE3
997+
if (
998+
hasattr(self.config, "target_hidden_size")
999+
and self.config.target_hidden_size != self.config.hidden_size
1000+
):
1001+
return
1002+
del self.model.embed_tokens.weight
1003+
self.model.embed_tokens.weight = embed
1004+
torch.cuda.empty_cache()
1005+
torch.cuda.synchronize()
1006+
9571007
def load_weights(
9581008
self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False
9591009
) -> Set[str]:
@@ -1071,6 +1121,23 @@ def get_model_config_for_expert_location(cls, config):
10711121
num_groups=None,
10721122
)
10731123

1124+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[list[int]] = None):
1125+
if not self.pp_group.is_last_rank:
1126+
return
1127+
1128+
self.capture_aux_hidden_states = True
1129+
if layer_ids is None:
1130+
num_layers = self.config.num_hidden_layers
1131+
self.model.set_eagle3_layers_to_capture(
1132+
[
1133+
2,
1134+
num_layers // 2,
1135+
num_layers - 3,
1136+
]
1137+
) # Specific layers for EAGLE3 support
1138+
else:
1139+
self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids])
1140+
10741141

10751142
EntryClass = Qwen3NextForCausalLM
10761143

0 commit comments

Comments
 (0)