Skip to content

[feat] feat: support swa in trtllm_mha#18970

Merged
ispobock merged 2 commits intosgl-project:mainfrom
LuYanFCP:feat/support_swa_in_trtllm_mha
Feb 20, 2026
Merged

[feat] feat: support swa in trtllm_mha#18970
ispobock merged 2 commits intosgl-project:mainfrom
LuYanFCP:feat/support_swa_in_trtllm_mha

Conversation

@LuYanFCP
Copy link
Contributor

@LuYanFCP LuYanFCP commented Feb 18, 2026

Motivation

Recent findings during the adaptation of step3.5-flash reveal that on the B200 platform, the default use of the trtllm_mha kernel lacks the implementation of SWA, resulting in incorrect output after generating a certain number of tokens.

Modifications

  1. add some helper function in python/sglang/srt/layers/attention/trtllm_mha_backend.py
  2. add swa translate in trtllm_mha

Accuracy Tests

using Step-3.5-Flash to Test

In B200

4188fc30b06b% python -m sglang.test.few_shot_gsm8k \
  --host http://127.0.0.1 \
  --port 8000 \
  --num-questions 200 \
  --num-shots 5

Accuracy: 0.875
Invalid: 0.005
Latency: 9.289 s
Output throughput: 2154.766 token/s

In H20-3e

root@aba9b8cd2ed9:/sgl-workspace/sglang# python -m sglang.test.few_shot_gsm8k   --host http://127.0.0.1   --port 8000  --num-questions 200 --num-shots 5

Accuracy: 0.835
Invalid: 0.005
Latency: 11.075 s
Output throughput: 1789.372 token/s

GPQA-Diamond:

B200: 0.8316498316498316
H200: 0.835016835016835

Detail Case

Pre Test:

4188fc30b06b% uv run test-scripts/chat.py '写一个python的快排' --temperature 0
.....
def quick_sort(arr, low=0, high=None):
    """原地快速排序(升序)"""
    if high is None:
        high = len(arr) - 1
    if low < high:
        # 分区操作,返回基准元素的正确位置
        pivot_index = partition(arr, low, high)
        # 递归排序左半部分和右半部分
        quick_sort(arr, low, pivot_index - 1)
        quick_sort(arr, pivot_index + 1, high)

def partition(arr, low, high):
    """分区函数:将数组分为小于基准和大于基准的两部分"""
    # 选择最后一个元素作为基准(可优化为随机选择)
    pivot = arr[high]
    i = low - 1  # 指向小于基准的区域的最后一个元素
    for j in range(low, high):
        if arr[j] <= pivot:
            i += 1
            i, arr[i], arr[i], arr[i], 交换 arr[i], 交换 arr[i] 交换 arr[i] 1 1 1, 1 1, 1:
    # 1:
                # 1
    # 1:
            arr[j] 1:
    # 1:
    # 1:
            arr[j] # 1:
    # 1:
            # 交换 arr[j]:
    # 1:
                # 1:
            arr[j] # 交换 arr[i] # 1:
            arr[i] # 交换 arr[i] # 交换 arr[i] # 交换 arr[j] # 交换 arr:
            arr[i] # 1:
            arr[j] # 1:
        arr[j] #  partition arr:
        # 1:
            arr:
            arr:
            arr:
            # 交换 arr[1:
            arr:
        arr:
        # 交换
    for _1:
        #  partition 1:
        #  partition 1:
        #  partition 1:
        # 1:
        # 1:
        #  partition 1:
            # 1:
        # 1, arr    #  partition  partition  partition 1:
        # 1,  partition arr    # 1:
    #  partition] = pivot_index = []
    ]:
    ]:
    # 1:
    ]:
        for j in range(:
    # 1    partition] # 1:
    # 升序):
    partition 1  partition]    partition 越 partition 1 排序    # 排序    # 排序1    # 1]    # 1]    # 排序后的元素交换    partition2, arr    partition 的    partition  partition 的 partition]   交换    #  partition 2, 1    # 的    ]    ]   交换的 partition  partition 1]:
    }

After Test:

4188fc30b06b% uv run test-scripts/chat.py '写一个python的快排' --temperature 0

