Skip to content

[VLM] Optimize get_rope_index for GLM4v#17420

Merged
BBuf merged 2 commits intosgl-project:mainfrom
antgroup:opt_get_rope_index_glm4v
Feb 1, 2026
Merged

[VLM] Optimize get_rope_index for GLM4v#17420
BBuf merged 2 commits intosgl-project:mainfrom
antgroup:opt_get_rope_index_glm4v

Conversation

@yuan-luo
Copy link
Collaborator

@yuan-luo yuan-luo commented Jan 20, 2026

Motivation

Speedup 12% to 600%(long token length) for get_rope_index for GLM4v.
Benchmark test added.

image

lmms_evals no drop.

PR:

root@c7e9bb6a6789:/sgl-workspace/bench_script# python3 benchmark_rope_index.py
[2026-01-20 13:18:27] INFO utils.py:148: Note: detected 224 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-20 13:18:27] INFO utils.py:151: Note: NumExpr detected 224 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-20 13:18:27] INFO utils.py:164: NumExpr defaulting to 16 threads.
Namespace(model_name='GLM4V', tp_size=1, device='cuda', warmup_iter=10, benchmark_iter=100, dtype='int64', seed=0, num_tokens=None, batch_size=1, pad_ratio=0.0, spatial_merge_size=2, num_images=1, image_patch_tokens=256, num_videos=1, video_patch_tokens=256, out_dir='.')
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=1 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=1):
Multimodal: mean=0.00006863s, median=0.00006843s, p99=0.00007617s
Fallback:   mean=0.00006712s, median=0.00006628s, p99=0.00007729s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=2 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=2):
Multimodal: mean=0.00006740s, median=0.00006676s, p99=0.00007325s
Fallback:   mean=0.00006693s, median=0.00006676s, p99=0.00007153s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=4 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=4):
Multimodal: mean=0.00006786s, median=0.00006700s, p99=0.00007822s
Fallback:   mean=0.00007262s, median=0.00006831s, p99=0.00009158s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=8 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=8):
Multimodal: mean=0.00006834s, median=0.00006843s, p99=0.00007395s
Fallback:   mean=0.00006839s, median=0.00006819s, p99=0.00007631s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=16 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=16):
Multimodal: mean=0.00006753s, median=0.00006676s, p99=0.00007678s
Fallback:   mean=0.00006691s, median=0.00006628s, p99=0.00007582s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=32 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=32):
Multimodal: mean=0.00006731s, median=0.00006652s, p99=0.00007820s
Fallback:   mean=0.00006707s, median=0.00006652s, p99=0.00007345s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=64 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=64):
Multimodal: mean=0.00006737s, median=0.00006700s, p99=0.00007582s
Fallback:   mean=0.00006725s, median=0.00006676s, p99=0.00007298s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=128 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=128):
Multimodal: mean=0.00006764s, median=0.00006700s, p99=0.00007632s
Fallback:   mean=0.00006750s, median=0.00006700s, p99=0.00007276s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=256 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=256):
Multimodal: mean=0.00006818s, median=0.00006723s, p99=0.00007582s
Fallback:   mean=0.00006721s, median=0.00006676s, p99=0.00007439s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=512 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=512):
Multimodal: mean=0.00006747s, median=0.00006700s, p99=0.00007679s
Fallback:   mean=0.00006716s, median=0.00006676s, p99=0.00007488s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=1024 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=1024):
Multimodal: mean=0.00073278s, median=0.00071609s, p99=0.00114831s
Fallback:   mean=0.00006872s, median=0.00006771s, p99=0.00007756s
Fallback Speedup over Multimodal: 10.66258456x
Fallback Speedup over Multimodal: 10.66258456x
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=2048 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=2048):
Multimodal: mean=0.00088511s, median=0.00085711s, p99=0.00097703s
Fallback:   mean=0.00007475s, median=0.00006795s, p99=0.00010307s
Fallback Speedup over Multimodal: 11.84109467x
Fallback Speedup over Multimodal: 11.84109467x
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=4096 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=4096):
Multimodal: mean=0.00115255s, median=0.00115132s, p99=0.00118212s
Fallback:   mean=0.00007426s, median=0.00006843s, p99=0.00009696s
Fallback Speedup over Multimodal: 15.52043535x
Fallback Speedup over Multimodal: 15.52043535x
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=8192 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=8192):
Multimodal: mean=0.00179818s, median=0.00177395s, p99=0.00249509s
Fallback:   mean=0.00006861s, median=0.00006723s, p99=0.00007953s
Fallback Speedup over Multimodal: 26.20972338x
Fallback Speedup over Multimodal: 26.20972338x
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=16384 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=16384):
Multimodal: mean=0.00371805s, median=0.00366211s, p99=0.00525098s
Fallback:   mean=0.00006894s, median=0.00006747s, p99=0.00007742s
Fallback Speedup over Multimodal: 53.93269929x
Fallback Speedup over Multimodal: 53.93269929x
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=32768 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=32768):
Multimodal: mean=0.00686914s, median=0.00684130s, p99=0.00718685s
Fallback:   mean=0.00006953s, median=0.00006795s, p99=0.00007931s
Fallback Speedup over Multimodal: 98.79056371x
Fallback Speedup over Multimodal: 98.79056371x
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=65536 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=65536):
Multimodal: mean=0.01308443s, median=0.01293290s, p99=0.01568485s
Fallback:   mean=0.00007095s, median=0.00006795s, p99=0.00008220s
Fallback Speedup over Multimodal: 184.41506771x
Fallback Speedup over Multimodal: 184.41506771x
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=131072 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=131072):
Multimodal: mean=0.02458367s, median=0.02454710s, p99=0.02778321s
Fallback:   mean=0.00007313s, median=0.00007129s, p99=0.00008432s
Fallback Speedup over Multimodal: 336.15240921x
Fallback Speedup over Multimodal: 336.15240921x

