Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
21c3eeb
support 31B
WANDY666 Apr 30, 2026
99b790c
fix
WANDY666 May 6, 2026
4c30c73
Merge branch 'main' of https://github.com/ModelTC/LightLLM into suppo…
WANDY666 May 6, 2026
15a5379
support moe
WANDY666 May 7, 2026
83f4983
support e4b (PLE and shared_kv)
WANDY666 May 9, 2026
d969a5f
support visual module
WANDY666 May 11, 2026
08f066d
optimize sliding window
WANDY666 May 12, 2026
7678de8
fix
WANDY666 May 12, 2026
63c658a
simplify
WANDY666 May 13, 2026
300e577
minor improvements
WANDY666 May 13, 2026
50822f0
fix
WANDY666 May 13, 2026
b4b13cc
fix attention cuda graph
WANDY666 May 13, 2026
f19074b
fused gelu gate up
WANDY666 May 14, 2026
5b61450
add out_dtype
WANDY666 May 14, 2026
c0ca212
minor improvements
WANDY666 May 14, 2026
9499a00
fix eos_token_ids
WANDY666 May 14, 2026
de7e220
for HF format
WANDY666 May 14, 2026
bfc59ff
Merge branch 'main' of https://github.com/ModelTC/LightLLM into suppo…
WANDY666 May 14, 2026
109d27c
fix window_size
WANDY666 May 14, 2026
2ea258e
fix window_size
WANDY666 May 14, 2026
b297af5
fix
WANDY666 May 14, 2026
7a81e85
add reasoning_parser for gemma4
WANDY666 May 15, 2026
d619534
[fix]ple support cudagraph
WANDY666 May 16, 2026
c2578c0
fix PLE illegal memory access
WANDY666 May 18, 2026
d744cbc
support sliding_window_right
WANDY666 May 18, 2026
05a0db8
fix notes
WANDY666 May 18, 2026
6f1bd2e
tune in H200
WANDY666 May 19, 2026
90643db
fix
hiworldwzj May 19, 2026
a2b74ab
fix
hiworldwzj May 19, 2026
e606e05
fix
hiworldwzj May 19, 2026
7354da2
fix
hiworldwzj May 20, 2026
afa0194
fix
hiworldwzj May 20, 2026
46ce6af
fix
hiworldwzj May 20, 2026
0188c10
fix
hiworldwzj May 20, 2026
393ec69
fix
hiworldwzj May 20, 2026
e96c2b7
fix
WANDY666 May 20, 2026
f806326
fix
hiworldwzj May 20, 2026
c5b2b81
fix
hiworldwzj May 20, 2026
3bd46d7
fix
hiworldwzj May 20, 2026
91051f0
Merge branch 'support_gemma4' of https://github.com/ModelTC/LightLLM …
WANDY666 May 20, 2026
fb75045
fix
WANDY666 May 20, 2026
7c664c3
fix
hiworldwzj May 20, 2026
0d35e8b
fix
hiworldwzj May 20, 2026
74a4b1f
fix
hiworldwzj May 20, 2026
d2df0a0
fix
hiworldwzj May 20, 2026
3491641
fix
hiworldwzj May 20, 2026
c8812f2
fix
hiworldwzj May 20, 2026
8f160b5
fix
hiworldwzj May 20, 2026
ee92fee
fix
hiworldwzj May 21, 2026
131a163
fix
hiworldwzj May 21, 2026
6d7729f
fix
hiworldwzj May 21, 2026
87da477
fix
WANDY666 May 21, 2026
819497c
Merge branch 'main' of https://github.com/ModelTC/LightLLM into suppo…
WANDY666 May 21, 2026
c57e062
format
WANDY666 May 21, 2026
6ebf9db
finish
WANDY666 May 21, 2026
e682f9b
fix
WANDY666 May 21, 2026
c28f085
Merge https://github.com/ModelTC/LightLLM into gemma4_mtp
WANDY666 May 22, 2026
6099413
format
WANDY666 May 22, 2026
ba47045
format
WANDY666 May 22, 2026
33d7ceb
format
WANDY666 May 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 67 additions & 22 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(self, kvargs):
"eagle_with_att",
"vanilla_no_att",
"eagle_no_att",
"eagle_frozen_kv",
]
self.prefill_graph: PrefillCudaGraph = None

