diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 40c19814c2..2af5ab0e6b 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -238,7 +238,12 @@ def model_forward( context=context, ) output = model(**input_dict) - return dict(hidden_states=output, model_metas=model_metas) + + # InternVL-3.5-Flash will change the seqlen, model_metas during forward + model_metas = context.model_metas + seq_length = context.q_seqlens + + return dict(hidden_states=output, model_metas=model_metas, seq_length=seq_length) @record_function('stopping_criteria') @@ -502,7 +507,11 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int): if not is_long_context: ret = await __forward(inputs) if not return_logits and not inputs.is_decoding: - last_token_loc = inputs.seq_length.cumsum(0) - 1 + # fetch seq_length from the context, since models may change it (e.g. InternVL-3.5-Flash) + seq_length = ret.get('seq_length', None) + assert seq_length is not None, 'seq_length cannot be None.' + last_token_loc = seq_length.cumsum(0) - 1 + ret['hidden_states'] = ret['hidden_states'][:, last_token_loc] else: ret = await __long_context_single_forward(inputs, max_seqlen) diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index de5f05efcf..4e315cd998 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -22,6 +22,117 @@ from .utils.model import DeployModelMixin, vlm_model +class Gating(nn.Module): + + def __init__(self, hidden_size=2048, expansion_factor=4, dtype=None, device=None): + super().__init__() + + mid_dim = hidden_size * expansion_factor + + def mlp_block(in_dim, out_dim): + return nn.Sequential( + nn.Linear(in_dim, out_dim, bias=True, dtype=dtype, device=device), + nn.GELU(), + nn.Identity(), + nn.Linear(out_dim, in_dim, bias=True, dtype=dtype, device=device), + nn.Identity(), + nn.LayerNorm(in_dim, dtype=dtype, device=device), + ) + + self.block1 = mlp_block(hidden_size, mid_dim) + self.block2 = mlp_block(hidden_size, mid_dim) + self.block3 = mlp_block(hidden_size, mid_dim) + self.block4 = mlp_block(hidden_size, mid_dim) + + self.gate = nn.Sequential( + nn.LayerNorm(hidden_size, dtype=dtype, device=device), + nn.Linear(hidden_size, 2, bias=True, dtype=dtype, device=device) # 2 experts + ) + + def forward(self, x): + x = x + self.block1(x) + x = x + self.block2(x) + x = x + self.block3(x) + x = x + self.block4(x) + + logits = self.gate(x) # shape: [B, 2] + probs = torch.softmax(logits, dim=-1) + return probs + + +class CrossAttentionPooling(nn.Module): + + def __init__(self, dim, num_heads=16, dtype=None, device=None): + super().__init__() + self.query_token = nn.Parameter(torch.randn(1, dim, dtype=dtype, device=device)) # [1, D] + + self.attn1 = nn.MultiheadAttention(embed_dim=dim, + num_heads=num_heads, + batch_first=True, + dtype=dtype, + device=device) + self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device) + + self.attn2 = nn.MultiheadAttention(embed_dim=dim, + num_heads=num_heads, + batch_first=True, + dtype=dtype, + device=device) + self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device) + + self.attn3 = nn.MultiheadAttention(embed_dim=dim, + num_heads=num_heads, + batch_first=True, + dtype=dtype, + device=device) + self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device) + + self.attn4 = nn.MultiheadAttention(embed_dim=dim, + num_heads=num_heads, + batch_first=True, + dtype=dtype, + device=device) + self.norm4 = nn.LayerNorm(dim, dtype=dtype, device=device) + + def forward(self, batched_tokens: list[torch.Tensor]): + """ + batched_tokens: List of Tensors of shape [Ti, D], length = B + """ + B = len(batched_tokens) + D = batched_tokens[0].shape[-1] + device = batched_tokens[0].device + + # 1. Padding + max_len = max(t.shape[0] for t in batched_tokens) + dtype = self.query_token.dtype + padded = torch.zeros(B, max_len, D, dtype=dtype, device=device) + padding_mask = torch.ones(B, max_len, dtype=torch.bool, device=device) + + for i, t in enumerate(batched_tokens): + L = t.shape[0] + padded[i, :L] = t + padding_mask[i, :L] = False + + # 2. Query token: [B, 1, D] + query = self.query_token.unsqueeze(0).expand(B, -1, -1) # learnable token for each sample + + # 3. First attention + out1, _ = self.attn1(query, padded, padded, key_padding_mask=padding_mask) # [B, 1, D] + out1 = self.norm1(out1) + + # 4. Second attention + out2, _ = self.attn2(out1, padded, padded, key_padding_mask=padding_mask) # [B, 1, D] + out2 = self.norm2(out2) + + out3, _ = self.attn2(out2, padded, padded, key_padding_mask=padding_mask) # [B, 1, D] + out3 = self.norm2(out3) + + out4, _ = self.attn2(out3, padded, padded, key_padding_mask=padding_mask) # [B, 1, D] + out4 = self.norm2(out4) + + return out4.squeeze(1) + + class InternVisionEmbeddings(nn.Module): """Intern vision embedding.""" @@ -367,8 +478,12 @@ def __init__(self, self.downsample_ratio = config.downsample_ratio self.mlp1 = nn.Sequential( nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2, dtype=dtype, device=device), - nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, llm_hidden_size, dtype=dtype, device=device), - nn.GELU(), nn.Linear(llm_hidden_size, llm_hidden_size, dtype=dtype, device=device)) + nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, + llm_hidden_size, + bias=True, + dtype=dtype, + device=device), nn.GELU(), + nn.Linear(llm_hidden_size, llm_hidden_size, bias=True, dtype=dtype, device=device)) # for Mono-InternVL if self.is_mono: @@ -379,6 +494,24 @@ def __init__(self, self.compile_vit = False + self.flash_mode = getattr(config, 'flash_mode', False) + if self.flash_mode: + self.flash_relative_threshold = config.flash_relative_threshold + self.flash_absolute_threshold = config.flash_absolute_threshold + + self.mlp2 = nn.Sequential( + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**4, dtype=dtype, device=device), + nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**4, + llm_hidden_size * 2, + bias=True, + dtype=dtype, + device=device), nn.GELU(), nn.Identity(), + nn.Linear(llm_hidden_size * 2, llm_hidden_size * 2, bias=True, dtype=dtype, device=device), nn.GELU(), + nn.Identity(), nn.Linear(llm_hidden_size * 2, llm_hidden_size, bias=True, dtype=dtype, device=device)) + + self.pooling_before_gating = CrossAttentionPooling(dim=vit_hidden_size, dtype=dtype, device=device) + self.gating = Gating(hidden_size=vit_hidden_size, dtype=dtype, device=device) + def compile_model(self): torch_version = version.parse(torch.__version__) if torch_version < version.parse('2.5.0'): @@ -433,6 +566,186 @@ def extract_feature(self, pixel_values): vit_embeds = self.mlp1(vit_embeds) return vit_embeds + def compress_visual_tokens_in_sentence( + self, + input_embeds: torch.Tensor, + input_ids: torch.Tensor, + img_context_token_id: int, + gate_result, + ) -> tuple: + # reshape + B, N, C = input_embeds.shape + input_embeds = input_embeds.reshape(B * N, C) + input_ids = input_ids.reshape(B * N) + + N, C = input_embeds.shape + lengths, starts, ends = self.get_image_num_per_sample(input_ids, img_context_token_id) + + keep_mask = torch.ones(N, dtype=torch.bool, device=input_embeds.device) + + total_blocks = 0 + block_counts = [] + for length in lengths.tolist(): + if length % 256 != 0: + raise ValueError(f'l % 256 != 0, l = {length}') + num_blocks = length // 256 + block_counts.append(num_blocks) + total_blocks += num_blocks + + flag_idx = 0 + for s, e, l, num_blocks in zip(starts.tolist(), ends.tolist(), lengths.tolist(), block_counts): + for i in range(num_blocks): + block_start = s + i * 256 + block_end = block_start + 256 + + compress = gate_result[flag_idx] + flag_idx += 1 + + if compress: + keep_mask[block_start + 64:block_end] = False + + # update + new_input_embeds = input_embeds[keep_mask.to(input_embeds.device), :] + new_input_ids = input_ids[keep_mask.to(input_ids.device)] + new_image_mask = (new_input_ids == img_context_token_id) + + # reshape back + new_input_ids = new_input_ids.reshape(B, -1) + new_input_embeds = new_input_embeds.reshape(B, -1, C) + + # since multiple sequences may concat together, we need to update the seqlens individually + # we calculate compressed token len for each sequence, and get new len for each sequence + crt_ctx = self.ctx_mgr.current_context() + seq_lengths = crt_ctx.q_seqlens + # split the keep_mask into chunks corresponding to each original sequence + mask_chunks = torch.split(keep_mask, seq_lengths.tolist()) + # the new length of each sequence is the number of tokens kept (sum of True values) + new_seq_lengths = [chunk.sum().item() for chunk in mask_chunks] + + return new_input_embeds, new_input_ids, new_image_mask, new_seq_lengths + + def get_image_num_per_sample(self, input_ids: torch.Tensor, img_context_token_id: int): + input_ids = input_ids.squeeze(0) # (N,) + selected = (input_ids == img_context_token_id) + padded = torch.cat( + [torch.tensor([0], device=selected.device), + selected.int(), + torch.tensor([0], device=selected.device)]) + diff = torch.diff(padded) + + starts = (diff == 1).nonzero(as_tuple=True)[0] + ends = (diff == -1).nonzero(as_tuple=True)[0] + lengths = ends - starts + + return lengths, starts, ends + + def split_and_merge(self, features: torch.Tensor, split_sizes: torch.Tensor): + """ + features: Tensor of shape [T, 1024, 1024] + split_sizes: 1D Tensor like [3, 3, 4] — tile of each sample + + returns: List of Tensors of shape [tile_i * 1024, 1024] + """ + # split features -> each sample a tile list + tile_splits = torch.split(features, split_sizes, dim=0) + + # merge the first two dimensions: tile * 1024 × 1024 + merged = [x.reshape(-1, x.shape[-1]) for x in tile_splits] + + return merged + + def extract_feature_flash(self, pixel_values, lengths): + + vit_embeds_1024 = self.vision_model(pixel_values) + + vit_embeds_1024 = vit_embeds_1024[:, 1:, :] + h = w = int(vit_embeds_1024.shape[1]**0.5) + vit_embeds_1024 = vit_embeds_1024.reshape(vit_embeds_1024.shape[0], h, w, -1) + + # begin moe + lengths = [int(x) for x in lengths.tolist()] + vit_embeds_1024_split_and_merge = self.split_and_merge(vit_embeds_1024, lengths) + + gate = self.pooling_before_gating(vit_embeds_1024_split_and_merge) + gate = self.gating(gate) + + vit_embeds_256 = vit_embeds_1024 + + with torch.no_grad(): + vit_embeds_64 = self.pixel_shuffle(vit_embeds_1024, scale_factor=self.downsample_ratio**2) + vit_embeds_64 = vit_embeds_64.reshape(vit_embeds_64.shape[0], -1, vit_embeds_64.shape[-1]) + vit_embeds_64 = self.mlp2(vit_embeds_64) + + vit_embeds_256 = self.pixel_shuffle(vit_embeds_256, scale_factor=self.downsample_ratio) + vit_embeds_256 = vit_embeds_256.reshape(vit_embeds_256.shape[0], -1, vit_embeds_256.shape[-1]) + vit_embeds_256 = self.mlp1(vit_embeds_256) + + return vit_embeds_64, vit_embeds_256, gate + + def extract_and_compress(self, pixel_values: torch.Tensor, input_ids: torch.Tensor, img_context_token_id: int): + lang_embeds = self.language_model.get_input_embeddings()(input_ids) + + self._mark_dynamic_once(pixel_values, [0]) + + lengths, starts, ends = self.get_image_num_per_sample(input_ids, img_context_token_id) + lengths = lengths // 256 + lengths_sum = torch.ones(int(lengths.sum().item()), dtype=torch.int64) + lengths = lengths_sum.repeat_interleave(1) + vit_embeds_64, vit_embeds_256, gate_result = self.extract_feature_flash(pixel_values, lengths) + + relative_threshold_value = torch.quantile(gate_result[:, 0].to(torch.float32), self.flash_relative_threshold) + gate_result = (gate_result[:, 0] >= relative_threshold_value) & (gate_result[:, 0] + > self.flash_absolute_threshold) + + selected_embeds = [ + vit_embeds_64[i] if gate_result[i] else vit_embeds_256[i] for i in range(gate_result.size(0)) + ] + + vit_embeds = torch.cat(selected_embeds, dim=0) + + # compress visual tokens in sentence + new_lang_embeds, new_input_ids, new_image_mask, new_seq_lengths = self.compress_visual_tokens_in_sentence( + input_embeds=lang_embeds, + input_ids=input_ids, + img_context_token_id=img_context_token_id, + gate_result=gate_result, + ) + + return vit_embeds, new_lang_embeds, new_input_ids, new_image_mask, new_seq_lengths + + def update_forward_inputs(self, input_ids: torch.Tensor, new_seqlens: List[int], + context: StepContext) -> StepContext: + """Update the forward inputs, position_ids and attention metadata.""" + from lmdeploy.pytorch.model_inputs import ModelInputs + + crt_ctx = self.ctx_mgr.current_context() + assert crt_ctx is not None, 'Current context cannot be None.' + + # create model metas with new seqlens + model_metas = [dict(new_seqlen=seqlen) for seqlen in new_seqlens] + + # create new model inputs and context, to get updated position_ids and attn_metadata + device = input_ids.device + total_msgs = len(new_seqlens) + new_seqlens = torch.tensor(new_seqlens, device=device, dtype=torch.long) + new_model_inputs = ModelInputs(input_ids=input_ids, + seq_length=new_seqlens, + history_lengths=torch.zeros(total_msgs, device=device, dtype=torch.long), + block_offsets=crt_ctx.block_offsets, + is_decoding=False, + num_ignored_history=torch.zeros(total_msgs, device=device, dtype=torch.long), + max_q_seqlen=new_seqlens.max().item(), + max_kv_seqlen=new_seqlens.max().item(), + sum_kv_seqlen=new_seqlens.sum().item(), + model_metas=model_metas) + new_ctx = self.ctx_mgr.build_context(new_model_inputs, crt_ctx.model_config) + + # update attributes of current context, in order to return values inside model agent + context.model_metas = model_metas + context.q_seqlens = new_seqlens + + return new_ctx.position_ids, new_ctx.attn_metadata + def forward( self, input_ids: torch.Tensor, @@ -444,13 +757,24 @@ def forward( inputs_embeds: torch.Tensor = None, vision_embedding_indexing: torch.Tensor = None, text_embedding_indexing: torch.Tensor = None, + image_token_id: int = None, + context: StepContext = None, **kwargs, ): if inputs_embeds is None and pixel_values is not None: - # extract feature - self._mark_dynamic_once(pixel_values, [0]) - vit_embeds = self.extract_feature(pixel_values) - lang_embeds = self.language_model.get_input_embeddings()(input_ids) + if not self.flash_mode: + # extract feature + self._mark_dynamic_once(pixel_values, [0]) + vit_embeds = self.extract_feature(pixel_values) + lang_embeds = self.language_model.get_input_embeddings()(input_ids) + else: + # extract feature and compress visual tokens + vit_embeds, lang_embeds, input_ids, image_mask, new_seqlens = self.extract_and_compress( + pixel_values, input_ids, image_token_id) + + # update forward inputs + position_ids, attn_metadata = self.update_forward_inputs(input_ids, new_seqlens, context) + lang_embeds.masked_scatter_(image_mask[..., None], vit_embeds) inputs_embeds = lang_embeds @@ -491,9 +815,27 @@ def prepare_inputs_for_generation( vision_embeddings = context.input_embeddings vision_embedding_indexing = None + if context.is_decoding and context.model_metas is not None and context.model_metas[0] is not None: + # NOTE, zhouxinyu, we need to consider the increasing batch in the decoding stage + # currently implementation will keep the batch size same as the prefill stage + + # model meta from the previous step, therefore +1 for the current decoding step + new_kv_seqlens = [(meta['new_seqlen'] + 1) for meta in context.model_metas] + + # update model metas for the next step + context.model_metas = [dict(new_seqlen=seqlen) for seqlen in new_kv_seqlens] + + # update position ids, attn_metadata + new_kv_seqlens = torch.tensor(new_kv_seqlens, device=input_ids.device, dtype=torch.long) + position_ids = new_kv_seqlens + attn_metadata.kv_seqlens = new_kv_seqlens + attn_metadata.cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(new_kv_seqlens, dim=0, dtype=torch.int32), + (1, 0)) + # vision inputs pixel_values = None image_mask = None + image_token_id = None if context.input_multimodals is not None: pixel_values = [input_mm.get('image', []) for input_mm in context.input_multimodals] # flatten batch @@ -524,27 +866,27 @@ def prepare_inputs_for_generation( vision_embedding_indexing = None if text_embedding_indexing.numel() == 0: text_embedding_indexing = None - return dict( - input_ids=input_ids, - position_ids=position_ids, - past_key_values=past_key_values, - attn_metadata=attn_metadata, - pixel_values=pixel_values, - image_mask=image_mask, - inputs_embeds=inputs_embeds, - vision_embedding_indexing=vision_embedding_indexing, - text_embedding_indexing=text_embedding_indexing, - ) + return dict(input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + pixel_values=pixel_values, + image_mask=image_mask, + inputs_embeds=inputs_embeds, + vision_embedding_indexing=vision_embedding_indexing, + text_embedding_indexing=text_embedding_indexing, + image_token_id=image_token_id, + context=context) else: - return dict( - input_ids=input_ids, - position_ids=position_ids, - past_key_values=past_key_values, - attn_metadata=attn_metadata, - pixel_values=pixel_values, - image_mask=image_mask, - inputs_embeds=inputs_embeds, - ) + return dict(input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + pixel_values=pixel_values, + image_mask=image_mask, + inputs_embeds=inputs_embeds, + image_token_id=image_token_id, + context=context) def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], adapter_id: int): """Load lora weights.""" diff --git a/lmdeploy/vl/model/internvl.py b/lmdeploy/vl/model/internvl.py index f508e31bc6..a2b8d7f9b7 100644 --- a/lmdeploy/vl/model/internvl.py +++ b/lmdeploy/vl/model/internvl.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional import torch -from transformers import AutoConfig, AutoModel, CLIPImageProcessor +from transformers import AutoConfig, AutoModel, AutoTokenizer, CLIPImageProcessor from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS, VisonModel @@ -76,6 +76,9 @@ def __init__(self, hf_config: AutoConfig = None, backend: str = ''): super().__init__(model_path, with_llm, max_memory, hf_config, backend) + IMG_CONTEXT_TOKEN = '' + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False) + self.image_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) def build_preprocessor(self): self.config = self.hf_config