Main:

root@c7e9bb6a6789:/sgl-workspace/bench_script# python benchmark_rope_index.py 
[2026-01-21 01:21:26] INFO utils.py:148: Note: detected 224 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-21 01:21:26] INFO utils.py:151: Note: NumExpr detected 224 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-21 01:21:26] INFO utils.py:164: NumExpr defaulting to 16 threads.
Namespace(model_name='GLM4V', tp_size=1, device='cuda', warmup_iter=10, benchmark_iter=100, dtype='int64', seed=0, num_tokens=None, batch_size=1, pad_ratio=0.0, spatial_merge_size=2, num_images=1, image_patch_tokens=256, num_videos=1, video_patch_tokens=256, out_dir='.')
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=1 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=1):
Multimodal: mean=0.00007301s, median=0.00007248s, p99=0.00008588s
Fallback:   mean=0.00007123s, median=0.00007033s, p99=0.00007727s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=2 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=2):
Multimodal: mean=0.00007161s, median=0.00007105s, p99=0.00007751s
Fallback:   mean=0.00007135s, median=0.00007081s, p99=0.00008059s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=4 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=4):
Multimodal: mean=0.00007182s, median=0.00007105s, p99=0.00008083s
Fallback:   mean=0.00007078s, median=0.00007057s, p99=0.00007510s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=8 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=8):
Multimodal: mean=0.00007123s, median=0.00007057s, p99=0.00008227s
Fallback:   mean=0.00007075s, median=0.00007033s, p99=0.00007750s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=16 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=16):
Multimodal: mean=0.00007248s, median=0.00007153s, p99=0.00008084s
Fallback:   mean=0.00007366s, median=0.00007319s, p99=0.00008538s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=32 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=32):
Multimodal: mean=0.00007210s, median=0.00007129s, p99=0.00007941s
Fallback:   mean=0.00007169s, median=0.00007129s, p99=0.00008083s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=64 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=64):
Multimodal: mean=0.00007199s, median=0.00007129s, p99=0.00007758s
Fallback:   mean=0.00007163s, median=0.00007129s, p99=0.00007610s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=128 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=128):
Multimodal: mean=0.00007338s, median=0.00007153s, p99=0.00009597s
Fallback:   mean=0.00007150s, median=0.00007105s, p99=0.00007868s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=256 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=256):
Multimodal: mean=0.00007224s, median=0.00007153s, p99=0.00008178s
Fallback:   mean=0.00007251s, median=0.00007153s, p99=0.00008560s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=512 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=512):
Multimodal: mean=0.00007182s, median=0.00007129s, p99=0.00007965s
Fallback:   mean=0.00007184s, median=0.00007129s, p99=0.00008036s
[INFO] num_tokens too small for multimodal segments; skip multimodal benchmark.
Fallback Speedup over Multimodal: nanx
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=1024 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=1024):
Multimodal: mean=0.00057760s, median=0.00057685s, p99=0.00059559s
Fallback:   mean=0.00009443s, median=0.00007212s, p99=0.00024320s
Fallback Speedup over Multimodal: 6.11677524x
Fallback Speedup over Multimodal: 6.11677524x
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=2048 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=2048):
Multimodal: mean=0.00072351s, median=0.00072169s, p99=0.00077027s
Fallback:   mean=0.00007354s, median=0.00007224s, p99=0.00008548s
Fallback Speedup over Multimodal: 9.83793685x
Fallback Speedup over Multimodal: 9.83793685x
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=4096 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=4096):
Multimodal: mean=0.00100719s, median=0.00100577s, p99=0.00103309s
Fallback:   mean=0.00007341s, median=0.00007224s, p99=0.00008432s
Fallback Speedup over Multimodal: 13.72020136x
Fallback Speedup over Multimodal: 13.72020136x
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=8192 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=8192):
Multimodal: mean=0.00162929s, median=0.00162792s, p99=0.00166395s
Fallback:   mean=0.00007344s, median=0.00007200s, p99=0.00008435s
Fallback Speedup over Multimodal: 22.18458642x
Fallback Speedup over Multimodal: 22.18458642x
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=16384 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=16384):
Multimodal: mean=0.01191243s, median=0.01175499s, p99=0.01465990s
Fallback:   mean=0.00008588s, median=0.00008965s, p99=0.00012821s
Fallback Speedup over Multimodal: 138.71285397x
Fallback Speedup over Multimodal: 138.71285397x
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=32768 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=32768):
Multimodal: mean=0.01443047s, median=0.01413357s, p99=0.01946598s
Fallback:   mean=0.00008696s, median=0.00007868s, p99=0.00012399s
Fallback Speedup over Multimodal: 165.93765593x
Fallback Speedup over Multimodal: 165.93765593x
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=65536 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=65536):
Multimodal: mean=0.11286658s, median=0.11191475s, p99=0.12707203s
Fallback:   mean=0.00008549s, median=0.00007534s, p99=0.00011458s
Fallback Speedup over Multimodal: 1320.19844944x
Fallback Speedup over Multimodal: 1320.19844944x
================================================================================
Evaluating: GLM4V tp_size=1 num_tokens=131072 batch=1 pad_ratio=0.0 images/sample=1 image_patch_tokens=256 videos/sample=1 video_patch_tokens=256 dtype=torch.int64 device=cuda

Performance for config (B=1, T=131072):
Multimodal: mean=0.16182978s, median=0.16150701s, p99=0.16971291s
Fallback:   mean=0.00008327s, median=0.00007534s, p99=0.00012350s
Fallback Speedup over Multimodal: 1943.48827487x
Fallback Speedup over Multimodal: 1943.48827487x

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

cc: @zRzRzRzRzRzRzR

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yuan-luo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a series of targeted optimizations to the get_rope_index_glm4v function, which is critical for the performance of GLM4v models. The primary goal is to accelerate the computation of Rotary Position Embeddings, especially for multimodal inputs, by reducing overhead from CPU-GPU data transfers and streamlining tensor operations. These enhancements have resulted in significant speed improvements, validated by a newly added benchmark script, all while maintaining the existing model accuracy.

Highlights

  • Performance Optimization: The get_rope_index_glm4v function for GLM4v models has been significantly optimized, leading to substantial speedups, particularly for multimodal inputs with large token counts. Benchmarks show speedups ranging from 10x to over 336x for higher token counts.
  • Reduced CPU-GPU Transfers: Key optimizations include minimizing CPU-GPU data transfers by replacing .item() calls with direct int() conversions and explicitly specifying the device for torch.arange operations.
  • Efficient Tensor Operations: Tensor manipulation has been made more efficient by using torch.amax() instead of torch.max(), preferring .reshape(-1) over .flatten(), and optimizing tensor concatenation and assignment within loops through pre-allocation and advanced indexing.
  • New Benchmark Script: A new benchmarking script (benchmark/bench_rope/benchmark_rope_index.py) has been added to systematically measure and validate the performance of the get_rope_index_glm4v function under various multimodal and fallback scenarios.
  • Accuracy Preservation: The changes ensure no regression in model accuracy, as confirmed by lmms_evals showing no drop in performance.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces a new benchmark script for get_rope_index_glm4v and includes several performance optimizations and a critical bug fix in the MRotaryEmbedding.get_rope_index_glm4v function. The optimizations primarily focus on reducing CPU-GPU transfers by avoiding .item() calls, explicitly setting device for torch.arange, preallocating lists, and consolidating torch.cat and torch.tensor operations outside loops. The critical fix addresses an issue where image_index and video_index were not reset per batch item, which could lead to incorrect multimodal data processing.