Expand Down Expand Up @@ -659,14 +660,19 @@ def prefill_func(input_tensors, infer_state):
last_input_embs = infer_state._all_to_all_unbalance_get(data=last_input_embs)

predict_logits = self.post_infer.token_forward(last_input_embs, infer_state, self.pre_post_weight)
if isinstance(predict_logits, tuple):
predict_logits, mtp_main_output_hiddens = predict_logits
else:
mtp_main_output_hiddens = None
model_output = ModelOutput(logits=predict_logits)

# 特殊模型特殊模式的额外输出
if self.is_mtp_mode:
input_embs = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state)
if infer_state.need_dp_prefill_balance:
input_embs = infer_state._all_to_all_unbalance_get(data=input_embs)
model_output.mtp_main_output_hiddens = input_embs.contiguous()
if mtp_main_output_hiddens is None:
mtp_main_output_hiddens = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state)
if infer_state.need_dp_prefill_balance:
mtp_main_output_hiddens = infer_state._all_to_all_unbalance_get(data=mtp_main_output_hiddens)
model_output.mtp_main_output_hiddens = mtp_main_output_hiddens.contiguous()

# 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候
# 该调用没有实际意义
Expand All @@ -689,12 +695,20 @@ def _token_forward(self, infer_state: InferStateInfo):
last_input_embs, infer_state=infer_state, layer_weight=self.pre_post_weight
)

if isinstance(predict_logits, tuple):
predict_logits, mtp_main_output_hiddens = predict_logits
else:
mtp_main_output_hiddens = None

model_output = ModelOutput(logits=predict_logits.contiguous())

# 特殊模型特殊模式的额外输出
if self.is_mtp_mode:
input_embs = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state)
model_output.mtp_main_output_hiddens = input_embs.contiguous()
if mtp_main_output_hiddens is None:
mtp_main_output_hiddens = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state)
if infer_state.need_dp_prefill_balance:
mtp_main_output_hiddens = infer_state._all_to_all_unbalance_get(data=mtp_main_output_hiddens)
model_output.mtp_main_output_hiddens = mtp_main_output_hiddens.contiguous()

# 在 cuda graph 模式下,输出需要转为 no ref tensor, 加强mem pool 的复用,降低显存的使用。
if infer_state.is_cuda_graph:
Expand Down Expand Up @@ -930,19 +944,30 @@ def _overlap_tpsp_context_forward(self, infer_state: InferStateInfo, infer_state
predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward(
last_input_embs, last_input_embs1, infer_state, infer_state1, self.pre_post_weight
)
if isinstance(predict_logits, tuple):
predict_logits, mtp_main_output_hiddens = predict_logits
else:
mtp_main_output_hiddens = None
if isinstance(predict_logits1, tuple):
predict_logits1, mtp_main_output_hiddens1 = predict_logits1
else:
mtp_main_output_hiddens1 = None
g_cache_manager.cache_env_out()

model_output = ModelOutput(logits=predict_logits.contiguous())
model_output1 = ModelOutput(logits=predict_logits1.contiguous())

if self.is_mtp_mode:
input_embs = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state)
input_embs1 = self.pre_infer._tpsp_allgather(input=input_embs1, infer_state=infer_state1)
if infer_state.need_dp_prefill_balance:
input_embs = infer_state._all_to_all_unbalance_get(data=input_embs)
input_embs1 = infer_state1._all_to_all_unbalance_get(data=input_embs1)
model_output.mtp_main_output_hiddens = input_embs.contiguous()
model_output1.mtp_main_output_hiddens = input_embs1.contiguous()
if mtp_main_output_hiddens is None:
mtp_main_output_hiddens = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state)
if infer_state.need_dp_prefill_balance:
mtp_main_output_hiddens = infer_state._all_to_all_unbalance_get(data=mtp_main_output_hiddens)
model_output.mtp_main_output_hiddens = mtp_main_output_hiddens.contiguous()
if mtp_main_output_hiddens1 is None:
mtp_main_output_hiddens1 = self.pre_infer._tpsp_allgather(input=input_embs1, infer_state=infer_state1)
if infer_state1.need_dp_prefill_balance:
mtp_main_output_hiddens1 = infer_state1._all_to_all_unbalance_get(data=mtp_main_output_hiddens1)
model_output1.mtp_main_output_hiddens = mtp_main_output_hiddens1.contiguous()

return model_output, model_output1

Expand All @@ -969,15 +994,29 @@ def _overlap_tpsp_token_forward(self, infer_state: InferStateInfo, infer_state1:
predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward(
last_input_embs, last_input_embs1, infer_state, infer_state1, self.pre_post_weight
)
if isinstance(predict_logits, tuple):
predict_logits, mtp_main_output_hiddens = predict_logits
else:
mtp_main_output_hiddens = None
if isinstance(predict_logits1, tuple):
predict_logits1, mtp_main_output_hiddens1 = predict_logits1
else:
mtp_main_output_hiddens1 = None

model_output = ModelOutput(logits=predict_logits.contiguous())
model_output1 = ModelOutput(logits=predict_logits1.contiguous())

if self.is_mtp_mode:
input_embs = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state)
input_embs1 = self.pre_infer._tpsp_allgather(input=input_embs1, infer_state=infer_state1)
model_output.mtp_main_output_hiddens = input_embs.contiguous()
model_output1.mtp_main_output_hiddens = input_embs1.contiguous()
if mtp_main_output_hiddens is None:
mtp_main_output_hiddens = self.pre_infer._tpsp_allgather(input=input_embs, infer_state=infer_state)
if infer_state.need_dp_prefill_balance:
mtp_main_output_hiddens = infer_state._all_to_all_unbalance_get(data=mtp_main_output_hiddens)
model_output.mtp_main_output_hiddens = mtp_main_output_hiddens.contiguous()
if mtp_main_output_hiddens1 is None:
mtp_main_output_hiddens1 = self.pre_infer._tpsp_allgather(input=input_embs1, infer_state=infer_state1)
if infer_state1.need_dp_prefill_balance:
mtp_main_output_hiddens1 = infer_state1._all_to_all_unbalance_get(data=mtp_main_output_hiddens1)
model_output1.mtp_main_output_hiddens = mtp_main_output_hiddens1.contiguous()

if infer_state.is_cuda_graph:
model_output.to_no_ref_tensor()
Expand Down Expand Up @@ -1185,15 +1224,21 @@ def _init_padded_req(self):
def _gen_special_model_input(self, token_num: int):
special_model_input = {}

cls_name = str(self.__class__)
is_mtp_draft_model = (
"Deepseek3MTPModel" in str(self.__class__)
or "Qwen3MOEMTPModel" in str(self.__class__)
or "MistralMTPModel" in str(self.__class__)
or "Glm4MoeLiteMTPModel" in str(self.__class__)
"Deepseek3MTPModel" in cls_name
or "Qwen3MOEMTPModel" in cls_name
or "MistralMTPModel" in cls_name
or "Glm4MoeLiteMTPModel" in cls_name
or "Gemma4MTPModel" in cls_name
)
if is_mtp_draft_model:
# Gemma-4's drafter consumes the recurrent hidden state in backbone
# width (the target's hidden size), not its own draft width; the other
# MTP drafters have draft width == backbone width so hidden_size fits.
hidden_size = self.config.get("backbone_hidden_size", self.config["hidden_size"])
special_model_input["mtp_draft_input_hiddens"] = torch.randn(
token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda"
token_num, hidden_size, dtype=self.data_type, device="cuda"
)
else:
special_model_input["mtp_draft_input_hiddens"] = None
Expand Down
26 changes: 24 additions & 2 deletions lightllm/models/gemma4/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
import numpy as np
from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer
from lightllm.distributed.communication_op import all_gather


class Gemma4PostLayerInfer(LlamaPostLayerInfer):
Expand All @@ -13,8 +15,28 @@ def __init__(self, network_config):
self.final_logit_softcapping = float(network_config.get("final_logit_softcapping"))

def token_forward(self, input_embdings, infer_state, layer_weight):
logits = super().token_forward(input_embdings, infer_state, layer_weight)
last_hidden, token_num = self._slice_get_last_input(input_embdings, infer_state)
input_embdings_dtype = input_embdings.dtype
last_hidden = self._norm(last_hidden, infer_state, layer_weight)
lm_input = last_hidden.permute(1, 0).view(-1, token_num)
logic_batch = layer_weight.lm_head_weight_(input=lm_input, alloc_func=self.alloc_tensor)
vocab_size = layer_weight.lm_head_weight_.vocab_size
if self.tp_world_size_ == 1:
gather_data = logic_batch
else:
gather_data = self.alloc_tensor((vocab_size, token_num), dtype=input_embdings_dtype)
split_indexes = np.linspace(0, vocab_size, self.tp_world_size_ + 1, dtype=np.int64)
all_gather(
[gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.tp_world_size_)],
logic_batch,
group=infer_state.dist_group,
async_op=False,
)
logic_batch = None
logits = self.alloc_tensor((token_num, vocab_size), dtype=torch.float32)
logits[:, :] = gather_data.permute(1, 0)
gather_data = None
if self.final_logit_softcapping is not None and self.final_logit_softcapping > 0:
cap = self.final_logit_softcapping
logits = torch.tanh(logits / cap) * cap
return logits
return logits, last_hidden
Empty file.
Empty file.
90 changes: 90 additions & 0 deletions lightllm/models/gemma4_mtp/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
import numpy as np
from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer
from lightllm.distributed.communication_op import all_gather


class Gemma4MTPPostLayerInfer(LlamaPostLayerInfer):
def __init__(self, network_config):
super().__init__(network_config)
cap = network_config.get("final_logit_softcapping")
self.final_logit_softcapping = float(cap) if cap else None

self.use_ordered_embeddings_ = bool(network_config.get("use_ordered_embeddings"))
self._post_projection_weight_ = None
if self.use_ordered_embeddings_:
self.num_centroids_ = network_config["num_centroids"]
self.centroid_top_k_ = network_config["centroid_intermediate_top_k"]
self.vocab_size_ = network_config["vocab_size"]
assert (
self.vocab_size_ % self.num_centroids_ == 0
), f"vocab_size={self.vocab_size_} must be divisible by num_centroids={self.num_centroids_}"
self._vocab_per_centroid_ = self.vocab_size_ // self.num_centroids_
# token -> centroid mapping is derived lazily from the loaded
# token_ordering buffer (weights are not loaded yet at __init__).
self._centroid_of_token_ = None

def _dense_logits(self, last_hidden, token_num, input_embdings_dtype, infer_state, layer_weight):
lm_input = last_hidden.permute(1, 0).view(-1, token_num)
logic_batch = layer_weight.lm_head_weight_(input=lm_input, alloc_func=self.alloc_tensor)
vocab_size = layer_weight.lm_head_weight_.vocab_size
if self.tp_world_size_ == 1:
gather_data = logic_batch
else:
gather_data = self.alloc_tensor((vocab_size, token_num), dtype=input_embdings_dtype)
split_indexes = np.linspace(0, vocab_size, self.tp_world_size_ + 1, dtype=np.int64)
all_gather(
[gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.tp_world_size_)],
logic_batch,
group=infer_state.dist_group,
async_op=False,
)
logic_batch = None
ans_logics = self.alloc_tensor((token_num, vocab_size), dtype=torch.float32)
ans_logics[:, :] = gather_data.permute(1, 0)
gather_data = None
return ans_logics

def _centroid_logits(self, last_hidden, token_num, layer_weight):
"""Gather lm_head rows for the per-token top-K centroid blocks,
dot with the post-norm hidden, scatter into a [N, vocab] -inf tensor
at the original vocab positions. Mathematically equivalent to
dense logits + mask but avoids the [N, vocab] bool tensor and matches
the reference implementations exactly.
"""
centroid_scores = layer_weight.centroids_weight_.mm(last_hidden) # [N, num_centroids]
topk_centroids = torch.topk(centroid_scores, k=self.centroid_top_k_, dim=-1).indices # [N, K]
# token_ordering[i] = original vocab id at reordered position i;
# row c of the (C, vpc) view holds the vocab ids of centroid c.
token_ordering = layer_weight.token_ordering_.weight # [vocab] int64
clusters = token_ordering.view(self.num_centroids_, self._vocab_per_centroid_) # [C, vpc]
selected_vocab = clusters[topk_centroids] # [N, K, vpc] - original vocab ids
num_selected = self.centroid_top_k_ * self._vocab_per_centroid_
selected_vocab = selected_vocab.reshape(token_num, num_selected) # [N, num_selected]
# Gather lm_head rows for the selected vocab ids.
lm_head_w = layer_weight.lm_head_weight_.weight # [vocab, draft_hidden]
selected_embeddings = lm_head_w[selected_vocab] # [N, num_selected, H]
# Sparse logits: dot product per token vs its selected rows.
selected_logits = torch.einsum("nh,nsh->ns", last_hidden, selected_embeddings)
# Scatter to [N, vocab] with -inf elsewhere.
output = torch.full(
(token_num, self.vocab_size_),
float("-inf"),
dtype=selected_logits.dtype,
device=selected_logits.device,
)
output.scatter_(-1, selected_vocab, selected_logits)
return output

