Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Move out validation logic
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
  • Loading branch information
DarkLight1337 committed Jan 4, 2025
commit 0899dce6f34f712d06a4451bd4c6887211bb92bf
2 changes: 1 addition & 1 deletion vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def get_replacement_mantis(item_idx: int):
mm_item_counts,
)

self._validate_placeholders(mm_placeholders, mm_item_counts)
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)

mm_placeholder_ranges = {
modality: [item.to_range() for item in placeholders]
Expand Down
40 changes: 26 additions & 14 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,17 +894,6 @@ def _cached_apply_hf_processor(

mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)

if self.enable_sanity_checks:
mm_item_counts = mm_data_items.get_all_counts()

for modality, item_count in mm_item_counts.items():
for item_idx in range(item_count):
try:
mm_kwargs.get_item(modality, item_idx)
except Exception as e:
# Make it easy to set a breakpoint in the debugger
raise e

return prompt_ids, mm_kwargs

def _bind_and_group_repls(
Expand Down Expand Up @@ -999,7 +988,28 @@ def _apply_prompt_replacements(

return token_ids, text, placeholders

def _validate_placeholders(
def _validate_mm_kwargs(
self,
mm_kwargs: MultiModalKwargs,
mm_item_counts: Mapping[str, int],
) -> None:
for modality, item_count in mm_item_counts.items():
if modality in mm_kwargs.modalities:
items = mm_kwargs.get_items(modality)
else:
items = []

if len(items) != item_count:
raise RuntimeError(
f"Expected there to be {item_count} {modality} items in "
f"keyword arguments corresponding to {item_count} "
f"{modality} data items, but only found {len(items)}! "
"There is likely a problem with your "
"implementation of merged multi-modal processor for this "
"model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_mm_fields_config`).")

def _validate_mm_placeholders(
self,
mm_placeholders: Mapping[str, list[_PlaceholderInfo]],
mm_item_counts: Mapping[str, int],
Expand Down Expand Up @@ -1061,6 +1071,8 @@ def apply(
mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls)

mm_item_counts = mm_items.get_all_counts()
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

hf_mm_placeholders = self._find_mm_placeholders(
mm_prompt_repls,
prompt_ids,
Expand All @@ -1071,7 +1083,7 @@ def apply(
mm_missing_repl_counts = mm_item_counts
mm_missing_repls = dict(mm_prompt_repls)
else:
mm_missing_repl_counts = self._validate_placeholders(
mm_missing_repl_counts = self._validate_mm_placeholders(
hf_mm_placeholders,
mm_item_counts,
allow_missing=True,
Expand Down Expand Up @@ -1106,7 +1118,7 @@ def apply(

mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders}

self._validate_placeholders(mm_placeholders, mm_item_counts)
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)

mm_placeholder_ranges = {
modality: [item.to_range() for item in placeholders]
Expand Down