[Thinking]
我们来实现一个快速排序(Quick Sort)的Python函数。
 快速排序的基本思想是选择一个基准元素,通过一趟排序将数组分成两部分,其中一部分的所有数据都比另一部分的所有数据小,然后再按此方法对这两部分数据分别进行快速排序,整个过程可以递归进行。

 步骤:
 1. 选择一个基准元素(通常选择第一个元素,但也可以选择最后一个或中间的元素)。
 2. 通过一趟排序将数组分成两部分,使得左边的元素都小于等于基准,右边的元素都大于等于基准。
 3. 递归地对左右两部分进行快速排序。

 注意:这里我们实现一个原地排序的版本,但为了简单起见,也可以使用额外的列表。不过,原地排序更节省空间。

 我们这里实现一个原地排序的版本,使用两个指针(左右指针)进行划分。

 但是,注意:原地排序的快速排序在划分时需要注意细节,我们这里采用《算法导论》中的划分方法,即使用最后一个元素作为基准,然后维护一个索引i,使得[0, i-1]都是小于等于基准的,然后遍历j从0到high-1,如果当前元素小于等于基准,就交换到i的位置,然后i++。最后将基准(最后一个元素)与i位置的元素交换,这样基准就放在了正确的位置。

 然而,我们也可以选择第一个元素作为基准,但那样需要调整。这里我们选择最后一个元素作为基准。

 但是,为了通用性,我们可以让用户指定基准的选择方式,但这里我们固定使用最后一个元素。

 我们写一个函数,参数是数组arr,以及排序的起始和结束索引(low和high)。如果low小于high,则进行划分,然后递归排序。

 注意:我们这里实现的是升序排序。

 步骤:
 1. 如果low < high,则:
     a. 调用partition函数,返回基准元素的位置pivot_index。
     b. 递归排序low到pivot_index-1。
     c. 递归排序pivot_index+1到high。

 2. partition函数:
    选择最后一个元素作为基准(pivot = arr[high])。
    初始化一个索引i = low - 1,表示小于等于基准的区域的最后一个元素的下标。
    遍历j从low到high-1:
        如果arr[j] <= pivot,则i++,交换arr[i]和arr[j]。
    最后,交换arr[i+1]和arr[high](将基准放到正确位置)。
    返回i+1,即基准的位置。

 注意:这个划分方法会使得基准左边的元素都小于等于基准,右边的元素都大于等于基准。

 但是,快速排序的平均时间复杂度是O(nlogn),最坏情况(已经有序)是O(n^2)。我们可以通过随机选择基准来避免最坏情况,但这里我们固定使用最后一个元素。

 我们也可以实现一个随机选择基准的版本,但题目没有要求,我们就先实现这个简单的。

 另外,注意递归深度,对于大数组可能会递归很深,但Python的递归深度有限,对于大数组可能会栈溢出。我们可以使用迭代方式,但这里我们先用递归。

 我们写一个函数,然后测试一下。

 但是,注意:我们这里实现的是原地排序,所以不需要返回新的列表,而是直接修改传入的列表。

 我们写两个函数:quick_sort和partition。

 但是,为了用户方便,我们可以提供一个对外的接口,只传入列表,然后内部调用递归函数。

 我们这样设计:

 定义一个内部递归函数_quick_sort(arr, low, high),然后对外暴露一个quick_sort(arr)函数,它调用_quick_sort(arr, 0, len(arr)-1)。

 但是,注意:如果列表为空或只有一个元素,则直接返回。

 我们开始写代码。

[/Thinking]


以下是 Python 中快速排序(Quick Sort)的几种实现方式,包括**原地排序版本**(节省内存)和**易读版本**(使用额外列表):

---

### 1. **原地排序版本(推荐)**
直接在原数组上操作,空间复杂度为 O(log n)(递归栈),时间复杂度平均 O(n log n),最坏 O(n²)。