def token_forward(self, input_embdings, infer_state, layer_weight):
last_hidden, token_num = self._slice_get_last_input(input_embdings, infer_state)
last_hidden = self._norm(last_hidden, infer_state, layer_weight)
if self.use_ordered_embeddings_:
logits = self._centroid_logits(last_hidden, token_num, layer_weight)
else:
logits = self._dense_logits(last_hidden, token_num, input_embdings.dtype, infer_state, layer_weight)
if self.final_logit_softcapping is not None and self.final_logit_softcapping > 0:
cap = self.final_logit_softcapping
logits = torch.tanh(logits / cap) * cap
assert self._post_projection_weight_ is not None, "post_projection weight is not initialized"
return logits, self._post_projection_weight_.mm(last_hidden)
41 changes: 41 additions & 0 deletions lightllm/models/gemma4_mtp/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
from lightllm.models.gemma4_mtp.layer_weights.pre_and_post_layer_weight import Gemma4MTPPreAndPostLayerWeight


class Gemma4MTPPreLayerInfer(LlamaPreLayerInfer):
"""
Gemma-4 assistant input/output glue.

Input fusion (context_forward / token_forward):
embed = target_wte(input_ids) * sqrt(backbone_hidden) # backbone width
prev = mtp_draft_input_hiddens # backbone width
fused = pre_projection(concat[embed, prev]) # -> draft width

Output projection is handled by Gemma4MTPPostLayerInfer, which returns
post_projection(norm(draft_hidden)) as mtp_main_output_hiddens.
"""

def __init__(self, network_config):
super().__init__(network_config)
self.draft_hidden_ = network_config["hidden_size"]
self.backbone_hidden_ = network_config["backbone_hidden_size"]
self.embed_scale_ = float(self.backbone_hidden_) ** 0.5

def _mtp_fuse(self, input_embdings, infer_state, layer_weight: Gemma4MTPPreAndPostLayerWeight):
prev = infer_state.mtp_draft_input_hiddens
assert (
input_embdings.shape[0] == prev.shape[0]
), f"token count mismatch: embed {input_embdings.shape} vs prev_hidden {prev.shape}"
# The target embedding is backbone width; scale like the target model does.
input_embdings = input_embdings * self.embed_scale_
cat = torch.cat((input_embdings, prev), dim=-1)
return layer_weight.pre_projection_weight_.mm(cat)

def context_forward(self, input_ids, infer_state, layer_weight):
input_embdings = super().context_forward(input_ids, infer_state, layer_weight)
return self._mtp_fuse(input_embdings, infer_state, layer_weight)

def token_forward(self, input_ids, infer_state, layer_weight):
input_embdings = super().token_forward(input_ids, infer_state, layer_weight)
return self._mtp_fuse(input_embdings, infer_state, layer_weight)
21 changes: 21 additions & 0 deletions lightllm/models/gemma4_mtp/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from lightllm.models.gemma4.layer_infer.transformer_layer_infer import Gemma4TransformerLayerInfer


class Gemma4MTPTransformerLayerInfer(Gemma4TransformerLayerInfer):
"""
Gemma-4 assistant decoder block. Identical to the target's block except the
attention is forced into the KV-shared (Q-only) path: K/V are read from the
*target model's* committed cache, and nothing is computed or written here.

`layer_num` is the assistant-local index (0..num_mtp_layers-1) - used for
config lookups (layer_types / RoPE table / per-layer shapes).
`kv_share_target_layer` is the *target model's* layer index whose KV cache
this layer reads (the target's last non-KV-shared layer of the same attention
type). The MTP network_config has num_kv_shared_layers forced to 0 so the
parent __init__ leaves is_kv_shared_ False; it is forced True here.
"""

def __init__(self, layer_num, network_config, kv_share_target_layer):
super().__init__(layer_num, network_config)
self.is_kv_shared_ = True
self.kv_share_target_layer_ = kv_share_target_layer
Empty file.
Loading
Loading