From 075cd4538376956bd733ed03973e25c21593e41d Mon Sep 17 00:00:00 2001 From: zxy Date: Tue, 9 Sep 2025 21:05:32 +0800 Subject: [PATCH 1/8] support internvl flash --- lmdeploy/pytorch/engine/model_agent.py | 9 +- lmdeploy/pytorch/models/internvl.py | 371 +++++++++++++++++++++++-- lmdeploy/vl/model/internvl.py | 9 +- 3 files changed, 362 insertions(+), 27 deletions(-) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 40c19814c2..ae2f68e9e0 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -238,7 +238,9 @@ def model_forward( context=context, ) output = model(**input_dict) - return dict(hidden_states=output, model_metas=model_metas) + seq_length = ctx_mgr.current_context().q_seqlens + + return dict(hidden_states=output, model_metas=model_metas, seq_length=seq_length) @record_function('stopping_criteria') @@ -502,7 +504,10 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int): if not is_long_context: ret = await __forward(inputs) if not return_logits and not inputs.is_decoding: - last_token_loc = inputs.seq_length.cumsum(0) - 1 + # fetch seq_length from the returned context, since models may change it (e.g. InternVL-Flash) + seq_length = ret.get('seq_length', None) + last_token_loc = seq_length.cumsum(0) - 1 + ret['hidden_states'] = ret['hidden_states'][:, last_token_loc] else: ret = await __long_context_single_forward(inputs, max_seqlen) diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index de5f05efcf..7af3336c89 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -22,6 +22,110 @@ from .utils.model import DeployModelMixin, vlm_model +class Gating(nn.Module): + + def __init__(self, hidden_size=2048, expansion_factor=4, dropout=0.1, use_checkpoint=True): + super().__init__() + + self.use_checkpoint = use_checkpoint + mid_dim = hidden_size * expansion_factor + + def mlp_block(in_dim, out_dim): + return nn.Sequential( + nn.Linear(in_dim, out_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(out_dim, in_dim), + nn.Dropout(dropout), + nn.LayerNorm(in_dim), + ) + + self.block1 = mlp_block(hidden_size, mid_dim) + self.block2 = mlp_block(hidden_size, mid_dim) + self.block3 = mlp_block(hidden_size, mid_dim) + self.block4 = mlp_block(hidden_size, mid_dim) + + self.gate = nn.Sequential( + nn.LayerNorm(hidden_size), + nn.Linear(hidden_size, 2) # 2 experts + ) + + def forward(self, x): + if self.use_checkpoint: + import torch.utils.checkpoint as cp + + x = x + cp.checkpoint(self.block1, x) + x = x + cp.checkpoint(self.block2, x) + x = x + cp.checkpoint(self.block3, x) + x = x + cp.checkpoint(self.block4, x) + else: + x = x + self.block1(x) + x = x + self.block2(x) + x = x + self.block3(x) + x = x + self.block4(x) + + logits = self.gate(x) # shape: [B, 2] + probs = torch.softmax(logits, dim=-1) + return probs + + +class CrossAttentionPooling(nn.Module): + + def __init__(self, dim, num_heads=16): + super().__init__() + self.query_token = nn.Parameter(torch.randn(1, dim)) # [1, D] + + self.attn1 = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True) + self.norm1 = nn.LayerNorm(dim) + + self.attn2 = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True) + self.norm2 = nn.LayerNorm(dim) + + self.attn3 = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True) + self.norm3 = nn.LayerNorm(dim) + + self.attn4 = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True) + self.norm4 = nn.LayerNorm(dim) + + def forward(self, batched_tokens: list[torch.Tensor]): + """ + batched_tokens: List of Tensors of shape [Ti, D], length = B + """ + B = len(batched_tokens) + D = batched_tokens[0].shape[-1] + device = batched_tokens[0].device + + # 1. Padding + max_len = max(t.shape[0] for t in batched_tokens) + dtype = self.query_token.dtype + padded = torch.zeros(B, max_len, D, dtype=dtype, device=device) + padding_mask = torch.ones(B, max_len, dtype=torch.bool, device=device) + + for i, t in enumerate(batched_tokens): + L = t.shape[0] + padded[i, :L] = t + padding_mask[i, :L] = False + + # 2. Query token: [B, 1, D] + query = self.query_token.unsqueeze(0).expand(B, -1, -1) # learnable token for each sample + + # 3. First attention + out1, _ = self.attn1(query, padded, padded, key_padding_mask=padding_mask) # [B, 1, D] + out1 = self.norm1(out1) + + # 4. Second attention + out2, _ = self.attn2(out1, padded, padded, key_padding_mask=padding_mask) # [B, 1, D] + out2 = self.norm2(out2) + + out3, _ = self.attn2(out2, padded, padded, key_padding_mask=padding_mask) # [B, 1, D] + out3 = self.norm2(out3) + + out4, _ = self.attn2(out3, padded, padded, key_padding_mask=padding_mask) # [B, 1, D] + out4 = self.norm2(out4) + + return out4.squeeze(1) + + class InternVisionEmbeddings(nn.Module): """Intern vision embedding.""" @@ -379,6 +483,23 @@ def __init__(self, self.compile_vit = False + self.flash_mode = getattr(config, 'flash_mode', False) + if self.flash_mode: + self.flash_relative_threshold = config.flash_relative_threshold + self.flash_absolute_threshold = config.flash_absolute_threshold + + self.mlp2 = nn.Sequential( + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**4, dtype=dtype, device=device), + nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**4, + llm_hidden_size * 2, + dtype=dtype, + device=device), nn.GELU(), nn.Dropout(0.1), + nn.Linear(llm_hidden_size * 2, llm_hidden_size * 2, dtype=dtype, device=device), nn.GELU(), + nn.Dropout(0.1), nn.Linear(llm_hidden_size * 2, llm_hidden_size, dtype=dtype, device=device)) + + self.pooling_before_gating = CrossAttentionPooling(dim=vit_hidden_size) + self.gating = Gating(hidden_size=vit_hidden_size) + def compile_model(self): torch_version = version.parse(torch.__version__) if torch_version < version.parse('2.5.0'): @@ -433,6 +554,195 @@ def extract_feature(self, pixel_values): vit_embeds = self.mlp1(vit_embeds) return vit_embeds + def compress_visual_tokens_in_sentence( + self, + input_embeds: torch.Tensor, + input_ids: torch.Tensor, + img_context_token_id: int, + gate_result, + ) -> tuple: + # reshape + B, N, C = input_embeds.shape + input_embeds = input_embeds.reshape(B * N, C) + input_ids = input_ids.reshape(B * N) + + N, C = input_embeds.shape + input_ids = input_ids.squeeze(0) # (N,) + selected = (input_ids == img_context_token_id) + padded = torch.cat( + [torch.tensor([0], device=selected.device), + selected.int(), + torch.tensor([0], device=selected.device)]) + diff = torch.diff(padded) + + starts = (diff == 1).nonzero(as_tuple=True)[0] + ends = (diff == -1).nonzero(as_tuple=True)[0] + lengths = ends - starts + + keep_mask = torch.ones(N, dtype=torch.bool, device=input_embeds.device) + + delete_flags = torch.zeros(N, dtype=torch.int32, device=input_embeds.device) + + total_blocks = 0 + block_counts = [] + for length in lengths.tolist(): + if length % 256 != 0: + raise ValueError(f'l % 256 != 0, l = {length}') + num_blocks = length // 256 + block_counts.append(num_blocks) + total_blocks += num_blocks + + flag_idx = 0 + for s, e, l, num_blocks in zip(starts.tolist(), ends.tolist(), lengths.tolist(), block_counts): + for i in range(num_blocks): + block_start = s + i * 256 + block_end = block_start + 256 + + compress = gate_result[flag_idx] + flag_idx += 1 + + if compress: + keep_mask[block_start + 64:block_end] = False + delete_flags[block_start + 64:block_end] = 1 + + cumulative_deletes = torch.cumsum(delete_flags, dim=0) + cumulative_deletes = torch.cat([cumulative_deletes, cumulative_deletes[-1:].clone()], dim=0) + + # update + new_input_embeds = input_embeds[keep_mask.to(input_embeds.device), :] + new_input_ids = input_ids[keep_mask.to(input_ids.device)] + new_image_mask = (new_input_ids == img_context_token_id) + + # reshape back + new_input_ids = new_input_ids.reshape(B, -1) + new_input_embeds = new_input_embeds.reshape(B, -1, C) + + return new_input_embeds, new_input_ids, new_image_mask + + def get_image_num_per_sample(self, input_ids: torch.Tensor, img_context_token_id: int): + input_ids = input_ids.squeeze(0) # (N,) + selected = (input_ids == img_context_token_id) + padded = torch.cat( + [torch.tensor([0], device=selected.device), + selected.int(), + torch.tensor([0], device=selected.device)]) + diff = torch.diff(padded) + + starts = (diff == 1).nonzero(as_tuple=True)[0] + ends = (diff == -1).nonzero(as_tuple=True)[0] + lengths = ends - starts + + return lengths + + def split_and_merge(self, features: torch.Tensor, split_sizes: torch.Tensor): + """ + features: Tensor of shape [T, 1024, 1024] + split_sizes: 1D Tensor like [3, 3, 4] — tile of each sample + + returns: List of Tensors of shape [tile_i * 1024, 1024] + """ + # split features -> each sample a tile list + tile_splits = torch.split(features, split_sizes, dim=0) + + # merge the first two dimensions: tile * 1024 × 1024 + merged = [x.reshape(-1, x.shape[-1]) for x in tile_splits] + + return merged + + def extract_feature_flash(self, pixel_values, lengths): + + with torch.no_grad(): + vit_embeds_1024 = self.vision_model(pixel_values) + + vit_embeds_1024 = vit_embeds_1024[:, 1:, :] + h = w = int(vit_embeds_1024.shape[1]**0.5) + vit_embeds_1024 = vit_embeds_1024.reshape(vit_embeds_1024.shape[0], h, w, -1) + + # begin moe + lengths = [int(x) for x in lengths.tolist()] + vit_embeds_1024_split_and_merge = self.split_and_merge(vit_embeds_1024, lengths) + + gate = self.pooling_before_gating(vit_embeds_1024_split_and_merge) + gate = self.gating(gate) + + vit_embeds_256 = vit_embeds_1024.clone() + + with torch.no_grad(): + vit_embeds_64 = self.pixel_shuffle(vit_embeds_1024, scale_factor=self.downsample_ratio**2) + vit_embeds_64 = vit_embeds_64.reshape(vit_embeds_64.shape[0], -1, vit_embeds_64.shape[-1]) + vit_embeds_64 = self.mlp2(vit_embeds_64) + + vit_embeds_256 = self.pixel_shuffle(vit_embeds_256, scale_factor=self.downsample_ratio) + vit_embeds_256 = vit_embeds_256.reshape(vit_embeds_256.shape[0], -1, vit_embeds_256.shape[-1]) + vit_embeds_256 = self.mlp1(vit_embeds_256) + + return vit_embeds_64, vit_embeds_256, gate + + def extract_and_compress(self, pixel_values: torch.Tensor, input_ids: torch.Tensor, img_context_token_id: int): + lang_embeds = self.language_model.get_input_embeddings()(input_ids) + + self._mark_dynamic_once(pixel_values, [0]) + + lengths = self.get_image_num_per_sample(input_ids, img_context_token_id) / 256 + lengths_sum = torch.ones(int(lengths.sum().item()), dtype=torch.int64) + lengths = lengths_sum.repeat_interleave(1) + vit_embeds_64, vit_embeds_256, gate_result = self.extract_feature_flash(pixel_values, lengths) + + relative_threshold_value = torch.quantile(gate_result[:, 0].to(torch.float32), self.flash_relative_threshold) + gate_result = (gate_result[:, 0] >= relative_threshold_value) & (gate_result[:, 0] + > self.flash_absolute_threshold) + + selected_embeds = [] + for i in range(gate_result.size(0)): + if gate_result[i]: + selected_embeds.append(vit_embeds_64[i]) + else: + selected_embeds.append(vit_embeds_256[i]) + + vit_embeds = torch.cat(selected_embeds, dim=0) + + # compress visual tokens in sentence + lang_embeds, input_ids, image_mask = self.compress_visual_tokens_in_sentence( + input_embeds=lang_embeds, + input_ids=input_ids, + img_context_token_id=img_context_token_id, + gate_result=gate_result, + ) + + return vit_embeds, lang_embeds, input_ids, image_mask + + def update_context(self, new_input_ids: torch.Tensor): + """Update the current context with new input_ids.""" + from lmdeploy.pytorch.model_inputs import ModelInputs + + crt_ctx = self.ctx_mgr.current_context() + if crt_ctx is None: + raise RuntimeError('Cannot update a non-existent context.') + + device = new_input_ids.device + new_seq_len = new_input_ids.size(-1) + + # create new model inputs + new_model_inputs = ModelInputs(input_ids=new_input_ids, + seq_length=torch.tensor([new_seq_len], device=device, dtype=torch.long), + history_lengths=torch.tensor([0], device=device, dtype=torch.long), + block_offsets=crt_ctx.block_offsets, + is_decoding=False, + num_ignored_history=torch.tensor([0], device=device, dtype=torch.long), + max_q_seqlen=new_seq_len, + max_kv_seqlen=new_seq_len, + sum_kv_seqlen=new_seq_len, + model_metas=[None]) + + # build and set new context + # NOTE: we keep original block_offsets, vision_inputs and kv_caches, might be wrong + new_ctx = self.ctx_mgr.build_context(new_model_inputs, crt_ctx.model_config) + new_ctx.vision_inputs = crt_ctx.vision_inputs + new_ctx.kv_caches = crt_ctx.kv_caches + self.ctx_mgr.set_context(new_ctx) + + return new_ctx + def forward( self, input_ids: torch.Tensor, @@ -444,13 +754,27 @@ def forward( inputs_embeds: torch.Tensor = None, vision_embedding_indexing: torch.Tensor = None, text_embedding_indexing: torch.Tensor = None, + image_token_id: int = None, **kwargs, ): if inputs_embeds is None and pixel_values is not None: - # extract feature - self._mark_dynamic_once(pixel_values, [0]) - vit_embeds = self.extract_feature(pixel_values) - lang_embeds = self.language_model.get_input_embeddings()(input_ids) + if not self.flash_mode: + # extract feature + self._mark_dynamic_once(pixel_values, [0]) + vit_embeds = self.extract_feature(pixel_values) + lang_embeds = self.language_model.get_input_embeddings()(input_ids) + else: + # extract feature and compress visual tokens + vit_embeds, lang_embeds, new_input_ids, new_image_mask = self.extract_and_compress( + pixel_values=pixel_values, input_ids=input_ids, img_context_token_id=image_token_id) + input_ids = new_input_ids + image_mask = new_image_mask + + # update context and relevant attributes + ctx = self.update_context(new_input_ids=new_input_ids) + position_ids = ctx.position_ids + attn_metadata = ctx.attn_metadata + lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds) inputs_embeds = lang_embeds @@ -494,6 +818,7 @@ def prepare_inputs_for_generation( # vision inputs pixel_values = None image_mask = None + image_token_id = None if context.input_multimodals is not None: pixel_values = [input_mm.get('image', []) for input_mm in context.input_multimodals] # flatten batch @@ -524,27 +849,25 @@ def prepare_inputs_for_generation( vision_embedding_indexing = None if text_embedding_indexing.numel() == 0: text_embedding_indexing = None - return dict( - input_ids=input_ids, - position_ids=position_ids, - past_key_values=past_key_values, - attn_metadata=attn_metadata, - pixel_values=pixel_values, - image_mask=image_mask, - inputs_embeds=inputs_embeds, - vision_embedding_indexing=vision_embedding_indexing, - text_embedding_indexing=text_embedding_indexing, - ) + return dict(input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + pixel_values=pixel_values, + image_mask=image_mask, + inputs_embeds=inputs_embeds, + vision_embedding_indexing=vision_embedding_indexing, + text_embedding_indexing=text_embedding_indexing, + image_token_id=image_token_id) else: - return dict( - input_ids=input_ids, - position_ids=position_ids, - past_key_values=past_key_values, - attn_metadata=attn_metadata, - pixel_values=pixel_values, - image_mask=image_mask, - inputs_embeds=inputs_embeds, - ) + return dict(input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + pixel_values=pixel_values, + image_mask=image_mask, + inputs_embeds=inputs_embeds, + image_token_id=image_token_id) def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], adapter_id: int): """Load lora weights.""" diff --git a/lmdeploy/vl/model/internvl.py b/lmdeploy/vl/model/internvl.py index f508e31bc6..ba0d90a8db 100644 --- a/lmdeploy/vl/model/internvl.py +++ b/lmdeploy/vl/model/internvl.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional import torch -from transformers import AutoConfig, AutoModel, CLIPImageProcessor +from transformers import AutoConfig, AutoModel, AutoTokenizer, CLIPImageProcessor from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel @@ -111,6 +111,11 @@ def build_preprocessor(self): downsample_ratio = self.hf_config.downsample_ratio self.image_tokens_per_patch = int((force_image_size // patch_size)**2 * (downsample_ratio**2)) + if 'internvl3_5' in self.model_path.lower(): + IMG_CONTEXT_TOKEN = '' + tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True, use_fast=False) + self.image_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) + def build_model(self): """Build the vision part of a VLM model when backend is turbomind, or load the whole VLM model when `self.with_llm==True`""" @@ -257,6 +262,8 @@ def proc_messages( sequence_start, tools=tools, enable_thinking=enable_thinking) + # import pdb; pdb.set_trace() + # FIXME: should double check here, different from provided one, in terms of something like \n words return prompt, IMAGE_TOKEN def to_pytorch(self, From 18e4abb55ffe36892c9529b7496bc9883986d6e3 Mon Sep 17 00:00:00 2001 From: zxy Date: Tue, 9 Sep 2025 21:07:26 +0800 Subject: [PATCH 2/8] clean --- lmdeploy/vl/model/internvl.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/lmdeploy/vl/model/internvl.py b/lmdeploy/vl/model/internvl.py index ba0d90a8db..8e85e6269b 100644 --- a/lmdeploy/vl/model/internvl.py +++ b/lmdeploy/vl/model/internvl.py @@ -262,8 +262,6 @@ def proc_messages( sequence_start, tools=tools, enable_thinking=enable_thinking) - # import pdb; pdb.set_trace() - # FIXME: should double check here, different from provided one, in terms of something like \n words return prompt, IMAGE_TOKEN def to_pytorch(self, From 02cfea48cc984ac1bc944337e07c426e3c2ae462 Mon Sep 17 00:00:00 2001 From: zxy Date: Wed, 10 Sep 2025 12:39:46 +0800 Subject: [PATCH 3/8] fix --- lmdeploy/pytorch/models/internvl.py | 54 +++++++---------------------- lmdeploy/vl/model/internvl.py | 8 ++--- 2 files changed, 16 insertions(+), 46 deletions(-) diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 7af3336c89..d3cf76e6d2 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -24,10 +24,9 @@ class Gating(nn.Module): - def __init__(self, hidden_size=2048, expansion_factor=4, dropout=0.1, use_checkpoint=True): + def __init__(self, hidden_size=2048, expansion_factor=4, dropout=0.1): super().__init__() - self.use_checkpoint = use_checkpoint mid_dim = hidden_size * expansion_factor def mlp_block(in_dim, out_dim): @@ -51,18 +50,10 @@ def mlp_block(in_dim, out_dim): ) def forward(self, x): - if self.use_checkpoint: - import torch.utils.checkpoint as cp - - x = x + cp.checkpoint(self.block1, x) - x = x + cp.checkpoint(self.block2, x) - x = x + cp.checkpoint(self.block3, x) - x = x + cp.checkpoint(self.block4, x) - else: - x = x + self.block1(x) - x = x + self.block2(x) - x = x + self.block3(x) - x = x + self.block4(x) + x = x + self.block1(x) + x = x + self.block2(x) + x = x + self.block3(x) + x = x + self.block4(x) logits = self.gate(x) # shape: [B, 2] probs = torch.softmax(logits, dim=-1) @@ -567,22 +558,10 @@ def compress_visual_tokens_in_sentence( input_ids = input_ids.reshape(B * N) N, C = input_embeds.shape - input_ids = input_ids.squeeze(0) # (N,) - selected = (input_ids == img_context_token_id) - padded = torch.cat( - [torch.tensor([0], device=selected.device), - selected.int(), - torch.tensor([0], device=selected.device)]) - diff = torch.diff(padded) - - starts = (diff == 1).nonzero(as_tuple=True)[0] - ends = (diff == -1).nonzero(as_tuple=True)[0] - lengths = ends - starts + lengths, starts, ends = self.get_image_num_per_sample(input_ids, img_context_token_id) keep_mask = torch.ones(N, dtype=torch.bool, device=input_embeds.device) - delete_flags = torch.zeros(N, dtype=torch.int32, device=input_embeds.device) - total_blocks = 0 block_counts = [] for length in lengths.tolist(): @@ -603,10 +582,6 @@ def compress_visual_tokens_in_sentence( if compress: keep_mask[block_start + 64:block_end] = False - delete_flags[block_start + 64:block_end] = 1 - - cumulative_deletes = torch.cumsum(delete_flags, dim=0) - cumulative_deletes = torch.cat([cumulative_deletes, cumulative_deletes[-1:].clone()], dim=0) # update new_input_embeds = input_embeds[keep_mask.to(input_embeds.device), :] @@ -632,7 +607,7 @@ def get_image_num_per_sample(self, input_ids: torch.Tensor, img_context_token_id ends = (diff == -1).nonzero(as_tuple=True)[0] lengths = ends - starts - return lengths + return lengths, starts, ends def split_and_merge(self, features: torch.Tensor, split_sizes: torch.Tensor): """ @@ -651,8 +626,7 @@ def split_and_merge(self, features: torch.Tensor, split_sizes: torch.Tensor): def extract_feature_flash(self, pixel_values, lengths): - with torch.no_grad(): - vit_embeds_1024 = self.vision_model(pixel_values) + vit_embeds_1024 = self.vision_model(pixel_values) vit_embeds_1024 = vit_embeds_1024[:, 1:, :] h = w = int(vit_embeds_1024.shape[1]**0.5) @@ -683,7 +657,8 @@ def extract_and_compress(self, pixel_values: torch.Tensor, input_ids: torch.Tens self._mark_dynamic_once(pixel_values, [0]) - lengths = self.get_image_num_per_sample(input_ids, img_context_token_id) / 256 + lengths, starts, ends = self.get_image_num_per_sample(input_ids, img_context_token_id) + lengths = lengths // 256 lengths_sum = torch.ones(int(lengths.sum().item()), dtype=torch.int64) lengths = lengths_sum.repeat_interleave(1) vit_embeds_64, vit_embeds_256, gate_result = self.extract_feature_flash(pixel_values, lengths) @@ -692,12 +667,9 @@ def extract_and_compress(self, pixel_values: torch.Tensor, input_ids: torch.Tens gate_result = (gate_result[:, 0] >= relative_threshold_value) & (gate_result[:, 0] > self.flash_absolute_threshold) - selected_embeds = [] - for i in range(gate_result.size(0)): - if gate_result[i]: - selected_embeds.append(vit_embeds_64[i]) - else: - selected_embeds.append(vit_embeds_256[i]) + selected_embeds = [ + vit_embeds_64[i] if gate_result[i] else vit_embeds_256[i] for i in range(gate_result.size(0)) + ] vit_embeds = torch.cat(selected_embeds, dim=0) diff --git a/lmdeploy/vl/model/internvl.py b/lmdeploy/vl/model/internvl.py index 8e85e6269b..a2b8d7f9b7 100644 --- a/lmdeploy/vl/model/internvl.py +++ b/lmdeploy/vl/model/internvl.py @@ -76,6 +76,9 @@ def __init__(self, hf_config: AutoConfig = None, backend: str = ''): super().__init__(model_path, with_llm, max_memory, hf_config, backend) + IMG_CONTEXT_TOKEN = '' + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False) + self.image_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) def build_preprocessor(self): self.config = self.hf_config @@ -111,11 +114,6 @@ def build_preprocessor(self): downsample_ratio = self.hf_config.downsample_ratio self.image_tokens_per_patch = int((force_image_size // patch_size)**2 * (downsample_ratio**2)) - if 'internvl3_5' in self.model_path.lower(): - IMG_CONTEXT_TOKEN = '' - tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True, use_fast=False) - self.image_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) - def build_model(self): """Build the vision part of a VLM model when backend is turbomind, or load the whole VLM model when `self.with_llm==True`""" From ab0adfc35921e80fc323b0b90d5b472cd448efef Mon Sep 17 00:00:00 2001 From: zxy Date: Thu, 11 Sep 2025 20:32:04 +0800 Subject: [PATCH 4/8] fix context update for multi requests --- lmdeploy/pytorch/engine/model_agent.py | 3 +- lmdeploy/pytorch/models/internvl.py | 54 ++++++++++++++++---------- 2 files changed, 36 insertions(+), 21 deletions(-) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index ae2f68e9e0..b90f8e4a26 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -505,7 +505,8 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int): ret = await __forward(inputs) if not return_logits and not inputs.is_decoding: # fetch seq_length from the returned context, since models may change it (e.g. InternVL-Flash) - seq_length = ret.get('seq_length', None) + seq_length = ret['seq_length'] + assert seq_length is not None, 'seq_length cannot be None' last_token_loc = seq_length.cumsum(0) - 1 ret['hidden_states'] = ret['hidden_states'][:, last_token_loc] diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index d3cf76e6d2..af987ea50d 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -31,10 +31,10 @@ def __init__(self, hidden_size=2048, expansion_factor=4, dropout=0.1): def mlp_block(in_dim, out_dim): return nn.Sequential( - nn.Linear(in_dim, out_dim), + nn.Linear(in_dim, out_dim, bias=True), nn.GELU(), nn.Dropout(dropout), - nn.Linear(out_dim, in_dim), + nn.Linear(out_dim, in_dim, bias=True), nn.Dropout(dropout), nn.LayerNorm(in_dim), ) @@ -46,7 +46,7 @@ def mlp_block(in_dim, out_dim): self.gate = nn.Sequential( nn.LayerNorm(hidden_size), - nn.Linear(hidden_size, 2) # 2 experts + nn.Linear(hidden_size, 2, bias=True) # 2 experts ) def forward(self, x): @@ -462,8 +462,12 @@ def __init__(self, self.downsample_ratio = config.downsample_ratio self.mlp1 = nn.Sequential( nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2, dtype=dtype, device=device), - nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, llm_hidden_size, dtype=dtype, device=device), - nn.GELU(), nn.Linear(llm_hidden_size, llm_hidden_size, dtype=dtype, device=device)) + nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, + llm_hidden_size, + bias=True, + dtype=dtype, + device=device), nn.GELU(), + nn.Linear(llm_hidden_size, llm_hidden_size, bias=True, dtype=dtype, device=device)) # for Mono-InternVL if self.is_mono: @@ -483,10 +487,11 @@ def __init__(self, nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**4, dtype=dtype, device=device), nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**4, llm_hidden_size * 2, + bias=True, dtype=dtype, device=device), nn.GELU(), nn.Dropout(0.1), - nn.Linear(llm_hidden_size * 2, llm_hidden_size * 2, dtype=dtype, device=device), nn.GELU(), - nn.Dropout(0.1), nn.Linear(llm_hidden_size * 2, llm_hidden_size, dtype=dtype, device=device)) + nn.Linear(llm_hidden_size * 2, llm_hidden_size * 2, bias=True, dtype=dtype, device=device), nn.GELU(), + nn.Dropout(0.1), nn.Linear(llm_hidden_size * 2, llm_hidden_size, bias=True, dtype=dtype, device=device)) self.pooling_before_gating = CrossAttentionPooling(dim=vit_hidden_size) self.gating = Gating(hidden_size=vit_hidden_size) @@ -592,7 +597,16 @@ def compress_visual_tokens_in_sentence( new_input_ids = new_input_ids.reshape(B, -1) new_input_embeds = new_input_embeds.reshape(B, -1, C) - return new_input_embeds, new_input_ids, new_image_mask + # since multiple sequences may concat together, we need to update the seqlens individually + # we calculate compressed token len for each sequence, and get new len for each sequence + crt_ctx = self.ctx_mgr.current_context() + seq_lengths = crt_ctx.q_seqlens + # split the keep_mask into chunks corresponding to each original sequence + mask_chunks = torch.split(keep_mask, seq_lengths.tolist()) + # the new length of each sequence is the number of tokens kept (sum of True values) + new_seq_lengths = [chunk.sum().item() for chunk in mask_chunks] + + return new_input_embeds, new_input_ids, new_image_mask, new_seq_lengths def get_image_num_per_sample(self, input_ids: torch.Tensor, img_context_token_id: int): input_ids = input_ids.squeeze(0) # (N,) @@ -674,16 +688,16 @@ def extract_and_compress(self, pixel_values: torch.Tensor, input_ids: torch.Tens vit_embeds = torch.cat(selected_embeds, dim=0) # compress visual tokens in sentence - lang_embeds, input_ids, image_mask = self.compress_visual_tokens_in_sentence( + lang_embeds, input_ids, image_mask, seq_lens = self.compress_visual_tokens_in_sentence( input_embeds=lang_embeds, input_ids=input_ids, img_context_token_id=img_context_token_id, gate_result=gate_result, ) - return vit_embeds, lang_embeds, input_ids, image_mask + return vit_embeds, lang_embeds, input_ids, image_mask, seq_lens - def update_context(self, new_input_ids: torch.Tensor): + def update_context(self, new_input_ids: torch.Tensor, new_seqlens: List[torch.Tensor]) -> StepContext: """Update the current context with new input_ids.""" from lmdeploy.pytorch.model_inputs import ModelInputs @@ -692,22 +706,22 @@ def update_context(self, new_input_ids: torch.Tensor): raise RuntimeError('Cannot update a non-existent context.') device = new_input_ids.device - new_seq_len = new_input_ids.size(-1) + new_seqlens = torch.tensor(new_seqlens, device=device, dtype=torch.long) # create new model inputs new_model_inputs = ModelInputs(input_ids=new_input_ids, - seq_length=torch.tensor([new_seq_len], device=device, dtype=torch.long), + seq_length=new_seqlens, history_lengths=torch.tensor([0], device=device, dtype=torch.long), block_offsets=crt_ctx.block_offsets, is_decoding=False, num_ignored_history=torch.tensor([0], device=device, dtype=torch.long), - max_q_seqlen=new_seq_len, - max_kv_seqlen=new_seq_len, - sum_kv_seqlen=new_seq_len, + max_q_seqlen=new_seqlens.max().item(), + max_kv_seqlen=new_seqlens.max().item(), + sum_kv_seqlen=new_seqlens.sum().item(), model_metas=[None]) # build and set new context - # NOTE: we keep original block_offsets, vision_inputs and kv_caches, might be wrong + # NOTE: we keep original block_offsets, vision_inputs and kv_caches new_ctx = self.ctx_mgr.build_context(new_model_inputs, crt_ctx.model_config) new_ctx.vision_inputs = crt_ctx.vision_inputs new_ctx.kv_caches = crt_ctx.kv_caches @@ -737,13 +751,13 @@ def forward( lang_embeds = self.language_model.get_input_embeddings()(input_ids) else: # extract feature and compress visual tokens - vit_embeds, lang_embeds, new_input_ids, new_image_mask = self.extract_and_compress( - pixel_values=pixel_values, input_ids=input_ids, img_context_token_id=image_token_id) + vit_embeds, lang_embeds, new_input_ids, new_image_mask, new_seqlens = self.extract_and_compress( + pixel_values, input_ids, image_token_id) input_ids = new_input_ids image_mask = new_image_mask # update context and relevant attributes - ctx = self.update_context(new_input_ids=new_input_ids) + ctx = self.update_context(new_input_ids, new_seqlens) position_ids = ctx.position_ids attn_metadata = ctx.attn_metadata From 9740225803d4600a8b8c1c7ff3b2af08b3fb43f2 Mon Sep 17 00:00:00 2001 From: zxy Date: Thu, 11 Sep 2025 21:15:36 +0800 Subject: [PATCH 5/8] dropout to identity, remove clone, fix type --- lmdeploy/pytorch/models/internvl.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index af987ea50d..00f0755c4f 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -24,7 +24,7 @@ class Gating(nn.Module): - def __init__(self, hidden_size=2048, expansion_factor=4, dropout=0.1): + def __init__(self, hidden_size=2048, expansion_factor=4): super().__init__() mid_dim = hidden_size * expansion_factor @@ -33,9 +33,9 @@ def mlp_block(in_dim, out_dim): return nn.Sequential( nn.Linear(in_dim, out_dim, bias=True), nn.GELU(), - nn.Dropout(dropout), + nn.Identity(), nn.Linear(out_dim, in_dim, bias=True), - nn.Dropout(dropout), + nn.Identity(), nn.LayerNorm(in_dim), ) @@ -489,9 +489,9 @@ def __init__(self, llm_hidden_size * 2, bias=True, dtype=dtype, - device=device), nn.GELU(), nn.Dropout(0.1), + device=device), nn.GELU(), nn.Identity(), nn.Linear(llm_hidden_size * 2, llm_hidden_size * 2, bias=True, dtype=dtype, device=device), nn.GELU(), - nn.Dropout(0.1), nn.Linear(llm_hidden_size * 2, llm_hidden_size, bias=True, dtype=dtype, device=device)) + nn.Identity(), nn.Linear(llm_hidden_size * 2, llm_hidden_size, bias=True, dtype=dtype, device=device)) self.pooling_before_gating = CrossAttentionPooling(dim=vit_hidden_size) self.gating = Gating(hidden_size=vit_hidden_size) @@ -653,7 +653,7 @@ def extract_feature_flash(self, pixel_values, lengths): gate = self.pooling_before_gating(vit_embeds_1024_split_and_merge) gate = self.gating(gate) - vit_embeds_256 = vit_embeds_1024.clone() + vit_embeds_256 = vit_embeds_1024 with torch.no_grad(): vit_embeds_64 = self.pixel_shuffle(vit_embeds_1024, scale_factor=self.downsample_ratio**2) @@ -697,7 +697,7 @@ def extract_and_compress(self, pixel_values: torch.Tensor, input_ids: torch.Tens return vit_embeds, lang_embeds, input_ids, image_mask, seq_lens - def update_context(self, new_input_ids: torch.Tensor, new_seqlens: List[torch.Tensor]) -> StepContext: + def update_context(self, new_input_ids: torch.Tensor, new_seqlens: List[int]) -> StepContext: """Update the current context with new input_ids.""" from lmdeploy.pytorch.model_inputs import ModelInputs From de968e7fa60803cb7f974112119b470659f0d15b Mon Sep 17 00:00:00 2001 From: zxy Date: Tue, 16 Sep 2025 17:13:17 +0800 Subject: [PATCH 6/8] fix acc, explicit dtype, optimize --- lmdeploy/pytorch/backends/graph_runner.py | 7 ++ lmdeploy/pytorch/engine/model_agent.py | 15 +-- lmdeploy/pytorch/models/internvl.py | 123 ++++++++++++++-------- 3 files changed, 93 insertions(+), 52 deletions(-) diff --git a/lmdeploy/pytorch/backends/graph_runner.py b/lmdeploy/pytorch/backends/graph_runner.py index a88872f2bd..a81e81be2f 100644 --- a/lmdeploy/pytorch/backends/graph_runner.py +++ b/lmdeploy/pytorch/backends/graph_runner.py @@ -82,6 +82,13 @@ def update_model_metas( return None + def post_update_model_metas(self, model_metas): + """Post update model meta.""" + if hasattr(self.model, 'post_update_model_metas'): + return self.model.post_update_model_metas(model_metas) + + return None + def get_input_processor(self): """Get input processor.""" if hasattr(self.model, 'get_input_processor'): diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index b90f8e4a26..98285e9946 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -238,9 +238,8 @@ def model_forward( context=context, ) output = model(**input_dict) - seq_length = ctx_mgr.current_context().q_seqlens - - return dict(hidden_states=output, model_metas=model_metas, seq_length=seq_length) + model_metas = model.post_update_model_metas(model_metas) + return dict(hidden_states=output, model_metas=model_metas) @record_function('stopping_criteria') @@ -504,9 +503,13 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int): if not is_long_context: ret = await __forward(inputs) if not return_logits and not inputs.is_decoding: - # fetch seq_length from the returned context, since models may change it (e.g. InternVL-Flash) - seq_length = ret['seq_length'] - assert seq_length is not None, 'seq_length cannot be None' + seq_length = inputs.seq_length + + # for InternVL-3.5-Flash, update seq_length if model_metas contain 'new_seqlen' + model_metas = ret.get('model_metas', None) + if model_metas is not None and 'new_seqlen' in model_metas[0]: + seq_length = torch.tensor([meta['new_seqlen'] for meta in model_metas], device='cuda') + last_token_loc = seq_length.cumsum(0) - 1 ret['hidden_states'] = ret['hidden_states'][:, last_token_loc] diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 00f0755c4f..28d7ee0460 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -24,19 +24,19 @@ class Gating(nn.Module): - def __init__(self, hidden_size=2048, expansion_factor=4): + def __init__(self, hidden_size=2048, expansion_factor=4, dtype=None, device=None): super().__init__() mid_dim = hidden_size * expansion_factor def mlp_block(in_dim, out_dim): return nn.Sequential( - nn.Linear(in_dim, out_dim, bias=True), + nn.Linear(in_dim, out_dim, bias=True, dtype=dtype, device=device), nn.GELU(), nn.Identity(), - nn.Linear(out_dim, in_dim, bias=True), + nn.Linear(out_dim, in_dim, bias=True, dtype=dtype, device=device), nn.Identity(), - nn.LayerNorm(in_dim), + nn.LayerNorm(in_dim, dtype=dtype, device=device), ) self.block1 = mlp_block(hidden_size, mid_dim) @@ -45,8 +45,8 @@ def mlp_block(in_dim, out_dim): self.block4 = mlp_block(hidden_size, mid_dim) self.gate = nn.Sequential( - nn.LayerNorm(hidden_size), - nn.Linear(hidden_size, 2, bias=True) # 2 experts + nn.LayerNorm(hidden_size, dtype=dtype, device=device), + nn.Linear(hidden_size, 2, bias=True, dtype=dtype, device=device) # 2 experts ) def forward(self, x): @@ -62,21 +62,37 @@ def forward(self, x): class CrossAttentionPooling(nn.Module): - def __init__(self, dim, num_heads=16): + def __init__(self, dim, num_heads=16, dtype=None, device=None): super().__init__() - self.query_token = nn.Parameter(torch.randn(1, dim)) # [1, D] - - self.attn1 = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True) - self.norm1 = nn.LayerNorm(dim) - - self.attn2 = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True) - self.norm2 = nn.LayerNorm(dim) - - self.attn3 = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True) - self.norm3 = nn.LayerNorm(dim) - - self.attn4 = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True) - self.norm4 = nn.LayerNorm(dim) + self.query_token = nn.Parameter(torch.randn(1, dim, dtype=dtype, device=device)) # [1, D] + + self.attn1 = nn.MultiheadAttention(embed_dim=dim, + num_heads=num_heads, + batch_first=True, + dtype=dtype, + device=device) + self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device) + + self.attn2 = nn.MultiheadAttention(embed_dim=dim, + num_heads=num_heads, + batch_first=True, + dtype=dtype, + device=device) + self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device) + + self.attn3 = nn.MultiheadAttention(embed_dim=dim, + num_heads=num_heads, + batch_first=True, + dtype=dtype, + device=device) + self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device) + + self.attn4 = nn.MultiheadAttention(embed_dim=dim, + num_heads=num_heads, + batch_first=True, + dtype=dtype, + device=device) + self.norm4 = nn.LayerNorm(dim, dtype=dtype, device=device) def forward(self, batched_tokens: list[torch.Tensor]): """ @@ -493,8 +509,10 @@ def __init__(self, nn.Linear(llm_hidden_size * 2, llm_hidden_size * 2, bias=True, dtype=dtype, device=device), nn.GELU(), nn.Identity(), nn.Linear(llm_hidden_size * 2, llm_hidden_size, bias=True, dtype=dtype, device=device)) - self.pooling_before_gating = CrossAttentionPooling(dim=vit_hidden_size) - self.gating = Gating(hidden_size=vit_hidden_size) + self.pooling_before_gating = CrossAttentionPooling(dim=vit_hidden_size, dtype=dtype, device=device) + self.gating = Gating(hidden_size=vit_hidden_size, dtype=dtype, device=device) + + self.model_metas = None def compile_model(self): torch_version = version.parse(torch.__version__) @@ -688,46 +706,44 @@ def extract_and_compress(self, pixel_values: torch.Tensor, input_ids: torch.Tens vit_embeds = torch.cat(selected_embeds, dim=0) # compress visual tokens in sentence - lang_embeds, input_ids, image_mask, seq_lens = self.compress_visual_tokens_in_sentence( + new_lang_embeds, new_input_ids, new_image_mask, new_seq_lengths = self.compress_visual_tokens_in_sentence( input_embeds=lang_embeds, input_ids=input_ids, img_context_token_id=img_context_token_id, gate_result=gate_result, ) - return vit_embeds, lang_embeds, input_ids, image_mask, seq_lens + return vit_embeds, new_lang_embeds, new_input_ids, new_image_mask, new_seq_lengths - def update_context(self, new_input_ids: torch.Tensor, new_seqlens: List[int]) -> StepContext: + def update_forward_inputs(self, input_ids: torch.Tensor, new_seqlens: List[int]) -> StepContext: """Update the current context with new input_ids.""" from lmdeploy.pytorch.model_inputs import ModelInputs crt_ctx = self.ctx_mgr.current_context() - if crt_ctx is None: - raise RuntimeError('Cannot update a non-existent context.') + assert crt_ctx is not None, 'Current context cannot be None.' - device = new_input_ids.device - new_seqlens = torch.tensor(new_seqlens, device=device, dtype=torch.long) + # fill model metas + self.model_metas = [dict(new_seqlen=seqlen) for seqlen in new_seqlens] # create new model inputs - new_model_inputs = ModelInputs(input_ids=new_input_ids, + device = input_ids.device + total_msgs = len(new_seqlens) + new_seqlens = torch.tensor(new_seqlens, device=device, dtype=torch.long) + new_model_inputs = ModelInputs(input_ids=input_ids, seq_length=new_seqlens, - history_lengths=torch.tensor([0], device=device, dtype=torch.long), + history_lengths=torch.zeros(total_msgs, device=device, dtype=torch.long), block_offsets=crt_ctx.block_offsets, is_decoding=False, - num_ignored_history=torch.tensor([0], device=device, dtype=torch.long), + num_ignored_history=torch.zeros(total_msgs, device=device, dtype=torch.long), max_q_seqlen=new_seqlens.max().item(), max_kv_seqlen=new_seqlens.max().item(), sum_kv_seqlen=new_seqlens.sum().item(), - model_metas=[None]) + model_metas=[None for _ in range(total_msgs)]) - # build and set new context - # NOTE: we keep original block_offsets, vision_inputs and kv_caches + # build new context, to get new position_ids and attn_metadata new_ctx = self.ctx_mgr.build_context(new_model_inputs, crt_ctx.model_config) - new_ctx.vision_inputs = crt_ctx.vision_inputs - new_ctx.kv_caches = crt_ctx.kv_caches - self.ctx_mgr.set_context(new_ctx) - return new_ctx + return new_ctx.position_ids, new_ctx.attn_metadata def forward( self, @@ -751,15 +767,11 @@ def forward( lang_embeds = self.language_model.get_input_embeddings()(input_ids) else: # extract feature and compress visual tokens - vit_embeds, lang_embeds, new_input_ids, new_image_mask, new_seqlens = self.extract_and_compress( + vit_embeds, lang_embeds, input_ids, image_mask, new_seqlens = self.extract_and_compress( pixel_values, input_ids, image_token_id) - input_ids = new_input_ids - image_mask = new_image_mask - # update context and relevant attributes - ctx = self.update_context(new_input_ids, new_seqlens) - position_ids = ctx.position_ids - attn_metadata = ctx.attn_metadata + # update forward inputs + position_ids, attn_metadata = self.update_forward_inputs(input_ids, new_seqlens) lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds) @@ -801,6 +813,20 @@ def prepare_inputs_for_generation( vision_embeddings = context.input_embeddings vision_embedding_indexing = None + if context.is_decoding and context.model_metas is not None and context.model_metas[0] is not None: + # model meta from the previous step, therefore +1 for the current decoding step + new_kv_seqlens = [(meta['new_seqlen'] + 1) for meta in context.model_metas] + + # update model metas for the next step + self.model_metas = [dict(new_seqlen=seqlen) for seqlen in new_kv_seqlens] + + # update position ids, attn_metadata + new_kv_seqlens = torch.tensor(new_kv_seqlens, device=input_ids.device, dtype=torch.long) + position_ids = new_kv_seqlens + attn_metadata.kv_seqlens = new_kv_seqlens + attn_metadata.cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(new_kv_seqlens, dim=0, dtype=torch.int32), + (1, 0)) + # vision inputs pixel_values = None image_mask = None @@ -890,6 +916,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.language_model.load_weights(new_weights.items()) + def post_update_model_metas(self, model_metas): + """Post update model meta.""" + new_model_metas = self.model_metas if self.model_metas is not None else model_metas + return new_model_metas + def get_input_processor(self) -> BaseModelInputProcessor: """Get input processor.""" return self.input_processor From c610ff3be13c1a9970b668a250ca6d19eb58e2a7 Mon Sep 17 00:00:00 2001 From: zxy Date: Wed, 17 Sep 2025 10:54:59 +0800 Subject: [PATCH 7/8] get seqlen from context, pass context in post update --- lmdeploy/pytorch/backends/graph_runner.py | 4 ++-- lmdeploy/pytorch/engine/model_agent.py | 16 +++++++--------- lmdeploy/pytorch/models/internvl.py | 10 ++++++---- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/lmdeploy/pytorch/backends/graph_runner.py b/lmdeploy/pytorch/backends/graph_runner.py index a81e81be2f..3397819ca6 100644 --- a/lmdeploy/pytorch/backends/graph_runner.py +++ b/lmdeploy/pytorch/backends/graph_runner.py @@ -82,10 +82,10 @@ def update_model_metas( return None - def post_update_model_metas(self, model_metas): + def post_update_model_metas(self, context: StepContext): """Post update model meta.""" if hasattr(self.model, 'post_update_model_metas'): - return self.model.post_update_model_metas(model_metas) + return self.model.post_update_model_metas(context) return None diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 98285e9946..c9bd28a87d 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -238,8 +238,10 @@ def model_forward( context=context, ) output = model(**input_dict) - model_metas = model.post_update_model_metas(model_metas) - return dict(hidden_states=output, model_metas=model_metas) + model_metas = model.post_update_model_metas(context) + seq_length = ctx_mgr.current_context().q_seqlens + + return dict(hidden_states=output, model_metas=model_metas, seq_length=seq_length) @record_function('stopping_criteria') @@ -503,13 +505,9 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int): if not is_long_context: ret = await __forward(inputs) if not return_logits and not inputs.is_decoding: - seq_length = inputs.seq_length - - # for InternVL-3.5-Flash, update seq_length if model_metas contain 'new_seqlen' - model_metas = ret.get('model_metas', None) - if model_metas is not None and 'new_seqlen' in model_metas[0]: - seq_length = torch.tensor([meta['new_seqlen'] for meta in model_metas], device='cuda') - + # fetch seq_length from the context, since models may change it (e.g. InternVL-3.5-Flash) + seq_length = ret.get('seq_length', None) + assert seq_length is not None, 'seq_length cannot be None.' last_token_loc = seq_length.cumsum(0) - 1 ret['hidden_states'] = ret['hidden_states'][:, last_token_loc] diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 28d7ee0460..4e7c05e1e9 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -716,7 +716,7 @@ def extract_and_compress(self, pixel_values: torch.Tensor, input_ids: torch.Tens return vit_embeds, new_lang_embeds, new_input_ids, new_image_mask, new_seq_lengths def update_forward_inputs(self, input_ids: torch.Tensor, new_seqlens: List[int]) -> StepContext: - """Update the current context with new input_ids.""" + """Update the forward inputs, position_ids and attention metadata.""" from lmdeploy.pytorch.model_inputs import ModelInputs crt_ctx = self.ctx_mgr.current_context() @@ -742,7 +742,7 @@ def update_forward_inputs(self, input_ids: torch.Tensor, new_seqlens: List[int]) # build new context, to get new position_ids and attn_metadata new_ctx = self.ctx_mgr.build_context(new_model_inputs, crt_ctx.model_config) - + self.ctx_mgr.set_context(new_ctx) return new_ctx.position_ids, new_ctx.attn_metadata def forward( @@ -916,9 +916,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.language_model.load_weights(new_weights.items()) - def post_update_model_metas(self, model_metas): + def post_update_model_metas(self, context: StepContext): """Post update model meta.""" - new_model_metas = self.model_metas if self.model_metas is not None else model_metas + new_model_metas = context.model_metas + if self.model_metas is not None: + new_model_metas = self.model_metas return new_model_metas def get_input_processor(self) -> BaseModelInputProcessor: From 4447ae0654cccbe1eedb2d2787f3f6c200be9df4 Mon Sep 17 00:00:00 2001 From: zxy Date: Wed, 17 Sep 2025 14:51:03 +0800 Subject: [PATCH 8/8] remove self.model_metas --- lmdeploy/pytorch/backends/graph_runner.py | 7 ---- lmdeploy/pytorch/engine/model_agent.py | 6 ++-- lmdeploy/pytorch/models/internvl.py | 42 +++++++++++------------ 3 files changed, 25 insertions(+), 30 deletions(-) diff --git a/lmdeploy/pytorch/backends/graph_runner.py b/lmdeploy/pytorch/backends/graph_runner.py index 3397819ca6..a88872f2bd 100644 --- a/lmdeploy/pytorch/backends/graph_runner.py +++ b/lmdeploy/pytorch/backends/graph_runner.py @@ -82,13 +82,6 @@ def update_model_metas( return None - def post_update_model_metas(self, context: StepContext): - """Post update model meta.""" - if hasattr(self.model, 'post_update_model_metas'): - return self.model.post_update_model_metas(context) - - return None - def get_input_processor(self): """Get input processor.""" if hasattr(self.model, 'get_input_processor'): diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index c9bd28a87d..2af5ab0e6b 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -238,8 +238,10 @@ def model_forward( context=context, ) output = model(**input_dict) - model_metas = model.post_update_model_metas(context) - seq_length = ctx_mgr.current_context().q_seqlens + + # InternVL-3.5-Flash will change the seqlen, model_metas during forward + model_metas = context.model_metas + seq_length = context.q_seqlens return dict(hidden_states=output, model_metas=model_metas, seq_length=seq_length) diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 4e7c05e1e9..4e315cd998 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -512,8 +512,6 @@ def __init__(self, self.pooling_before_gating = CrossAttentionPooling(dim=vit_hidden_size, dtype=dtype, device=device) self.gating = Gating(hidden_size=vit_hidden_size, dtype=dtype, device=device) - self.model_metas = None - def compile_model(self): torch_version = version.parse(torch.__version__) if torch_version < version.parse('2.5.0'): @@ -715,17 +713,18 @@ def extract_and_compress(self, pixel_values: torch.Tensor, input_ids: torch.Tens return vit_embeds, new_lang_embeds, new_input_ids, new_image_mask, new_seq_lengths - def update_forward_inputs(self, input_ids: torch.Tensor, new_seqlens: List[int]) -> StepContext: + def update_forward_inputs(self, input_ids: torch.Tensor, new_seqlens: List[int], + context: StepContext) -> StepContext: """Update the forward inputs, position_ids and attention metadata.""" from lmdeploy.pytorch.model_inputs import ModelInputs crt_ctx = self.ctx_mgr.current_context() assert crt_ctx is not None, 'Current context cannot be None.' - # fill model metas - self.model_metas = [dict(new_seqlen=seqlen) for seqlen in new_seqlens] + # create model metas with new seqlens + model_metas = [dict(new_seqlen=seqlen) for seqlen in new_seqlens] - # create new model inputs + # create new model inputs and context, to get updated position_ids and attn_metadata device = input_ids.device total_msgs = len(new_seqlens) new_seqlens = torch.tensor(new_seqlens, device=device, dtype=torch.long) @@ -738,11 +737,13 @@ def update_forward_inputs(self, input_ids: torch.Tensor, new_seqlens: List[int]) max_q_seqlen=new_seqlens.max().item(), max_kv_seqlen=new_seqlens.max().item(), sum_kv_seqlen=new_seqlens.sum().item(), - model_metas=[None for _ in range(total_msgs)]) - - # build new context, to get new position_ids and attn_metadata + model_metas=model_metas) new_ctx = self.ctx_mgr.build_context(new_model_inputs, crt_ctx.model_config) - self.ctx_mgr.set_context(new_ctx) + + # update attributes of current context, in order to return values inside model agent + context.model_metas = model_metas + context.q_seqlens = new_seqlens + return new_ctx.position_ids, new_ctx.attn_metadata def forward( @@ -757,6 +758,7 @@ def forward( vision_embedding_indexing: torch.Tensor = None, text_embedding_indexing: torch.Tensor = None, image_token_id: int = None, + context: StepContext = None, **kwargs, ): if inputs_embeds is None and pixel_values is not None: @@ -771,7 +773,7 @@ def forward( pixel_values, input_ids, image_token_id) # update forward inputs - position_ids, attn_metadata = self.update_forward_inputs(input_ids, new_seqlens) + position_ids, attn_metadata = self.update_forward_inputs(input_ids, new_seqlens, context) lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds) @@ -814,11 +816,14 @@ def prepare_inputs_for_generation( vision_embedding_indexing = None if context.is_decoding and context.model_metas is not None and context.model_metas[0] is not None: + # NOTE, zhouxinyu, we need to consider the increasing batch in the decoding stage + # currently implementation will keep the batch size same as the prefill stage + # model meta from the previous step, therefore +1 for the current decoding step new_kv_seqlens = [(meta['new_seqlen'] + 1) for meta in context.model_metas] # update model metas for the next step - self.model_metas = [dict(new_seqlen=seqlen) for seqlen in new_kv_seqlens] + context.model_metas = [dict(new_seqlen=seqlen) for seqlen in new_kv_seqlens] # update position ids, attn_metadata new_kv_seqlens = torch.tensor(new_kv_seqlens, device=input_ids.device, dtype=torch.long) @@ -870,7 +875,8 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, vision_embedding_indexing=vision_embedding_indexing, text_embedding_indexing=text_embedding_indexing, - image_token_id=image_token_id) + image_token_id=image_token_id, + context=context) else: return dict(input_ids=input_ids, position_ids=position_ids, @@ -879,7 +885,8 @@ def prepare_inputs_for_generation( pixel_values=pixel_values, image_mask=image_mask, inputs_embeds=inputs_embeds, - image_token_id=image_token_id) + image_token_id=image_token_id, + context=context) def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], adapter_id: int): """Load lora weights.""" @@ -916,13 +923,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.language_model.load_weights(new_weights.items()) - def post_update_model_metas(self, context: StepContext): - """Post update model meta.""" - new_model_metas = context.model_metas - if self.model_metas is not None: - new_model_metas = self.model_metas - return new_model_metas - def get_input_processor(self) -> BaseModelInputProcessor: """Get input processor.""" return self.input_processor