fuse qkvbfg linear into one gemm and f_b g_b into batched gemm.#17801
fuse qkvbfg linear into one gemm and f_b g_b into batched gemm.#17801ispobock merged 4 commits intosgl-project:mainfrom
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
/tag-and-rerun-ci |
|
@yizhang2077 @ispobock Now benchmark is ready here. |
|
/rerun-failed-ci |
1 similar comment
|
/rerun-failed-ci |
| self.num_heads, | ||
| ] | ||
| self.fg_sizes = [self.head_dim, self.head_dim] | ||
| self.fused_qkvbfg_proj = MergedColumnParallelRepeatedLinear( |
There was a problem hiding this comment.
It seems that, for large bs, it will cause higher ttft, since it's already compute bound.
There was a problem hiding this comment.
Gemm with large M will not cause higher latency, I think this benchmark result may be caused by random issue. I'll test it for more times and do some separate test.
There was a problem hiding this comment.
@ispobock I used benchmark but ttft result is not so stable, so I use throughput at output_size=1 to identify the prefill performance at batch_size=128.
Before opt is 12.87 and after is 12.96, this result is quite stable.
There was a problem hiding this comment.
@ispobock And for separate test, I use following code to test. Mainly for m=4096 to 16384, because these values are commonly used chunked_prefill_size .
import torch
m = 8192
k = 2304
n1s = [1024, 1024, 1024, 8, 128, 128]
n2 = 1024
x = torch.randn([m, k], device='cuda', dtype=torch.bfloat16)
w1s = [torch.randn([n, k], device='cuda', dtype=torch.bfloat16) for n in n1s]
w2s = [torch.randn([n2, 128], device='cuda', dtype=torch.bfloat16) for _ in range(2)]
merged_w1 = torch.cat(w1s, dim=0)
merged_w2 = torch.stack(w2s, dim=0)
def forward_before():
y1s = [x @ w.T for w in w1s]
y21 = y1s[-2] @ w2s[0].T
y22 = y1s[-1] @ w2s[1].T
def forward_after():
merged_y1 = x @ merged_w1.T
merged_y2 = torch.bmm(merged_y1[:, -256:].view(m, 2, 128).transpose(0, 1), merged_w2.transpose(-1, -2))
forward_before()
forward_after()
For m=4096, 0.57ms vs 0.50ms, for m=8192, 1.03ms vs 0.98ms, for m=16384, 1.99ms vs 1.91ms
|
/rerun-failed-ci |
|
/rerun-failed-ci |
3 similar comments
|
/rerun-failed-ci |
|
/rerun-failed-ci |
|
/rerun-failed-ci |
|
@ispobock It seems most ci passed and the rest is not related, and rerun bot seems not working. |
|
/rerun-failed-ci |
1 similar comment
|
/rerun-failed-ci |
|
@ispobock now it's all passed |
Motivation
There are 8 gemm in kda, and 6 of them (q_proj, k_proj, v_proj, b_proj, f_a_proj, g_a_proj) share the same input, so they can be fused into single gemm, and other 2 of them (f_b_proj, g_b_proj) has the same input/output size, so they can be fused into a batched gemm.
With 4k input 1k output test for decode, the profile like the following:

kernel duration from 5.9+5.7+5.6+5.6+2.5+5.2+5.7+2.6=38.8us to 8.7+3=11.7us
For prefill, kernel duration from 286+286+286+45+20+14+45+19=1001us to 934+35=969us, optimization is much less than decode.
Modifications
Accuracy Tests
test with gsm8k all questions:
before:
after:
Benchmarking and Profiling
before:
after:
TPOT is decreased by about 5%~10%
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci