[feat] feat: support swa in trtllm_mha#18970
Conversation
Summary of ChangesHello @LuYanFCP, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the trtllm_mha kernel by adding SWA support, which was previously missing and causing issues on the B200 platform. The implementation includes necessary helper functions and translation logic to ensure correct output generation. Accuracy tests confirm that the changes maintain the model's performance. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly adds support for Sliding Window Attention (SWA) in the trtllm_mha backend. The changes are well-structured, introducing helper functions for SWA page table management and integrating them correctly into the existing CUDA graph capture/replay logic and the standard forward pass. The modifications are consistent and correctly handle the translation between full and SWA KV pools. The code is of high quality and I have no suggestions for improvement.
6f6c119 to
2dc87c9
Compare
|
@LuYanFCP Could you fix the lint ci? |
284c316 to
e0078fd
Compare
Done |
|
@ispobock add complete GPQA-Diamond using nemo-evaluator and there are no issues with the accuracy using Step-3.5-Flash to Test B200: 0.8316498316498316 |
|
/tag-and-rerun-ci |
Motivation
Recent findings during the adaptation of step3.5-flash reveal that on the B200 platform, the default use of the trtllm_mha kernel lacks the implementation of SWA, resulting in incorrect output after generating a certain number of tokens.
Modifications
python/sglang/srt/layers/attention/trtllm_mha_backend.pyAccuracy Tests
using Step-3.5-Flash to Test
In B200
In H20-3e
GPQA-Diamond:
B200: 0.8316498316498316
H200: 0.835016835016835
Detail Case
Pre Test:
After Test:
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci