Skip to content

fuse qkvbfg linear into one gemm and f_b g_b into batched gemm.#17801

Merged
ispobock merged 4 commits intosgl-project:mainfrom
antgroup:dev/kda_fuse_gemm
Feb 4, 2026
Merged

fuse qkvbfg linear into one gemm and f_b g_b into batched gemm.#17801
ispobock merged 4 commits intosgl-project:mainfrom
antgroup:dev/kda_fuse_gemm

Conversation

@strgrb
Copy link
Collaborator

@strgrb strgrb commented Jan 27, 2026

Motivation

There are 8 gemm in kda, and 6 of them (q_proj, k_proj, v_proj, b_proj, f_a_proj, g_a_proj) share the same input, so they can be fused into single gemm, and other 2 of them (f_b_proj, g_b_proj) has the same input/output size, so they can be fused into a batched gemm.

With 4k input 1k output test for decode, the profile like the following:
image

kernel duration from 5.9+5.7+5.6+5.6+2.5+5.2+5.7+2.6=38.8us to 8.7+3=11.7us

For prefill, kernel duration from 286+286+286+45+20+14+45+19=1001us to 934+35=969us, optimization is much less than decode.

Modifications

Accuracy Tests

test with gsm8k all questions:

python bench_sglang.py --port 8189 --data-path ./test.jsonl --num-questions -1

before:

Accuracy: 0.895
Invalid: 0.001
Latency: 109.134 s
Output throughput: 1175.404 token/s

after:

Accuracy: 0.898
Invalid: 0.000
Latency: 104.709 s
Output throughput: 1224.809 token/s

Benchmarking and Profiling

before:

max concurrency random input len random output len throughput TTFT TPOT
8 4000 1000 1.26 473.58 5.88
32 4000 1000 2.50 1445.09 11.33
128 4000 1000 4.38 5326.57 23.62

after:

max concurrency random input len random output len throughput TTFT TPOT
8 4000 1000 1.37 470.89 5.28
32 4000 1000 2.65 1443.90 10.62
128 4000 1000 4.56 5612.32 22.48

TPOT is decreased by about 5%~10%

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@strgrb
Copy link
Collaborator Author

strgrb commented Jan 29, 2026

/tag-and-rerun-ci

@strgrb
Copy link
Collaborator Author

strgrb commented Jan 29, 2026

@yizhang2077 @ispobock Now benchmark is ready here.

@strgrb
Copy link
Collaborator Author

strgrb commented Jan 30, 2026

/rerun-failed-ci

1 similar comment
@strgrb
Copy link
Collaborator Author

strgrb commented Jan 30, 2026

/rerun-failed-ci

self.num_heads,
]
self.fg_sizes = [self.head_dim, self.head_dim]
self.fused_qkvbfg_proj = MergedColumnParallelRepeatedLinear(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that, for large bs, it will cause higher ttft, since it's already compute bound.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemm with large M will not cause higher latency, I think this benchmark result may be caused by random issue. I'll test it for more times and do some separate test.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ispobock I used benchmark but ttft result is not so stable, so I use throughput at output_size=1 to identify the prefill performance at batch_size=128.
Before opt is 12.87 and after is 12.96, this result is quite stable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ispobock And for separate test, I use following code to test. Mainly for m=4096 to 16384, because these values are commonly used chunked_prefill_size .

import torch

m = 8192
k = 2304
n1s = [1024, 1024, 1024, 8, 128, 128]
n2 = 1024

x = torch.randn([m, k], device='cuda', dtype=torch.bfloat16)
w1s = [torch.randn([n, k], device='cuda', dtype=torch.bfloat16) for n in n1s]
w2s = [torch.randn([n2, 128], device='cuda', dtype=torch.bfloat16) for _ in range(2)]

merged_w1 = torch.cat(w1s, dim=0)
merged_w2 = torch.stack(w2s, dim=0)

def forward_before():
  y1s = [x @ w.T for w in w1s]
  y21 = y1s[-2] @ w2s[0].T
  y22 = y1s[-1] @ w2s[1].T

def forward_after():
  merged_y1 = x @ merged_w1.T
  merged_y2 = torch.bmm(merged_y1[:, -256:].view(m, 2, 128).transpose(0, 1), merged_w2.transpose(-1, -2))

forward_before()
forward_after()

For m=4096, 0.57ms vs 0.50ms, for m=8192, 1.03ms vs 0.98ms, for m=16384, 1.99ms vs 1.91ms

@strgrb
Copy link
Collaborator Author

strgrb commented Feb 2, 2026

/rerun-failed-ci

@strgrb
Copy link
Collaborator Author

strgrb commented Feb 2, 2026

/rerun-failed-ci

3 similar comments
@strgrb
Copy link
Collaborator Author

strgrb commented Feb 3, 2026

/rerun-failed-ci

@strgrb
Copy link
Collaborator Author

strgrb commented Feb 3, 2026

/rerun-failed-ci

@strgrb
Copy link
Collaborator Author

strgrb commented Feb 3, 2026

/rerun-failed-ci

@strgrb
Copy link
Collaborator Author

strgrb commented Feb 3, 2026

@ispobock It seems most ci passed and the rest is not related, and rerun bot seems not working.

@strgrb
Copy link
Collaborator Author

strgrb commented Feb 3, 2026

/rerun-failed-ci

1 similar comment
@strgrb
Copy link
Collaborator Author

strgrb commented Feb 4, 2026

/rerun-failed-ci

@strgrb
Copy link
Collaborator Author

strgrb commented Feb 4, 2026

@ispobock now it's all passed

@ispobock ispobock merged commit 37c33cc into sgl-project:main Feb 4, 2026
376 of 398 checks passed
charlesHsuGG pushed a commit to charlesHsuGG/sglang that referenced this pull request Feb 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants