From 022c6b4a306d0392e671848b08acc91266febf71 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 1 Jan 2025 17:37:26 +0000 Subject: [PATCH 01/47] initial Signed-off-by: Roger Wang --- vllm/multimodal/utils.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 7b6ded6a2708..2e491d017f2c 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -1,6 +1,6 @@ from functools import lru_cache from pathlib import Path -from typing import Optional, TypeVar, Union +from typing import TYPE_CHECKING, Optional, TypeVar, Union from urllib.parse import ParseResult, urlparse import numpy as np @@ -25,6 +25,9 @@ _M = TypeVar("_M") +if TYPE_CHECKING: + from ..multimodal import MultiModalPlaceholderDict + class MediaConnector: @@ -437,3 +440,28 @@ def consecutive_placeholder_ranges( PlaceholderRange(offset=initial_offset + i * item_size, length=item_size) for i in range(num_items) ] + + +def merge_and_sort_placeholders_from_modalities( + modalities: list[str], mm_positions: "MultiModalPlaceholderDict" +) -> tuple[list[tuple[str, int]], list[PlaceholderRange]]: + + placeholder_lists_with_modality = [(modality, mm_positions[modality]) + for modality in modalities + if modality in mm_positions] + + sorted_lists_with_modality = sorted(placeholder_lists_with_modality, + key=lambda x: x[1][0]['offset']) + + # Verify if the sorted order avoids interleaving + merged: list[PlaceholderRange] = [] + for modality, placeholder_list in sorted_lists_with_modality: + if merged and placeholder_list[0]['offset'] < merged[-1]['offset']: + raise ValueError( + "Interleaved mixed-modality inference is currently not " + "supported.") + merged.extend(placeholder_list) + + # Return the order of the keys and the merged result + return [(modality, len(lst)) + for modality, lst in sorted_lists_with_modality], merged From 43fdf458bb6fb9c706d5bdbd89666c75389fbc50 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 1 Jan 2025 17:37:40 +0000 Subject: [PATCH 02/47] fix llava ov Signed-off-by: Roger Wang --- vllm/model_executor/models/llava_onevision.py | 33 ++++++++----------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 0bebc1c745e2..cfe0c4c86f44 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -38,8 +38,8 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -# Result in the max possible feature size (2x2 grid of 336x336px tiles) -MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 +# Ref: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/docs/LLaVA_OneVision.md?plain=1#L14 +MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 2304 # For profile run _MAX_FRAMES_PER_VIDEO = 16 @@ -366,9 +366,11 @@ def input_processor_for_llava_onevision(ctx: InputContext, and "image" not in multi_modal_data): return inputs if "image" in multi_modal_data: - return input_processor_when_multimodal_input_image(ctx, inputs) + inputs = input_processor_when_multimodal_input_image(ctx, inputs) if "video" in multi_modal_data: return input_processor_when_multimodal_input_video(ctx, inputs) + else: + return inputs msg = "Unsupported multi data type" raise NotImplementedError(msg) @@ -832,21 +834,18 @@ def get_multimodal_embeddings( if not modalities: return None - # We make a tuple of each embedding with its modality string. This is a - # temporary workaround for models to handle mixed modalities when - # get_multimodal_embeddings and get_input_embeddings are called - # separately. - # TODO(ywang96): Add support for mixed-modality inference for v1. - multimodal_embeddings: List[Tuple[NestedTensors, str]] = [] + # The result multimoal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () if "images" in modalities: image_input = modalities["images"] vision_embeddings = self._process_image_input(image_input) - multimodal_embeddings.append((vision_embeddings, "image")) + multimodal_embeddings += tuple(vision_embeddings) if "videos" in modalities: video_input = modalities["videos"] video_embeddings = self._process_video_pixels(video_input) - multimodal_embeddings.append((video_embeddings, "video")) + multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings @@ -858,15 +857,9 @@ def get_input_embeddings( ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: - for embeddings, modality in multimodal_embeddings: - if modality == "image": - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, embeddings, - self.config.image_token_index) - if modality == "video": - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, embeddings, - self.config.video_token_index) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + [self.config.image_token_index, self.config.video_token_index]) return inputs_embeds def forward( From e0fb002a37f0edb716e409029361e5ad61bf201b Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 1 Jan 2025 17:38:06 +0000 Subject: [PATCH 03/47] iterate Signed-off-by: Roger Wang --- vllm/v1/engine/mm_input_mapper.py | 2 ++ vllm/v1/request.py | 22 +++++++++++++++++++--- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index 8bfc739b3dbb..d2dd5f7d07b0 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List, Optional import PIL +import torch from blake3 import blake3 from vllm.config import ModelConfig @@ -102,6 +103,7 @@ def process_inputs( {"image": [image_inputs[input_id]]}, mm_processor_kwargs=mm_processor_kwargs, ) + mm_input["image"] = torch.tensor([]) if self.use_cache: # Add to cache diff --git a/vllm/v1/request.py b/vllm/v1/request.py index f4783ae366ef..7806fadb79a3 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -4,6 +4,7 @@ from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs +from vllm.multimodal.utils import merge_and_sort_placeholders_from_modalities from vllm.sampling_params import SamplingParams from vllm.sequence import RequestMetrics from vllm.v1.engine import EngineCoreRequest @@ -49,16 +50,31 @@ def __init__( self.num_computed_tokens = 0 # Multi-modal input metadata. + all_modalities = ["image", "video", "audio"] mm_positions = self.inputs.multi_modal_placeholders if mm_positions: - # FIXME(woosuk): Support other modalities. - self.mm_positions = mm_positions.get("image", []) + sorted_modalities, sorted_mm_positions = merge_and_sort_placeholders_from_modalities( # noqa: E501 + all_modalities, mm_positions) + self.mm_positions = sorted_mm_positions else: self.mm_positions = [] # Output of the mm input mapper (e.g., image tensors). self.mm_inputs: List[MultiModalKwargs] = [] if self.inputs.multi_modal_inputs: - self.mm_inputs = self.inputs.multi_modal_inputs + if len(sorted_modalities) == 1: + self.mm_inputs = self.inputs.multi_modal_inputs + else: + for modality, count in sorted_modalities: + for i in range(len(self.inputs.multi_modal_inputs)): + if modality in self.inputs.multi_modal_inputs[i]: + for j in range(count): + self.inputs.multi_modal_inputs[i + + j].pop(modality) + self.mm_inputs.append( + self.inputs.multi_modal_inputs[i + j]) + break + assert len(self.mm_inputs) == len(self.inputs.multi_modal_inputs) + assert len(self.mm_inputs) == len(self.mm_positions) self.mm_hashes: List[str] = self.inputs.multi_modal_hashes From b45010b77e63d363e8d4748fbe44fb88a1921b39 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 2 Jan 2025 07:24:31 +0000 Subject: [PATCH 04/47] revert padding tensor Signed-off-by: Roger Wang --- vllm/v1/engine/mm_input_mapper.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index d2dd5f7d07b0..8bfc739b3dbb 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -1,7 +1,6 @@ from typing import Any, Dict, List, Optional import PIL -import torch from blake3 import blake3 from vllm.config import ModelConfig @@ -103,7 +102,6 @@ def process_inputs( {"image": [image_inputs[input_id]]}, mm_processor_kwargs=mm_processor_kwargs, ) - mm_input["image"] = torch.tensor([]) if self.use_cache: # Add to cache From d83e25e5e93fc3341f6e365d78b479ee164cccc4 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 2 Jan 2025 11:10:59 +0000 Subject: [PATCH 05/47] simplify Signed-off-by: Roger Wang --- vllm/multimodal/utils.py | 9 ++++++--- vllm/v1/request.py | 40 ++++++++++++++++++++++++---------------- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 2e491d017f2c..460c9d96bedf 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -444,7 +444,11 @@ def consecutive_placeholder_ranges( def merge_and_sort_placeholders_from_modalities( modalities: list[str], mm_positions: "MultiModalPlaceholderDict" -) -> tuple[list[tuple[str, int]], list[PlaceholderRange]]: +) -> tuple[list[str], list[PlaceholderRange]]: + + # For single modality, its placeholder ranges are already sorted. + if len(modalities) == 1: + return modalities, list(mm_positions[modalities[0]]) placeholder_lists_with_modality = [(modality, mm_positions[modality]) for modality in modalities @@ -463,5 +467,4 @@ def merge_and_sort_placeholders_from_modalities( merged.extend(placeholder_list) # Return the order of the keys and the merged result - return [(modality, len(lst)) - for modality, lst in sorted_lists_with_modality], merged + return [modality for modality, _ in sorted_lists_with_modality], merged diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 7806fadb79a3..4317b3e70f27 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -50,32 +50,40 @@ def __init__( self.num_computed_tokens = 0 # Multi-modal input metadata. - all_modalities = ["image", "video", "audio"] mm_positions = self.inputs.multi_modal_placeholders if mm_positions: + available_modalities = mm_positions.keys() sorted_modalities, sorted_mm_positions = merge_and_sort_placeholders_from_modalities( # noqa: E501 - all_modalities, mm_positions) + list(available_modalities), mm_positions) self.mm_positions = sorted_mm_positions else: + sorted_modalities = [] self.mm_positions = [] + # Output of the mm input mapper (e.g., image tensors). self.mm_inputs: List[MultiModalKwargs] = [] if self.inputs.multi_modal_inputs: - if len(sorted_modalities) == 1: - self.mm_inputs = self.inputs.multi_modal_inputs - else: - for modality, count in sorted_modalities: - for i in range(len(self.inputs.multi_modal_inputs)): - if modality in self.inputs.multi_modal_inputs[i]: - for j in range(count): - self.inputs.multi_modal_inputs[i + - j].pop(modality) - self.mm_inputs.append( - self.inputs.multi_modal_inputs[i + j]) - break - assert len(self.mm_inputs) == len(self.inputs.multi_modal_inputs) - assert len(self.mm_inputs) == len(self.mm_positions) + # NOTE: We only need to sort multimodal kwargs when there + # are multiple modalities involved. + if len(sorted_modalities) > 1: + modality_order_dict = { + modality: order + for order, modality in enumerate(sorted_modalities) + } + + # Sanity check to make sure each multimodal input + # has only one modality key. + for mm_input in self.inputs.multi_modal_inputs: + assert len(mm_input.modalities) == 1 + + # Sort MultiModalKwags to match sorted_mm_positions + self.inputs.multi_modal_inputs.sort( + key=lambda mm_input: modality_order_dict[list( + mm_input.modalities)[0]]) + + self.mm_inputs = self.inputs.multi_modal_inputs + assert len(self.mm_inputs) == len(self.mm_positions) self.mm_hashes: List[str] = self.inputs.multi_modal_hashes # Cache the computed kv block hashes of the request to avoid From d13b0f7533e9c3e5efd0b96bb2885cba3ee9cca9 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 2 Jan 2025 11:11:12 +0000 Subject: [PATCH 06/47] comment Signed-off-by: Roger Wang --- vllm/multimodal/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 460c9d96bedf..fefc7c76eed1 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -466,5 +466,5 @@ def merge_and_sort_placeholders_from_modalities( "supported.") merged.extend(placeholder_list) - # Return the order of the keys and the merged result + # Return the order of modalities and the merged placeholder ranges return [modality for modality, _ in sorted_lists_with_modality], merged From 6959ec0ca8ec37e17f4c755889b2c251334f2bac Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 2 Jan 2025 11:58:17 +0000 Subject: [PATCH 07/47] simplify and doc Signed-off-by: Roger Wang --- vllm/multimodal/utils.py | 19 ++++++++++++++++++- vllm/v1/request.py | 3 +-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index fefc7c76eed1..ef980c48b852 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -443,8 +443,25 @@ def consecutive_placeholder_ranges( def merge_and_sort_placeholders_from_modalities( - modalities: list[str], mm_positions: "MultiModalPlaceholderDict" + mm_positions: "MultiModalPlaceholderDict" ) -> tuple[list[str], list[PlaceholderRange]]: + """Given a MultiModalPlaceholderDict, merge all PlaceholderRange + objects from all available modalities into a single list of + PlaceholderRange, sorted by their offset (starting index in the input + sequence) in the ascending order. + + Raises: + ValueError: If the input prompt has interleaved placeholders from + different modalities (e.g, "