[Feature] Support bidirectional attention for Gemma-3 #10707
[Feature] Support bidirectional attention for Gemma-3 #10707mickqian merged 3 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @zzhbrr, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical aspect of Gemma-3 multimodal model processing by introducing proper bidirectional attention for image tokens during the prefill stage. This change corrects the attention mechanism, leading to improved accuracy in multimodal tasks, and is specifically implemented within the Triton attention backend. While enhancing model performance, it introduces some constraints regarding CUDA Graph and chunked prefill compatibility. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for bidirectional attention in Gemma-3, which is essential for its multimodal capabilities. The changes are well-targeted, primarily affecting the Triton attention backend and the Gemma-3 model implementation. The core logic correctly constructs and applies custom attention masks to enable bidirectional attention for image tokens during the prefill stage. My review includes suggestions to improve code clarity, performance, and type hint correctness.
|
Hi, @mickqian @JustinTong0323. Could you please help review this PR? Thanks! |
|
Thanks for the contribution! This implementation looks good, I notice there are some restrictions, could you share the launch cmd? Also, it's better if you could add to a note in doc at appropriate position for the special usage of gemma3 :) |
|
Note: required CI is passed. |
Hi! The launch command is: python -m sglang.launch_server \
--model-path google/gemma-3-4b-it \
--host 0.0.0.0 --port 30000 \
--enable-multimodal \
--dtype bfloat16 --triton-attention-reduce-in-fp32 \
--attention-backend triton \ # Use Triton attention backend
--disable-cuda-graph \ # Disable Cuda Graph
--chunked-prefill-size -1 # Disable Chunked PrefillI also add a new chapter in |
|
Hi, @JustinTong0323 . What changes do I need to make for this part of the code to be applied? |
I would help you ask for code owner's approve. |
|
MMMU benchmark: Result: |
|
This is a blocker for pi0-fast, since the paligemma model uses bidirectional for all prefix |
|
hey @zzhbrr thanks for you brilliant work and sorry for the late response. Could you rebase, then we can merge this immediately after CI passed. |
Got it! I'll rebase. |
d56386a to
aa47fba
Compare
Motivation
As discussed in Question about the multimodal attention mask in Gemma-3 · Issue #10309 · sgl-project/sglang, the Gemma-3 multimodal model uses bidirectional attention between image tokens during the prefill stage.
Now, SGLang only support causal attention for prefill stage.
Modifications
I have implemented bidirectional attention for the
TritonAttnBackend. Specifically, I added mask computation ingemma3_mm.py, introduced the DECODER_BIDIRECTIONAL AttentionType, and used custom masks during the prefill stage.There are a few points to note:
128K * 128K / 1024^3 = 15.25 GBof VRAM. A potential solution is to restrict thecustom_maskto within the sliding window range.mask_indptrand related variables from int32 to int64, as they might exceed the int32 range when the sequence length and batch size are large. For instance, with a batch size of 20 and a sequence length of 10K,mask_indptrcould exceed the int32 representation limit.Accuracy Tests
Evaluation results using
lmms_evalon different datasets are as follows:ChartQA
MME
MMMU
The correct attention implementation leads to accuracy improvement.
Benchmarking and Profiling
Checklist