@@ -547,12 +547,18 @@ def forward(
547547 self ,
548548 hidden_states : torch .Tensor ,
549549 residual : Optional [torch .Tensor ],
550+ captured_last_layer_outputs : Optional [list [torch .Tensor ]] = None ,
550551 ** kwargs ,
551552 ):
552553 forward_batch = kwargs .get ("forward_batch" , None )
553554
554- hidden_states , residual = self .layer_communicator .prepare_attn (
555- hidden_states , residual , forward_batch
555+ hidden_states , residual = (
556+ self .layer_communicator .prepare_attn_and_capture_last_layer_outputs (
557+ hidden_states ,
558+ residual ,
559+ forward_batch ,
560+ captured_last_layer_outputs = captured_last_layer_outputs ,
561+ )
556562 )
557563
558564 if not forward_batch .forward_mode .is_idle ():
@@ -769,10 +775,16 @@ def forward(
769775 hidden_states : torch .Tensor ,
770776 residual : Optional [torch .Tensor ],
771777 forward_batch : ForwardBatch ,
778+ captured_last_layer_outputs : Optional [list [torch .Tensor ]] = None ,
772779 ** kwargs : Any ,
773780 ):
774- hidden_states , residual = self .layer_communicator .prepare_attn (
775- hidden_states , residual , forward_batch
781+ hidden_states , residual = (
782+ self .layer_communicator .prepare_attn_and_capture_last_layer_outputs (
783+ hidden_states ,
784+ residual ,
785+ forward_batch ,
786+ captured_last_layer_outputs = captured_last_layer_outputs ,
787+ )
776788 )
777789
778790 if not forward_batch .forward_mode .is_idle ():
@@ -844,6 +856,14 @@ def get_layer(idx: int, prefix: str):
844856 self .norm = GemmaRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
845857 self .infer_count = 0
846858
859+ # For EAGLE3 support
860+ self .layers_to_capture = []
861+
862+ def set_eagle3_layers_to_capture (self , layers_to_capture : list [int ]):
863+ self .layers_to_capture = layers_to_capture
864+ for layer_id in self .layers_to_capture :
865+ setattr (self .layers [layer_id ], "_is_layer_to_capture" , True )
866+
847867 def forward (
848868 self ,
849869 input_ids : torch .Tensor ,
@@ -862,6 +882,7 @@ def forward(
862882 hidden_states = self .embed_tokens (input_ids )
863883
864884 residual = None
885+ aux_hidden_states = []
865886 for i in range (len (self .layers )):
866887 layer = self .layers [i ]
867888 with get_global_expert_distribution_recorder ().with_current_layer (i ):
@@ -871,6 +892,11 @@ def forward(
871892 hidden_states = hidden_states ,
872893 residual = residual ,
873894 forward_batch = forward_batch ,
895+ captured_last_layer_outputs = (
896+ aux_hidden_states
897+ if getattr (layer , "_is_layer_to_capture" , False )
898+ else None
899+ ),
874900 )
875901
876902 if not forward_batch .forward_mode .is_idle ():
@@ -879,7 +905,10 @@ def forward(
879905 else :
880906 hidden_states , _ = self .norm (hidden_states , residual )
881907
882- return hidden_states
908+ if len (aux_hidden_states ) == 0 :
909+ return hidden_states
910+
911+ return hidden_states , aux_hidden_states
883912
884913
885914class HybridLayerType (enum .Enum ):
@@ -915,6 +944,8 @@ def __init__(
915944 use_attn_tp_group = get_global_server_args ().enable_dp_lm_head ,
916945 )
917946 self .logits_processor = LogitsProcessor (config )
947+ # For EAGLE3 support
948+ self .capture_aux_hidden_states = False
918949
919950 self ._routed_experts_weights_of_layer = LazyValue (
920951 lambda : {
@@ -939,8 +970,12 @@ def forward(
939970 ):
940971 hidden_states = self .model (input_ids , positions , forward_batch , inputs_embeds )
941972
973+ aux_hidden_states = None
974+ if self .capture_aux_hidden_states :
975+ hidden_states , aux_hidden_states = hidden_states
976+
942977 return self .logits_processor (
943- input_ids , hidden_states , self .lm_head , forward_batch
978+ input_ids , hidden_states , self .lm_head , forward_batch , aux_hidden_states
944979 )
945980
946981 def get_embed_and_head (self ):
@@ -954,6 +989,21 @@ def set_embed_and_head(self, embed, head):
954989 torch .cuda .empty_cache ()
955990 torch .cuda .synchronize ()
956991
992+ def get_embed (self ):
993+ return self .model .embed_tokens .weight
994+
995+ def set_embed (self , embed ):
996+ # NOTE: If draft hidden size != target hidden size, the embed weight cannot be shared for EAGLE3
997+ if (
998+ hasattr (self .config , "target_hidden_size" )
999+ and self .config .target_hidden_size != self .config .hidden_size
1000+ ):
1001+ return
1002+ del self .model .embed_tokens .weight
1003+ self .model .embed_tokens .weight = embed
1004+ torch .cuda .empty_cache ()
1005+ torch .cuda .synchronize ()
1006+
9571007 def load_weights (
9581008 self , weights : Iterable [Tuple [str , torch .Tensor ]], is_mtp : bool = False
9591009 ) -> Set [str ]:
@@ -1071,6 +1121,23 @@ def get_model_config_for_expert_location(cls, config):
10711121 num_groups = None ,
10721122 )
10731123
1124+ def set_eagle3_layers_to_capture (self , layer_ids : Optional [list [int ]] = None ):
1125+ if not self .pp_group .is_last_rank :
1126+ return
1127+
1128+ self .capture_aux_hidden_states = True
1129+ if layer_ids is None :
1130+ num_layers = self .config .num_hidden_layers
1131+ self .model .set_eagle3_layers_to_capture (
1132+ [
1133+ 2 ,
1134+ num_layers // 2 ,
1135+ num_layers - 3 ,
1136+ ]
1137+ ) # Specific layers for EAGLE3 support
1138+ else :
1139+ self .model .set_eagle3_layers_to_capture ([val + 1 for val in layer_ids ])
1140+
10741141
10751142EntryClass = Qwen3NextForCausalLM
10761143
0 commit comments