\```python
def quick_sort(arr, low=0, high=None):
    """原地快速排序(升序)"""
    if high is None:
        high = len(arr) - 1
    if low < high:
        # 分区操作,返回基准元素的正确位置
        pivot_index = partition(arr, low, high)
        # 递归排序左半部分和右半部分
        quick_sort(arr, low, pivot_index - 1)
        quick_sort(arr, pivot_index + 1, high)

def partition(arr, low, high):
    """分区函数:将数组分为小于基准和大于基准的两部分"""
    # 选择最后一个元素作为基准(可优化为随机选择)
    pivot = arr[high]
    i = low - 1  # 指向小于基准的区域的最后一个元素
    for j in range(low, high):
        if arr[j] <= pivot:
            i += 1
            arr[i], arr[j] = arr[j], arr[i]  # 交换元素
    # 将基准元素放到正确位置
    arr[i + 1], arr[high] = arr[high], arr[i + 1]
    return i + 1

# 示例
if __name__ == "__main__":
    data = [3, 6, 8, 10, 1, 2, 1]
    quick_sort(data)
    print("排序后:", data)  # 输出: [1, 1, 2, 3, 6, 8, 10]
\```

---

### 2. **易读版本(使用额外列表)**
逻辑清晰,适合理解快速排序原理,但空间复杂度 O(n)。

\```python
def quick_sort_readable(arr):
    """易读的快速排序(返回新列表)"""
    if len(arr) <= 1:
        return arr
    pivot = arr[-1]  # 选择最后一个元素作为基准
    left = [x for x in arr[:-1] if x <= pivot]  # 小于等于基准的子列表
    right = [x for x in arr[:-1] if x > pivot]   # 大于基准的子列表
    return quick_sort_readable(left) + [pivot] + quick_sort_readable(right)

# 示例
if __name__ == "__main__":
    data = [3, 6, 8, 10, 1, 2, 1]
    sorted_data = quick_sort_readable(data)
    print("排序后:", sorted_data)  # 输出: [1, 1, 2, 3, 6, 8, 10]
\```

---

### 3. **优化版本(随机选择基准 + 三数取中)**
避免最坏情况(如数组已排序),提升性能:

\```python
import random

def quick_sort_optimized(arr, low=0, high=None):
    """优化版快速排序:随机选择基准"""
    if high is None:
        high = len(arr) - 1
    if low < high:
        # 随机选择基准并交换到末尾
        random_index = random.randint(low, high)
        arr[random_index], arr[high] = arr[high], arr[random_index]
        pivot_index = partition(arr, low, high)
        quick_sort_optimized(arr, low, pivot_index - 1)
        quick_sort_optimized(arr, pivot_index + 1, high)

# partition 函数同上(使用 arr[high] 作为基准)
\```

---

### 关键点说明:
1. **基准选择**:固定选择最后一个元素简单但可能退化;随机选择可避免最坏情况。
2. **稳定性**:快速排序是**不稳定**的(相同元素的相对位置可能改变)。
3. **适用场景**:适合大规模无序数据;小数组可切换到插入排序优化。
4. **递归深度**:最坏情况递归深度 O(n),可能栈溢出;可改用迭代+栈模拟递归。

---

### 测试建议:
\```python
# 测试不同情况
test_cases = [
    [], [1], [2, 1], [1, 2, 3], [3, 2, 1],
    [5, 2, 9, 1, 5, 6], [1, 1, 1, 1]
]
for case in test_cases:
    arr = case.copy()
    quick_sort(arr)
    print(f"{case} -> {arr}")
\```

选择版本建议:
- **生产环境**:使用原地排序版本(`quick_sort`),注意随机化基准。
- **教学/理解**:使用易读版本(`quick_sort_readable`)。
- **避免最坏情况**:使用优化版本(`quick_sort_optimized`)。

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @LuYanFCP, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the trtllm_mha kernel by adding SWA support, which was previously missing and causing issues on the B200 platform. The implementation includes necessary helper functions and translation logic to ensure correct output generation. Accuracy tests confirm that the changes maintain the model's performance.

Highlights

  • SWA Support: This PR introduces support for SWA (Speculative With Aggregation) in the trtllm_mha kernel, addressing an issue where the default kernel lacked SWA implementation on the B200 platform, leading to incorrect output after generating a certain number of tokens.
  • Code Modifications: The changes include adding helper functions in python/sglang/srt/layers/attention/trtllm_mha_backend.py and implementing SWA translation within the trtllm_mha kernel.
  • Accuracy Tests: Accuracy tests were conducted before and after the modifications, demonstrating that the changes do not negatively impact the model's output quality, as shown by the consistent response to a sample query.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/attention/trtllm_mha_backend.py
    • feat: add some helper function in python/sglang/srt/layers/attention/trtllm_mha_backend.py
    • feat: add swa translate in trtllm_mha
Activity
  • The PR introduces SWA support to the trtllm_mha kernel.
  • Helper functions were added to trtllm_mha_backend.py.
  • SWA translation was implemented in trtllm_mha.
  • Accuracy tests were performed to validate the changes.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly adds support for Sliding Window Attention (SWA) in the trtllm_mha backend. The changes are well-structured, introducing helper functions for SWA page table management and integrating them correctly into the existing CUDA graph capture/replay logic and the standard forward pass. The modifications are consistent and correctly handle the translation between full and SWA KV pools. The code is of high quality and I have no suggestions for improvement.

@LuYanFCP LuYanFCP force-pushed the feat/support_swa_in_trtllm_mha branch from 6f6c119 to 2dc87c9 Compare February 18, 2026 09:41
@LuYanFCP LuYanFCP changed the title [feat]: support swa in trtllm_mha [feat] feat: support swa in trtllm_mha Feb 18, 2026
@ispobock
Copy link
Collaborator

@LuYanFCP Could you fix the lint ci?

@LuYanFCP LuYanFCP force-pushed the feat/support_swa_in_trtllm_mha branch from 284c316 to e0078fd Compare February 19, 2026 15:30
@LuYanFCP
Copy link
Contributor Author

@LuYanFCP Could you fix the lint ci?

Done

@LuYanFCP
Copy link
Contributor Author

LuYanFCP commented Feb 20, 2026

@ispobock add complete GPQA-Diamond using nemo-evaluator and there are no issues with the accuracy

using Step-3.5-Flash to Test

B200: 0.8316498316498316
H200: 0.835016835016835

@ispobock
Copy link
Collaborator

/tag-and-rerun-ci

@ispobock ispobock merged commit ab18734 into sgl-project:main Feb 20, 2026
155 of 169 checks passed
@LuYanFCP LuYanFCP deleted the feat/support_swa_in_trtllm_mha branch February 21, 2026 14:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants