Skip to content

[Feature] Support bidirectional attention for Gemma-3 #10707

Merged
mickqian merged 3 commits intosgl-project:mainfrom
zzhbrr:fix_gemma3_attn
Feb 9, 2026
Merged

[Feature] Support bidirectional attention for Gemma-3 #10707
mickqian merged 3 commits intosgl-project:mainfrom
zzhbrr:fix_gemma3_attn

Conversation

@zzhbrr
Copy link
Contributor

@zzhbrr zzhbrr commented Sep 21, 2025

Motivation

As discussed in Question about the multimodal attention mask in Gemma-3 · Issue #10309 · sgl-project/sglang, the Gemma-3 multimodal model uses bidirectional attention between image tokens during the prefill stage.

Now, SGLang only support causal attention for prefill stage.

Modifications

I have implemented bidirectional attention for the TritonAttnBackend. Specifically, I added mask computation in gemma3_mm.py, introduced the DECODER_BIDIRECTIONAL AttentionType, and used custom masks during the prefill stage.

There are a few points to note:

  1. The current implementation is not compatible with CUDA Graph. CUDA Graph requires pre-allocating space for the prefill custom mask, which is impractical. For example, Gemma-3 supports a 128K context length, meaning the prefill custom mask for a single request would occupy 128K * 128K / 1024^3 = 15.25 GB of VRAM. A potential solution is to restrict the custom_mask to within the sliding window range.
  2. Bidirectional Attention is not compatible with chunked prefill, as chunked prefill may split image tokens into multiple chunks.
  3. The Triton Extend Attention Kernel may need to change mask_indptr and related variables from int32 to int64, as they might exceed the int32 range when the sequence length and batch size are large. For instance, with a batch size of 20 and a sequence length of 10K, mask_indptr could exceed the int32 representation limit.

Accuracy Tests

Evaluation results using lmms_eval on different datasets are as follows:

  • ChartQA

    ChartQA SGLang-wrong_attention SGLang-correct_attention huggingface
    relaxed_overall 0.393 0.524 0.508
    relaxed_human_split 0.261 0.327 0.308
    relaxed_augmented_split 0.524 0.720 0.708
  • MME

    MME SGLang-wrong_attention SGLang-correct_attention
    mme_cognition_score 350 361
    mme_perception_score 1168 1224
  • MMMU

    SGLang-wrong_attention SGLang-correct_attention huggingface
    MMMU 0.372 0.387 0.4

The correct attention implementation leads to accuracy improvement.

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @zzhbrr, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical aspect of Gemma-3 multimodal model processing by introducing proper bidirectional attention for image tokens during the prefill stage. This change corrects the attention mechanism, leading to improved accuracy in multimodal tasks, and is specifically implemented within the Triton attention backend. While enhancing model performance, it introduces some constraints regarding CUDA Graph and chunked prefill compatibility.

Highlights

  • Gemma-3 Bidirectional Attention: Implemented support for bidirectional attention for Gemma-3 multimodal models during the prefill (extend) stage, specifically for image tokens, which is crucial for correct multimodal processing.
  • Triton Backend Integration: The new bidirectional attention logic is integrated into the TritonAttnBackend by dynamically generating and applying custom attention masks for image tokens.
  • New Attention Type: A new AttentionType.DECODER_BIDIRECTIONAL enum member has been introduced to explicitly categorize this specific attention behavior.
  • Accuracy Improvement: The correct implementation of bidirectional attention leads to notable accuracy improvements across various multimodal benchmarks, including ChartQA, MME, and MMMU.
  • Known Limitations: The current implementation has known limitations, including incompatibility with CUDA Graph (due to potential VRAM exhaustion for large custom masks) and chunked prefill.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for bidirectional attention in Gemma-3, which is essential for its multimodal capabilities. The changes are well-targeted, primarily affecting the Triton attention backend and the Gemma-3 model implementation. The core logic correctly constructs and applies custom attention masks to enable bidirectional attention for image tokens during the prefill stage. My review includes suggestions to improve code clarity, performance, and type hint correctness.

@zzhbrr
Copy link
Contributor Author

zzhbrr commented Sep 23, 2025

Hi, @mickqian @JustinTong0323. Could you please help review this PR? Thanks!

@JustinTong0323
Copy link
Collaborator

Thanks for the contribution! This implementation looks good, I notice there are some restrictions, could you share the launch cmd? Also, it's better if you could add to a note in doc at appropriate position for the special usage of gemma3 :)

@JustinTong0323
Copy link
Collaborator

Note: required CI is passed.

@zzhbrr
Copy link
Contributor Author

zzhbrr commented Sep 25, 2025

Thanks for the contribution! This implementation looks good, I notice there are some restrictions, could you share the launch cmd? Also, it's better if you could add to a note in doc at appropriate position for the special usage of gemma3 :)

Hi! The launch command is:

python -m sglang.launch_server \
  --model-path google/gemma-3-4b-it \
  --host 0.0.0.0 --port 30000 \
  --enable-multimodal \
  --dtype bfloat16 --triton-attention-reduce-in-fp32 \
  --attention-backend triton \ # Use Triton attention backend
  --disable-cuda-graph \ # Disable Cuda Graph
  --chunked-prefill-size -1 # Disable Chunked Prefill

I also add a new chapter in multimodal_language_models docs to illustrate bidirectional attention.

@zzhbrr
Copy link
Contributor Author

zzhbrr commented Sep 30, 2025

Hi, @JustinTong0323 . What changes do I need to make for this part of the code to be applied?

@JustinTong0323
Copy link
Collaborator

Hi, @JustinTong0323 . What changes do I need to make for this part of the code to be applied?

I would help you ask for code owner's approve.

@JustinTong0323
Copy link
Collaborator

MMMU benchmark:

python -m sglang.launch_server \
  --model-path google/gemma-3-27b-it \
  --host 0.0.0.0 --port 30000 \
  --enable-multimodal \
  --dtype bfloat16 --triton-attention-reduce-in-fp32 \
  --attention-backend triton \
  --disable-cuda-graph \
  --chunked-prefill-size -1 --tp 2

python3 -m lmms_eval \
    --model openai_compatible \
    --model_args 'model_version="gemma-3-27b-it",tp=2' \
    --tasks mmmu_val \
    --batch_size 64 \
    --log_samples \
    --log_samples_suffix openai_compatible \
    --output_path ./output/gemma-3-27b-it

Result:
Main: 0.5278
PR: 0.5367

@vincentzed
Copy link
Contributor

This is a blocker for pi0-fast, since the paligemma model uses bidirectional for all prefix

@mickqian
Copy link
Collaborator

mickqian commented Feb 8, 2026

hey @zzhbrr thanks for you brilliant work and sorry for the late response. Could you rebase, then we can merge this immediately after CI passed.

@zzhbrr
Copy link
Contributor Author

zzhbrr commented Feb 8, 2026

hey @zzhbrr thanks for you brilliant work and sorry for the late response. Could you rebase, then we can merge this immediately after CI passed.

Got it! I'll rebase.

@github-actions github-actions bot added documentation Improvements or additions to documentation Multi-modal multi-modal language model labels Feb 8, 2026
@mickqian mickqian merged commit ddbcfba into sgl-project:main Feb 9, 2026
276 of 301 checks passed
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation Multi-modal multi-modal language model run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants