diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index c46617ae9..af3e22e56 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -96,6 +96,7 @@ def __init__(self, kvargs): "eagle_with_att", "vanilla_no_att", "eagle_no_att", + "eagle_frozen_kv", ] self.prefill_graph: PrefillCudaGraph = None @@ -659,14 +660,19 @@ def prefill_func(input_tensors, infer_state): last_input_embs = infer_state._all_to_all_unbalance_get(data=last_input_embs) predict_logits = self.post_infer.token_forward(last_input_embs, infer_state, self.pre_post_weight) + if isinstance(predict_logits, tuple): + predict_logits, mtp_main_output_hiddens = predict_logits + else: + mtp_main_output_hiddens = None model_output = ModelOutput(logits=predict_logits) # 特殊模型特殊模式的额外输出 if self.is_mtp_mode: - input_embs = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) - if infer_state.need_dp_prefill_balance: - input_embs = infer_state._all_to_all_unbalance_get(data=input_embs) - model_output.mtp_main_output_hiddens = input_embs.contiguous() + if mtp_main_output_hiddens is None: + mtp_main_output_hiddens = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) + if infer_state.need_dp_prefill_balance: + mtp_main_output_hiddens = infer_state._all_to_all_unbalance_get(data=mtp_main_output_hiddens) + model_output.mtp_main_output_hiddens = mtp_main_output_hiddens.contiguous() # 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候 # 该调用没有实际意义 @@ -689,12 +695,20 @@ def _token_forward(self, infer_state: InferStateInfo): last_input_embs, infer_state=infer_state, layer_weight=self.pre_post_weight ) + if isinstance(predict_logits, tuple): + predict_logits, mtp_main_output_hiddens = predict_logits + else: + mtp_main_output_hiddens = None + model_output = ModelOutput(logits=predict_logits.contiguous()) # 特殊模型特殊模式的额外输出 if self.is_mtp_mode: - input_embs = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) - model_output.mtp_main_output_hiddens = input_embs.contiguous() + if mtp_main_output_hiddens is None: + mtp_main_output_hiddens = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) + if infer_state.need_dp_prefill_balance: + mtp_main_output_hiddens = infer_state._all_to_all_unbalance_get(data=mtp_main_output_hiddens) + model_output.mtp_main_output_hiddens = mtp_main_output_hiddens.contiguous() # 在 cuda graph 模式下,输出需要转为 no ref tensor, 加强mem pool 的复用,降低显存的使用。 if infer_state.is_cuda_graph: @@ -930,19 +944,30 @@ def _overlap_tpsp_context_forward(self, infer_state: InferStateInfo, infer_state predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward( last_input_embs, last_input_embs1, infer_state, infer_state1, self.pre_post_weight ) + if isinstance(predict_logits, tuple): + predict_logits, mtp_main_output_hiddens = predict_logits + else: + mtp_main_output_hiddens = None + if isinstance(predict_logits1, tuple): + predict_logits1, mtp_main_output_hiddens1 = predict_logits1 + else: + mtp_main_output_hiddens1 = None g_cache_manager.cache_env_out() model_output = ModelOutput(logits=predict_logits.contiguous()) model_output1 = ModelOutput(logits=predict_logits1.contiguous()) if self.is_mtp_mode: - input_embs = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) - input_embs1 = self.pre_infer._tpsp_allgather(input=input_embs1, infer_state=infer_state1) - if infer_state.need_dp_prefill_balance: - input_embs = infer_state._all_to_all_unbalance_get(data=input_embs) - input_embs1 = infer_state1._all_to_all_unbalance_get(data=input_embs1) - model_output.mtp_main_output_hiddens = input_embs.contiguous() - model_output1.mtp_main_output_hiddens = input_embs1.contiguous() + if mtp_main_output_hiddens is None: + mtp_main_output_hiddens = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) + if infer_state.need_dp_prefill_balance: + mtp_main_output_hiddens = infer_state._all_to_all_unbalance_get(data=mtp_main_output_hiddens) + model_output.mtp_main_output_hiddens = mtp_main_output_hiddens.contiguous() + if mtp_main_output_hiddens1 is None: + mtp_main_output_hiddens1 = self.pre_infer._tpsp_allgather(input=input_embs1, infer_state=infer_state1) + if infer_state1.need_dp_prefill_balance: + mtp_main_output_hiddens1 = infer_state1._all_to_all_unbalance_get(data=mtp_main_output_hiddens1) + model_output1.mtp_main_output_hiddens = mtp_main_output_hiddens1.contiguous() return model_output, model_output1 @@ -969,15 +994,29 @@ def _overlap_tpsp_token_forward(self, infer_state: InferStateInfo, infer_state1: predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward( last_input_embs, last_input_embs1, infer_state, infer_state1, self.pre_post_weight ) + if isinstance(predict_logits, tuple): + predict_logits, mtp_main_output_hiddens = predict_logits + else: + mtp_main_output_hiddens = None + if isinstance(predict_logits1, tuple): + predict_logits1, mtp_main_output_hiddens1 = predict_logits1 + else: + mtp_main_output_hiddens1 = None model_output = ModelOutput(logits=predict_logits.contiguous()) model_output1 = ModelOutput(logits=predict_logits1.contiguous()) if self.is_mtp_mode: - input_embs = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) - input_embs1 = self.pre_infer._tpsp_allgather(input=input_embs1, infer_state=infer_state1) - model_output.mtp_main_output_hiddens = input_embs.contiguous() - model_output1.mtp_main_output_hiddens = input_embs1.contiguous() + if mtp_main_output_hiddens is None: + mtp_main_output_hiddens = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state) + if infer_state.need_dp_prefill_balance: + mtp_main_output_hiddens = infer_state._all_to_all_unbalance_get(data=mtp_main_output_hiddens) + model_output.mtp_main_output_hiddens = mtp_main_output_hiddens.contiguous() + if mtp_main_output_hiddens1 is None: + mtp_main_output_hiddens1 = self.pre_infer._tpsp_allgather(input=input_embs1, infer_state=infer_state1) + if infer_state1.need_dp_prefill_balance: + mtp_main_output_hiddens1 = infer_state1._all_to_all_unbalance_get(data=mtp_main_output_hiddens1) + model_output1.mtp_main_output_hiddens = mtp_main_output_hiddens1.contiguous() if infer_state.is_cuda_graph: model_output.to_no_ref_tensor() @@ -1185,15 +1224,21 @@ def _init_padded_req(self): def _gen_special_model_input(self, token_num: int): special_model_input = {} + cls_name = str(self.__class__) is_mtp_draft_model = ( - "Deepseek3MTPModel" in str(self.__class__) - or "Qwen3MOEMTPModel" in str(self.__class__) - or "MistralMTPModel" in str(self.__class__) - or "Glm4MoeLiteMTPModel" in str(self.__class__) + "Deepseek3MTPModel" in cls_name + or "Qwen3MOEMTPModel" in cls_name + or "MistralMTPModel" in cls_name + or "Glm4MoeLiteMTPModel" in cls_name + or "Gemma4MTPModel" in cls_name ) if is_mtp_draft_model: + # Gemma-4's drafter consumes the recurrent hidden state in backbone + # width (the target's hidden size), not its own draft width; the other + # MTP drafters have draft width == backbone width so hidden_size fits. + hidden_size = self.config.get("backbone_hidden_size", self.config["hidden_size"]) special_model_input["mtp_draft_input_hiddens"] = torch.randn( - token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda" + token_num, hidden_size, dtype=self.data_type, device="cuda" ) else: special_model_input["mtp_draft_input_hiddens"] = None diff --git a/lightllm/models/gemma4/layer_infer/post_layer_infer.py b/lightllm/models/gemma4/layer_infer/post_layer_infer.py index 22bcf0508..46c707de4 100644 --- a/lightllm/models/gemma4/layer_infer/post_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/post_layer_infer.py @@ -1,5 +1,7 @@ import torch +import numpy as np from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer +from lightllm.distributed.communication_op import all_gather class Gemma4PostLayerInfer(LlamaPostLayerInfer): @@ -13,8 +15,28 @@ def __init__(self, network_config): self.final_logit_softcapping = float(network_config.get("final_logit_softcapping")) def token_forward(self, input_embdings, infer_state, layer_weight): - logits = super().token_forward(input_embdings, infer_state, layer_weight) + last_hidden, token_num = self._slice_get_last_input(input_embdings, infer_state) + input_embdings_dtype = input_embdings.dtype + last_hidden = self._norm(last_hidden, infer_state, layer_weight) + lm_input = last_hidden.permute(1, 0).view(-1, token_num) + logic_batch = layer_weight.lm_head_weight_(input=lm_input, alloc_func=self.alloc_tensor) + vocab_size = layer_weight.lm_head_weight_.vocab_size + if self.tp_world_size_ == 1: + gather_data = logic_batch + else: + gather_data = self.alloc_tensor((vocab_size, token_num), dtype=input_embdings_dtype) + split_indexes = np.linspace(0, vocab_size, self.tp_world_size_ + 1, dtype=np.int64) + all_gather( + [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.tp_world_size_)], + logic_batch, + group=infer_state.dist_group, + async_op=False, + ) + logic_batch = None + logits = self.alloc_tensor((token_num, vocab_size), dtype=torch.float32) + logits[:, :] = gather_data.permute(1, 0) + gather_data = None if self.final_logit_softcapping is not None and self.final_logit_softcapping > 0: cap = self.final_logit_softcapping logits = torch.tanh(logits / cap) * cap - return logits + return logits, last_hidden diff --git a/lightllm/models/gemma4_mtp/__init__.py b/lightllm/models/gemma4_mtp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/gemma4_mtp/layer_infer/__init__.py b/lightllm/models/gemma4_mtp/layer_infer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/gemma4_mtp/layer_infer/post_layer_infer.py b/lightllm/models/gemma4_mtp/layer_infer/post_layer_infer.py new file mode 100644 index 000000000..7bd4b8eac --- /dev/null +++ b/lightllm/models/gemma4_mtp/layer_infer/post_layer_infer.py @@ -0,0 +1,90 @@ +import torch +import numpy as np +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer +from lightllm.distributed.communication_op import all_gather + + +class Gemma4MTPPostLayerInfer(LlamaPostLayerInfer): + def __init__(self, network_config): + super().__init__(network_config) + cap = network_config.get("final_logit_softcapping") + self.final_logit_softcapping = float(cap) if cap else None + + self.use_ordered_embeddings_ = bool(network_config.get("use_ordered_embeddings")) + self._post_projection_weight_ = None + if self.use_ordered_embeddings_: + self.num_centroids_ = network_config["num_centroids"] + self.centroid_top_k_ = network_config["centroid_intermediate_top_k"] + self.vocab_size_ = network_config["vocab_size"] + assert ( + self.vocab_size_ % self.num_centroids_ == 0 + ), f"vocab_size={self.vocab_size_} must be divisible by num_centroids={self.num_centroids_}" + self._vocab_per_centroid_ = self.vocab_size_ // self.num_centroids_ + # token -> centroid mapping is derived lazily from the loaded + # token_ordering buffer (weights are not loaded yet at __init__). + self._centroid_of_token_ = None + + def _dense_logits(self, last_hidden, token_num, input_embdings_dtype, infer_state, layer_weight): + lm_input = last_hidden.permute(1, 0).view(-1, token_num) + logic_batch = layer_weight.lm_head_weight_(input=lm_input, alloc_func=self.alloc_tensor) + vocab_size = layer_weight.lm_head_weight_.vocab_size + if self.tp_world_size_ == 1: + gather_data = logic_batch + else: + gather_data = self.alloc_tensor((vocab_size, token_num), dtype=input_embdings_dtype) + split_indexes = np.linspace(0, vocab_size, self.tp_world_size_ + 1, dtype=np.int64) + all_gather( + [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.tp_world_size_)], + logic_batch, + group=infer_state.dist_group, + async_op=False, + ) + logic_batch = None + ans_logics = self.alloc_tensor((token_num, vocab_size), dtype=torch.float32) + ans_logics[:, :] = gather_data.permute(1, 0) + gather_data = None + return ans_logics + + def _centroid_logits(self, last_hidden, token_num, layer_weight): + """Gather lm_head rows for the per-token top-K centroid blocks, + dot with the post-norm hidden, scatter into a [N, vocab] -inf tensor + at the original vocab positions. Mathematically equivalent to + dense logits + mask but avoids the [N, vocab] bool tensor and matches + the reference implementations exactly. + """ + centroid_scores = layer_weight.centroids_weight_.mm(last_hidden) # [N, num_centroids] + topk_centroids = torch.topk(centroid_scores, k=self.centroid_top_k_, dim=-1).indices # [N, K] + # token_ordering[i] = original vocab id at reordered position i; + # row c of the (C, vpc) view holds the vocab ids of centroid c. + token_ordering = layer_weight.token_ordering_.weight # [vocab] int64 + clusters = token_ordering.view(self.num_centroids_, self._vocab_per_centroid_) # [C, vpc] + selected_vocab = clusters[topk_centroids] # [N, K, vpc] - original vocab ids + num_selected = self.centroid_top_k_ * self._vocab_per_centroid_ + selected_vocab = selected_vocab.reshape(token_num, num_selected) # [N, num_selected] + # Gather lm_head rows for the selected vocab ids. + lm_head_w = layer_weight.lm_head_weight_.weight # [vocab, draft_hidden] + selected_embeddings = lm_head_w[selected_vocab] # [N, num_selected, H] + # Sparse logits: dot product per token vs its selected rows. + selected_logits = torch.einsum("nh,nsh->ns", last_hidden, selected_embeddings) + # Scatter to [N, vocab] with -inf elsewhere. + output = torch.full( + (token_num, self.vocab_size_), + float("-inf"), + dtype=selected_logits.dtype, + device=selected_logits.device, + ) + output.scatter_(-1, selected_vocab, selected_logits) + return output + + def token_forward(self, input_embdings, infer_state, layer_weight): + last_hidden, token_num = self._slice_get_last_input(input_embdings, infer_state) + last_hidden = self._norm(last_hidden, infer_state, layer_weight) + if self.use_ordered_embeddings_: + logits = self._centroid_logits(last_hidden, token_num, layer_weight) + else: + logits = self._dense_logits(last_hidden, token_num, input_embdings.dtype, infer_state, layer_weight) + if self.final_logit_softcapping is not None and self.final_logit_softcapping > 0: + cap = self.final_logit_softcapping + logits = torch.tanh(logits / cap) * cap + assert self._post_projection_weight_ is not None, "post_projection weight is not initialized" + return logits, self._post_projection_weight_.mm(last_hidden) diff --git a/lightllm/models/gemma4_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/gemma4_mtp/layer_infer/pre_layer_infer.py new file mode 100644 index 000000000..a2794b768 --- /dev/null +++ b/lightllm/models/gemma4_mtp/layer_infer/pre_layer_infer.py @@ -0,0 +1,41 @@ +import torch +from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.models.gemma4_mtp.layer_weights.pre_and_post_layer_weight import Gemma4MTPPreAndPostLayerWeight + + +class Gemma4MTPPreLayerInfer(LlamaPreLayerInfer): + """ + Gemma-4 assistant input/output glue. + + Input fusion (context_forward / token_forward): + embed = target_wte(input_ids) * sqrt(backbone_hidden) # backbone width + prev = mtp_draft_input_hiddens # backbone width + fused = pre_projection(concat[embed, prev]) # -> draft width + + Output projection is handled by Gemma4MTPPostLayerInfer, which returns + post_projection(norm(draft_hidden)) as mtp_main_output_hiddens. + """ + + def __init__(self, network_config): + super().__init__(network_config) + self.draft_hidden_ = network_config["hidden_size"] + self.backbone_hidden_ = network_config["backbone_hidden_size"] + self.embed_scale_ = float(self.backbone_hidden_) ** 0.5 + + def _mtp_fuse(self, input_embdings, infer_state, layer_weight: Gemma4MTPPreAndPostLayerWeight): + prev = infer_state.mtp_draft_input_hiddens + assert ( + input_embdings.shape[0] == prev.shape[0] + ), f"token count mismatch: embed {input_embdings.shape} vs prev_hidden {prev.shape}" + # The target embedding is backbone width; scale like the target model does. + input_embdings = input_embdings * self.embed_scale_ + cat = torch.cat((input_embdings, prev), dim=-1) + return layer_weight.pre_projection_weight_.mm(cat) + + def context_forward(self, input_ids, infer_state, layer_weight): + input_embdings = super().context_forward(input_ids, infer_state, layer_weight) + return self._mtp_fuse(input_embdings, infer_state, layer_weight) + + def token_forward(self, input_ids, infer_state, layer_weight): + input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + return self._mtp_fuse(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/gemma4_mtp/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma4_mtp/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..b54a7b85c --- /dev/null +++ b/lightllm/models/gemma4_mtp/layer_infer/transformer_layer_infer.py @@ -0,0 +1,21 @@ +from lightllm.models.gemma4.layer_infer.transformer_layer_infer import Gemma4TransformerLayerInfer + + +class Gemma4MTPTransformerLayerInfer(Gemma4TransformerLayerInfer): + """ + Gemma-4 assistant decoder block. Identical to the target's block except the + attention is forced into the KV-shared (Q-only) path: K/V are read from the + *target model's* committed cache, and nothing is computed or written here. + + `layer_num` is the assistant-local index (0..num_mtp_layers-1) - used for + config lookups (layer_types / RoPE table / per-layer shapes). + `kv_share_target_layer` is the *target model's* layer index whose KV cache + this layer reads (the target's last non-KV-shared layer of the same attention + type). The MTP network_config has num_kv_shared_layers forced to 0 so the + parent __init__ leaves is_kv_shared_ False; it is forced True here. + """ + + def __init__(self, layer_num, network_config, kv_share_target_layer): + super().__init__(layer_num, network_config) + self.is_kv_shared_ = True + self.kv_share_target_layer_ = kv_share_target_layer diff --git a/lightllm/models/gemma4_mtp/layer_weights/__init__.py b/lightllm/models/gemma4_mtp/layer_weights/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/models/gemma4_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma4_mtp/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 000000000..2eab38873 --- /dev/null +++ b/lightllm/models/gemma4_mtp/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,95 @@ +import torch +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + LMHeadWeight, + ParameterWeight, + RMSNormWeight, + ROWMMWeight, +) + + +class Gemma4MTPPreAndPostLayerWeight(PreAndPostLayerWeight): + """ + Pre/post weights for the Gemma-4 assistant (frozen-KV MTP drafter). + + Layout differs from a normal model: + * pre_projection : Linear(2 * backbone_hidden -> draft_hidden) - fuses the + target token embedding with the recurrent hidden state. + * post_projection: Linear(draft_hidden -> backbone_hidden) - maps the draft + trunk output back to backbone width for the next recurrent step. + * model.norm : final RMSNorm in draft_hidden width. + * lm_head : tied to the assistant's own model.embed_tokens.weight + (draft_hidden width); this is NOT the input embedding. + * wte_weight_ : the *target's* input embedding (backbone_hidden width), + aliased from the main model in Gemma4MTPModel._init_weights - never loaded + here. + """ + + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + draft_hidden = network_config["hidden_size"] + backbone_hidden = network_config["backbone_hidden_size"] + vocab_size = network_config["vocab_size"] + + self.pre_projection_weight_ = ROWMMWeight( + in_dim=backbone_hidden * 2, + out_dims=[draft_hidden], + weight_names="pre_projection.weight", + data_type=self.data_type_, + tp_rank=0, + tp_world_size=1, + ) + self.post_projection_weight_ = ROWMMWeight( + in_dim=draft_hidden, + out_dims=[backbone_hidden], + weight_names="post_projection.weight", + data_type=self.data_type_, + tp_rank=0, + tp_world_size=1, + ) + self.final_norm_weight_ = RMSNormWeight( + dim=draft_hidden, + weight_name="model.norm.weight", + data_type=self.data_type_, + ) + # The assistant ships model.embed_tokens.weight in draft_hidden width; + # with tie_word_embeddings it serves only as the tied lm_head matrix + # (the input embedding comes from the target model, see wte_weight_). + self._mtp_lm_head_embed_ = EmbeddingWeight( + dim=draft_hidden, + vocab_size=vocab_size, + weight_name="model.embed_tokens.weight", + data_type=self.data_type_, + ) + self.lm_head_weight_ = LMHeadWeight( + dim=draft_hidden, + vocab_size=vocab_size, + weight_name="lm_head.weight", + data_type=self.data_type_, + embedding_weight=self._mtp_lm_head_embed_, + ) + # The input token embedding is the *target's* (backbone_hidden width); + # aliased from the main model in Gemma4MTPModel._init_weights. + self.wte_weight_: EmbeddingWeight = None + + # E-series centroid sparse-logits head (only the E* assistants ship these). + # token_ordering[v] gives the centroid id of vocab token v; at decode time + # Gemma4MTPPostLayerInfer masks logits to the per-query top-K centroids' + # vocab slice. + if network_config.get("use_ordered_embeddings"): + num_centroids = network_config["num_centroids"] + self.centroids_weight_ = ROWMMWeight( + in_dim=draft_hidden, + out_dims=[num_centroids], + weight_names="masked_embedding.centroids.weight", + data_type=self.data_type_, + tp_rank=0, + tp_world_size=1, + ) + self.token_ordering_ = ParameterWeight( + weight_name="masked_embedding.token_ordering", + data_type=torch.int64, + weight_shape=(vocab_size,), + ) + return diff --git a/lightllm/models/gemma4_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma4_mtp/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..4bb6001d7 --- /dev/null +++ b/lightllm/models/gemma4_mtp/layer_weights/transformer_layer_weight.py @@ -0,0 +1,80 @@ +from lightllm.common.basemodel.layer_weights.meta_weights import RMSNormWeight, ROWMMWeight, ParameterWeight +from lightllm.models.gemma4.layer_weights.transformer_layer_weight import Gemma4TransformerLayerWeight + + +class Gemma4MTPTransformerLayerWeight(Gemma4TransformerLayerWeight): + """ + Gemma-4 assistant decoder-layer weights. Same block shape as the target's + Gemma4TransformerLayerWeight, but: + * checkpoint prefix is `model.layers.{i}` (the target uses + `model.language_model.layers.{i}`), + * attention is Q-projection only - the assistant has no k/v_proj or k_norm + (it reads the target's KV cache), + * never MoE / never PLE (the assistant trunk is always dense). + `layer_num_` here is the assistant-local index (0..num_mtp_layers-1). + """ + + def _init_weight_names(self): + prefix = f"model.layers.{self.layer_num_}" + self._q_weight_name = f"{prefix}.self_attn.q_proj.weight" + self._q_bias_name = None + self._o_weight_name = f"{prefix}.self_attn.o_proj.weight" + self._o_bias_name = None + self._q_norm_weight_name = f"{prefix}.self_attn.q_norm.weight" + + self._gate_weight_name = f"{prefix}.mlp.gate_proj.weight" + self._up_weight_name = f"{prefix}.mlp.up_proj.weight" + self._down_weight_name = f"{prefix}.mlp.down_proj.weight" + + self._att_norm_weight_name = f"{prefix}.input_layernorm.weight" + self._ffn_norm_weight_name = f"{prefix}.post_attention_layernorm.weight" + self._pre_feedforward_layernorm_name = f"{prefix}.pre_feedforward_layernorm.weight" + self._post_feedforward_layernorm_name = f"{prefix}.post_feedforward_layernorm.weight" + + self._layer_scalar_name = f"{prefix}.layer_scalar" + + def _init_qkv(self): + # Q-projection only: the assistant reads K/V from the target's cache. + self.q_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.q_head_num_ * self.head_dim], + weight_names=self._q_weight_name, + data_type=self.data_type_, + bias_names=self._q_bias_name, + quant_method=self.get_quant_method("q_proj"), + ) + + def _init_norm(self): + hidden_size = self.network_config_["hidden_size"] + # Standard RMSNorm (not the gemma2/3 (1+w) variant). No k_norm: there is + # no k_proj on the assistant. + self.q_norm_weight_ = RMSNormWeight( + dim=self._layer_head_dim, + weight_name=self._q_norm_weight_name, + data_type=self.data_type_, + ) + self.att_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._att_norm_weight_name, + data_type=self.data_type_, + ) + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, + ) + self.pre_feedforward_layernorm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._pre_feedforward_layernorm_name, + data_type=self.data_type_, + ) + self.post_feedforward_layernorm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._post_feedforward_layernorm_name, + data_type=self.data_type_, + ) + self.layer_scalar_ = ParameterWeight( + weight_name=self._layer_scalar_name, + data_type=self.data_type_, + weight_shape=(1,), + ) diff --git a/lightllm/models/gemma4_mtp/model.py b/lightllm/models/gemma4_mtp/model.py new file mode 100644 index 000000000..fe5f74a46 --- /dev/null +++ b/lightllm/models/gemma4_mtp/model.py @@ -0,0 +1,118 @@ +import os +import json +from typing import List +from lightllm.models.gemma4.model import Gemma4TpPartModel +from lightllm.models.gemma4_mtp.layer_weights.pre_and_post_layer_weight import Gemma4MTPPreAndPostLayerWeight +from lightllm.models.gemma4_mtp.layer_weights.transformer_layer_weight import Gemma4MTPTransformerLayerWeight +from lightllm.models.gemma4_mtp.layer_infer.pre_layer_infer import Gemma4MTPPreLayerInfer +from lightllm.models.gemma4_mtp.layer_infer.transformer_layer_infer import Gemma4MTPTransformerLayerInfer +from lightllm.models.gemma4_mtp.layer_infer.post_layer_infer import Gemma4MTPPostLayerInfer +from lightllm.common.basemodel import TpPartBaseModel +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Gemma4MTPModel(Gemma4TpPartModel): + """ + Gemma-4 assistant drafter (frozen-KV MTP). Subclasses the Gemma-4 target model + and reuses its decoder block / RoPE tables / attention backends, but: + * shares the target's mem_manager and req_manager - the drafter allocates no + KV of its own (frozen-KV: it reads the target's committed cache), + * builds only num_hidden_layers (4) draft layers, each forced into the + KV-shared Q-only attention path pointed at a target-model layer, + * fuses the target's recurrent hidden state with the token embedding via + pre_projection / post_projection (see Gemma4MTPPreLayerInfer). + + Instantiated directly by ModeBackend.init_mtp_draft_model - not registered via + @ModelRegistry and not imported in lightllm/models/__init__.py. + """ + + pre_and_post_weight_class = Gemma4MTPPreAndPostLayerWeight + transformer_weight_class = Gemma4MTPTransformerLayerWeight + + pre_layer_infer_class = Gemma4MTPPreLayerInfer + transformer_layer_infer_class = Gemma4MTPTransformerLayerInfer + post_layer_infer_class = Gemma4MTPPostLayerInfer + + def __init__(self, kvargs): + self._pre_init(kvargs) + super().__init__(kvargs) + + def _pre_init(self, kvargs): + self.main_model: TpPartBaseModel = kvargs.pop("main_model") + self.mtp_previous_draft_models: List[TpPartBaseModel] = kvargs.pop("mtp_previous_draft_models") + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as f: + outer_cfg = json.load(f) + super()._init_config() + # backbone_hidden_size lives on the outer assistant config, above text_config. + self.config["backbone_hidden_size"] = outer_cfg["backbone_hidden_size"] + # E-series centroid sparse-logits head (gemma-4-E4B-it-assistant etc.): + # forward outer-config fields used by Gemma4MTPPreAndPostLayerWeight / + # Gemma4MTPPostLayerInfer to mask draft logits to the top-K centroids' + # vocab slice. + self.config["use_ordered_embeddings"] = bool(outer_cfg.get("use_ordered_embeddings")) + if self.config["use_ordered_embeddings"]: + self.config["num_centroids"] = outer_cfg["num_centroids"] + self.config["centroid_intermediate_top_k"] = outer_cfg["centroid_intermediate_top_k"] + # The assistant config marks every layer KV-shared - this denotes + # cross-model sharing with the target, not the intra-model sharing the + # inherited layer infer expects. Force it to 0 so + # Gemma4TransformerLayerInfer.__init__ leaves is_kv_shared_ False; + # Gemma4MTPTransformerLayerInfer then sets it True with the correct + # target-model layer index. + self.config["num_kv_shared_layers"] = 0 + + def _init_custom(self): + # Reuse the target's RoPE tables (built for the target's max_seq_len, which + # covers every position the drafter ever queries). Skip the deepep group + # setup - the assistant trunk is always dense. + self._cos_cached_sliding = self.main_model._cos_cached_sliding + self._sin_cached_sliding = self.main_model._sin_cached_sliding + self._cos_cached_full = self.main_model._cos_cached_full + self._sin_cached_full = self.main_model._sin_cached_full + + def _init_req_manager(self): + self.req_manager = self.main_model.req_manager + + def _init_mem_manager(self): + # Frozen-KV: the drafter never writes KV, it reads the target's cache. + self.mem_manager = self.main_model.mem_manager + + def _init_weights(self): + super()._init_weights() + # The input token embedding is the target's (backbone width); the + # assistant's own model.embed_tokens serves only as the tied lm_head. + self.pre_post_weight.wte_weight_ = self.main_model.pre_post_weight.wte_weight_ + + def _init_infer_layer(self): + self.pre_infer = self.pre_layer_infer_class(network_config=self.config) + self.post_infer = self.post_layer_infer_class(network_config=self.config) + # post_projection lifts normed draft hidden to backbone width for the next + # recurrent input; token_forward returns it as mtp_main_output_hiddens. + self.post_infer._post_projection_weight_ = self.pre_post_weight.post_projection_weight_ + + # Map each assistant layer to the target model's layer whose KV cache it + # reads: the target's last non-KV-shared layer of the same attention type. + target_cfg = self.main_model.config + target_layer_types = target_cfg["layer_types"] + target_kv_shared = target_cfg.get("num_kv_shared_layers") or 0 + target_cutoff = len(target_layer_types) - target_kv_shared + last_of_type = {} + for j in range(target_cutoff): + last_of_type[target_layer_types[j]] = j + + draft_layer_types = self.config["layer_types"] + self.layers_infer = [] + for i in range(self.config["num_hidden_layers"]): + layer_type = draft_layer_types[i] + self.layers_infer.append( + self.transformer_layer_infer_class( + i, network_config=self.config, kv_share_target_layer=last_of_type[layer_type] + ) + ) + + def autotune_layers(self): + return self.config["num_hidden_layers"] diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 6b30ab687..8a531652b 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -689,13 +689,16 @@ def make_argument_parser() -> argparse.ArgumentParser: "eagle_with_att", "vanilla_no_att", "eagle_no_att", + "eagle_frozen_kv", None, ], default=None, help="""Supported MTP modes. None: Disables MTP. *_with_att: Uses the MTP model with an attention mechanism to predict the next draft token. - *_no_att: Uses the MTP model without an attention module to predict the next draft token.""", + *_no_att: Uses the MTP model without an attention module to predict the next draft token. + eagle_frozen_kv: Eagle-style draft whose attention reads the target model's KV cache and + writes none of its own (Gemma-4 assistant drafters).""", ) parser.add_argument( "--mtp_draft_model_dir", diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 05ff2658e..29eb144d1 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -163,6 +163,7 @@ class StartArgs: "eagle_with_att", "vanilla_no_att", "eagle_no_att", + "eagle_frozen_kv", "qwen3next_vanilla", "qwen3next_eagle", None, diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index ef93ad5fd..24c573f68 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -46,6 +46,7 @@ from lightllm.models.qwen3_moe_mtp.model import Qwen3MOEMTPModel from lightllm.models.mistral_mtp.model import MistralMTPModel from lightllm.models.glm4_moe_lite_mtp.model import Glm4MoeLiteMTPModel +from lightllm.models.gemma4_mtp.model import Gemma4MTPModel from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet @@ -316,7 +317,7 @@ def init_mtp_draft_model(self, main_kvargs: dict): if self.args.mtp_mode in ["vanilla_with_att", "vanilla_no_att"]: num_mtp_modules = self.args.mtp_step - elif self.args.mtp_mode in ["eagle_with_att", "eagle_no_att"]: + elif self.args.mtp_mode in ["eagle_with_att", "eagle_no_att", "eagle_frozen_kv"]: num_mtp_modules = 1 else: assert False, f"error mtp mode {self.args.mtp_mode}" @@ -360,6 +361,9 @@ def init_mtp_draft_model(self, main_kvargs: dict): elif mtp_model_cfg["model_type"] == "glm4_moe_lite": assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] self.draft_models.append(Glm4MoeLiteMTPModel(mtp_model_kvargs)) + elif model_type == "gemma4_assistant": + assert self.args.mtp_mode == "eagle_frozen_kv" + self.draft_models.append(Gemma4MTPModel(mtp_model_kvargs)) else: raise ValueError(f"Unsupported MTP model type: {model_type}") diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 60045fab6..8314e7286 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -40,12 +40,23 @@ def __init__(self) -> None: if get_env_start_args().mtp_mode: self.prefill = self.prefill_mtp self.decode = self.decode_mtp - self.is_mtp_eagle = get_env_start_args().mtp_mode in ["eagle_with_att", "eagle_no_att"] + mtp_mode = get_env_start_args().mtp_mode + # frozen-KV (Gemma-4 assistant): eagle-style recurrence (1 draft module, + # run mtp_step times) but the draft writes no KV of its own - it reads + # the target model's committed KV cache. + self.is_mtp_frozen_kv = mtp_mode == "eagle_frozen_kv" + self.is_mtp_eagle = mtp_mode in ["eagle_with_att", "eagle_no_att", "eagle_frozen_kv"] self.num_mtp_models = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step - self._draft_decode_func = self._draft_decode_eagle if self.is_mtp_eagle else self._draft_decode_vanilla + if self.is_mtp_frozen_kv: + self._draft_decode_func = self._draft_decode_frozen_kv + elif self.is_mtp_eagle: + self._draft_decode_func = self._draft_decode_eagle + else: + self._draft_decode_func = self._draft_decode_vanilla else: self.prefill = self.prefill_normal self.decode = self.decode_normal + self.is_mtp_frozen_kv = False self.classed_req_strict_prefill = False return @@ -326,6 +337,10 @@ def decode_mtp( def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor): # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。 + # frozen-KV draft 没有自己的 kv 需要填充,它读取 target 模型已提交的 kv cache; + # 第一个 decode step 的 draft 输入来自该 step 的 target forward,无需 prefill 阶段准备。 + if self.is_mtp_frozen_kv: + return draft_model_input = model_input draft_model_output = model_output draft_next_token_ids_gpu = next_token_ids @@ -425,3 +440,40 @@ def _draft_decode_eagle( mtp_accept_len=mtp_accept_len, ) return eagle_mem_indexes_cpu + + def _draft_decode_frozen_kv( + self, + main_model_input: ModelInput, + main_model_output: ModelOutput, + next_token_ids: torch.Tensor, + mtp_accept_len: torch.Tensor, + b_req_mtp_start_loc: torch.Tensor, + ): + # frozen-KV draft (Gemma-4 assistant): like _draft_decode_eagle but the + # draft has no KV cache of its own - it reads the target model's committed + # KV. So there is no mem_manager.alloc, no b_seq_len / max_kv_seq_len + # advance, and no mem_indexes roll: every draft step queries at the same + # (target-committed) position, and the recurrent hidden state carries the + # progression. Nothing extra to free, so returns None. + draft_model_input = main_model_input + draft_model_output = main_model_output + draft_next_token_ids = next_token_ids + all_next_token_ids = [] + all_next_token_ids.append(next_token_ids) + for _step in range(self.mtp_step): + draft_model_input.input_ids = draft_next_token_ids + draft_model_input.mtp_draft_input_hiddens = draft_model_output.mtp_main_output_hiddens + draft_model_output: ModelOutput = self.draft_models[0].forward(draft_model_input) + draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) + all_next_token_ids.append(draft_next_token_ids) + + all_next_token_ids = torch.stack(all_next_token_ids, dim=1) # [batch_size, mtp_step + 1] + + mtp_scatter_next_token_ids( + req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, + b_req_mtp_start_loc=b_req_mtp_start_loc, + all_next_token_ids=all_next_token_ids, + b_req_idx=main_model_input.b_req_idx, + mtp_accept_len=mtp_accept_len, + ) + return None