Comment on lines +2124 to 2125
# Move attention mask to device once to avoid repeated transfers
attention_mask = attention_mask.to(total_input_ids.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Preallocating input_token_type with [""] * len(input_tokens) is a significant performance improvement. Repeated append operations can be inefficient for large lists, as they may involve reallocations. Direct assignment after preallocation is much faster.

Comment on lines 2268 to 2270
mrope_position_deltas = torch.tensor(
mrope_position_deltas, device=input_ids.device
).unsqueeze(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Building the mrope_position_deltas tensor in one call outside the loop is a good performance optimization, avoiding repeated tensor creations and concatenations.

Comment on lines +2258 to 2259
# Concatenate once outside for speed
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Concatenating llm_pos_ids_list once outside the loop is a significant performance optimization. Repeated torch.cat calls within a loop can be very expensive due to frequent memory reallocations and data copying.

Comment on lines +2137 to +2148
for j, token in enumerate(input_tokens):
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False

if token == image_token_id and not video_check_flg:
input_token_type.append("image")
input_token_type[j] = "image"
elif token == image_token_id and video_check_flg:
input_token_type.append("video")
input_token_type[j] = "video"
else:
input_token_type.append("text")
input_token_type[j] = "text"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using direct assignment input_token_type[j] = "image" instead of append is a good performance optimization, especially when the list size is known beforehand and preallocated.

Comment on lines +1744 to 1748
torch.arange(llm_grid_h, device=position_ids.device)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
.expand(llm_grid_t, llm_grid_h, llm_grid_w)
.reshape(-1)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Explicitly setting device=position_ids.device for torch.arange is a good practice to ensure tensors are created on the correct device, preventing potential implicit device transfers. Using .reshape(-1) instead of .flatten() is also a minor stylistic and potential performance improvement.

Comment on lines 2264 to 2266
mrope_position_deltas.append(
llm_positions.max() + 1 - len(total_input_ids[i])
llm_positions.max().item() + 1 - len(total_input_ids[i])
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using .max().item() to get the scalar value from a tensor is more efficient than .max() followed by implicit conversion to a Python int, as it avoids unnecessary device synchronization and CPU transfer.

Comment on lines +1736 to 1740
torch.arange(llm_grid_t, device=position_ids.device)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
.expand(llm_grid_t, llm_grid_h * llm_grid_w)
.reshape(-1)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Explicitly setting device=position_ids.device for torch.arange is a good practice to ensure tensors are created on the correct device, preventing potential implicit device transfers. Using .reshape(-1) instead of .flatten() is also a minor stylistic and potential performance improvement.

Comment on lines +2282 to +2287
max_position_ids = position_ids.amax(dim=0, keepdim=False)
mrope_position_deltas = (
max_position_ids.amax(-1, keepdim=True)
+ 1
- attention_mask.shape[-1]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Replacing chained .max() calls with torch.amax is a good performance optimization, as torch.amax is generally more efficient for finding maximums across multiple dimensions.

Comment on lines +2291 to +2295
# Use torch.arange with in-place expansion
arange_ids = torch.arange(length, device=input_ids.device).view(
1, 1, -1
)
position_ids = arange_ids.expand(3, batch_size, length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using torch.arange with explicit device and then expanding is a clear and efficient way to create the position IDs.

@yuan-luo
Copy link
Collaborator Author

/tag-and-rerun-ci

@yuan-luo yuan-luo added Multi-modal multi-modal language model vlm performance and removed run-ci labels Jan 21, 2026
@yuan-luo yuan-luo force-pushed the opt_get_rope_index_glm4v branch from 87843cd to ac17cff Compare January 21, 2026 02:06
@yuan-luo
Copy link
Collaborator Author

/rerun-failed-ci

@yuan-luo yuan-luo force-pushed the opt_get_rope_index_glm4v branch from ac17cff to a78b779 Compare January 24, 2026 01:43
@yuan-luo
Copy link
Collaborator Author

/rerun-failed-ci

1 similar comment
@yuan-luo
Copy link
Collaborator Author

/rerun-failed-ci

@WhoisZihan
Copy link

Impressive optimization 👍

@yuan-luo
Copy link
Collaborator Author

/rerun-failed-ci

@mickqian
Copy link
Collaborator

overall LGTM, but since it's a common method, we might need benchmarks for other models too

@yuan-luo yuan-luo force-pushed the opt_get_rope_index_glm4v branch from a78b779 to 0952830 Compare January 29, 2026 11:59
@yuan-luo
Copy link
Collaborator Author

yuan-luo commented Jan 29, 2026

overall LGTM, but since it's a common method, we might need benchmarks for other models too

Sure, will benchmark other VLMs. Since this PR is dedicated form GLM4V, will follow up in new PRs when this PR is merged.

@yuan-luo
Copy link
Collaborator Author

/rerun-failed-ci

2 similar comments
@yuan-luo
Copy link
Collaborator Author

/rerun-failed-ci

@yuan-luo
Copy link
Collaborator Author

/rerun-failed-ci

@yuan-luo yuan-luo force-pushed the opt_get_rope_index_glm4v branch from 0952830 to a4ad83e Compare January 31, 2026 08:01
@yuan-luo
Copy link
Collaborator Author

yuan-luo commented Jan 31, 2026

It seems GLM4.5V was broken on the main branch. I encountered this error on main branch.

H800

root@c7e9bb6a6789:/sgl-workspace# FLASHINFER_DISABLE_VERSION_CHECK=1 SGLANG_USE_CUDA_IPC_TRANSPORT=1 python -m sglang.launch_server --model-path zai-org/GLM-4.5V --mm-attention-backend sdpa --port 30000 --chunked-prefill-size 8192 --disable-radix-cache --disable-overlap-schedule --attention-backend fa3 --tp 4 --mem-fraction-static=0.7
[2026-01-31 08:40:47] INFO utils.py:148: Note: detected 224 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-31 08:40:47] INFO utils.py:151: Note: NumExpr detected 224 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-31 08:40:47] INFO utils.py:164: NumExpr defaulting to 16 threads.
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.
<frozen importlib._bootstrap_external>:1297: FutureWarning: The cuda.nvrtc module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.nvrtc module instead.
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/usr/local/lib/python3.12/dist-packages/sglang/launch_server.py", line 32, in <module>
    run_server(server_args)
  File "/usr/local/lib/python3.12/dist-packages/sglang/launch_server.py", line 25, in run_server
    launch_server(server_args)
  File "/usr/local/lib/python3.12/dist-packages/sglang/srt/entrypoints/http_server.py", line 1844, in launch_server
    _launch_subprocesses(
  File "/usr/local/lib/python3.12/dist-packages/sglang/srt/entrypoints/engine.py", line 1004, in _launch_subprocesses
    tokenizer_manager, template_manager = init_tokenizer_manager_func(
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/sglang/srt/entrypoints/engine.py", line 102, in init_tokenizer_manager
    tokenizer_manager = TokenizerManagerClass(server_args, port_args)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/sglang/srt/managers/tokenizer_manager.py", line 208, in __init__
    self.init_tokenizer_and_processor()
  File "/usr/local/lib/python3.12/dist-packages/sglang/srt/managers/tokenizer_manager.py", line 278, in init_tokenizer_and_processor
    self.mm_processor = get_mm_processor(
                        ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/sglang/srt/managers/multimodal_processor.py", line 52, in get_mm_processor
    raise ValueError(
ValueError: No processor registered for architecture: ['Glm4vMoeForConditionalGeneration'].
Registered architectures: ['CLIPModel', 'DeepseekOCRForCausalLM', 'DeepseekVL2ForCausalLM', 'DotsVLMForCausalLM', 'DotsOCRForCausalLM', 'Ernie4_5_VLMoeForConditionalGeneration', 'Gemma3ForConditionalGeneration', 'Gemma3nForConditionalGeneration', 'InternVLChatModel', 'InternS1ForConditionalGeneration', 'MultiModalityCausalLM', 'KimiK25ForConditionalGeneration', 'KimiVLForConditionalGeneration', 'LightOnOCRForConditionalGeneration', 'LlavaLlamaForCausalLM', 'LlavaVidForCausalLM', 'LlavaQwenForCausalLM', 'LlavaMistralForCausalLM', 'LlavaForConditionalGeneration', 'Mistral3ForConditionalGeneration', 'MiDashengLMModel', 'MiniCPMV', 'MiniCPMO', 'MllamaForConditionalGeneration', 'Llama4ForConditionalGeneration', 'NemotronH_Nano_VL_V2', 'NVILAForConditionalGeneration', 'NVILALiteForConditionalGeneration', 'JetVLMForConditionalGeneration', 'PaddleOCRVLForConditionalGeneration', 'Phi4MMForCausalLM', 'PixtralVisionModel', 'PixtralForConditionalGeneration', 'POINTSV15ChatModel', 'Qwen2AudioForConditionalGeneration', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration', 'Qwen3VLForConditionalGeneration', 'Qwen3VLMoeForConditionalGeneration', 'Qwen3OmniMoeForConditionalGeneration', 'Sarashina2VisionForCausalLM', 'Step3VLForConditionalGeneration', 'StepVLForConditionalGeneration']

@yuan-luo
Copy link
Collaborator Author

/rerun-failed-ci

2 similar comments
@yuan-luo
Copy link
Collaborator Author

yuan-luo commented Feb 1, 2026

/rerun-failed-ci

@yuan-luo
Copy link
Collaborator Author

yuan-luo commented Feb 1, 2026

/rerun-failed-ci

@BBuf BBuf merged commit 4ea4f2a into sgl-project:main Feb 1, 2026
233 of 252 checks passed
charlesHsuGG pushed a commit to charlesHsuGG/sglang that referenced this pull request Feb 2, 2026
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
sfiisf pushed a commit to sfiisf/sglang that referenced this pull request Feb 5, 2026
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants