[VLM] Optimize get_rope_index for GLM4v#17420
Conversation
Summary of ChangesHello @yuan-luo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a series of targeted optimizations to the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces a new benchmark script for get_rope_index_glm4v and includes several performance optimizations and a critical bug fix in the MRotaryEmbedding.get_rope_index_glm4v function. The optimizations primarily focus on reducing CPU-GPU transfers by avoiding .item() calls, explicitly setting device for torch.arange, preallocating lists, and consolidating torch.cat and torch.tensor operations outside loops. The critical fix addresses an issue where image_index and video_index were not reset per batch item, which could lead to incorrect multimodal data processing.
| # Move attention mask to device once to avoid repeated transfers | ||
| attention_mask = attention_mask.to(total_input_ids.device) |
| mrope_position_deltas = torch.tensor( | ||
| mrope_position_deltas, device=input_ids.device | ||
| ).unsqueeze(1) |
| # Concatenate once outside for speed | ||
| llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) |
| for j, token in enumerate(input_tokens): | ||
| if token == video_start_token_id: | ||
| video_check_flg = True | ||
| elif token == video_end_token_id: | ||
| video_check_flg = False | ||
|
|
||
| if token == image_token_id and not video_check_flg: | ||
| input_token_type.append("image") | ||
| input_token_type[j] = "image" | ||
| elif token == image_token_id and video_check_flg: | ||
| input_token_type.append("video") | ||
| input_token_type[j] = "video" | ||
| else: | ||
| input_token_type.append("text") | ||
| input_token_type[j] = "text" |
| torch.arange(llm_grid_h, device=position_ids.device) | ||
| .view(1, -1, 1) | ||
| .expand(llm_grid_t, -1, llm_grid_w) | ||
| .flatten() | ||
| .expand(llm_grid_t, llm_grid_h, llm_grid_w) | ||
| .reshape(-1) | ||
| ) |
There was a problem hiding this comment.
| mrope_position_deltas.append( | ||
| llm_positions.max() + 1 - len(total_input_ids[i]) | ||
| llm_positions.max().item() + 1 - len(total_input_ids[i]) | ||
| ) |
| torch.arange(llm_grid_t, device=position_ids.device) | ||
| .view(-1, 1) | ||
| .expand(-1, llm_grid_h * llm_grid_w) | ||
| .flatten() | ||
| .expand(llm_grid_t, llm_grid_h * llm_grid_w) | ||
| .reshape(-1) | ||
| ) |
There was a problem hiding this comment.
| max_position_ids = position_ids.amax(dim=0, keepdim=False) | ||
| mrope_position_deltas = ( | ||
| max_position_ids.amax(-1, keepdim=True) | ||
| + 1 | ||
| - attention_mask.shape[-1] | ||
| ) |
| # Use torch.arange with in-place expansion | ||
| arange_ids = torch.arange(length, device=input_ids.device).view( | ||
| 1, 1, -1 | ||
| ) | ||
| position_ids = arange_ids.expand(3, batch_size, length) |
|
/tag-and-rerun-ci |
87843cd to
ac17cff
Compare
|
/rerun-failed-ci |
ac17cff to
a78b779
Compare
|
/rerun-failed-ci |
1 similar comment
|
/rerun-failed-ci |
|
Impressive optimization 👍 |
|
/rerun-failed-ci |
|
overall LGTM, but since it's a common method, we might need benchmarks for other models too |
a78b779 to
0952830
Compare
Sure, will benchmark other VLMs. Since this PR is dedicated form GLM4V, will follow up in new PRs when this PR is merged. |
|
/rerun-failed-ci |
2 similar comments
|
/rerun-failed-ci |
|
/rerun-failed-ci |
0952830 to
a4ad83e
Compare
|
It seems GLM4.5V was broken on the main branch. I encountered this error on main branch. H800 |
|
/rerun-failed-ci |
2 similar comments
|
/rerun-failed-ci |
|
/rerun-failed-ci |
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Motivation
Speedup 12% to 600%(long token length) for get_rope_index for GLM4v.
Benchmark test added.
lmms_evals no drop.
PR:
Main:
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-cicc: @zRzRzRzRzRzRzR