[Perf] Add Flashinfer DeepGEMM SM90 for SwapAB Optimization#15514
[Perf] Add Flashinfer DeepGEMM SM90 for SwapAB Optimization#15514Fridge003 merged 3 commits intosgl-project:mainfrom
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
efad674 to
d735872
Compare
|
@b8zhong Will the warmup process be handled by flashinfer for this case? |
d735872 to
ae254a2
Compare
|
@Fridge003 I think, it uses the same DeepGEMM compiler under the hood. E.g during warmup you can see this process and a few similar ones. Although, I don't absolutely have the most context, so it may or may not be fully correct This is the regular |
|
@Fridge003 Although, it does seem somewhat faster actually (maybe 2x faster?) |
b8zhong
left a comment
There was a problem hiding this comment.
Added accuracy numbers too~
ae254a2 to
d12cb83
Compare
|
/tag-and-rerun-ci |
d12cb83 to
87e195c
Compare
|
Can we add a test for this new fp8 gemm kernel |
87e195c to
0d1dfc4
Compare
0d1dfc4 to
35b6375
Compare
|
Done @Fridge003 |
|
/rerun-failed-ci |
…ect#15514) Co-authored-by: Brayden Zhong <b8zhong@users.noreply.github.com>
…ect#15514) Co-authored-by: Brayden Zhong <b8zhong@users.noreply.github.com>
…ect#15514) Co-authored-by: Brayden Zhong <b8zhong@users.noreply.github.com>
Motivation
After flashinfer-ai/flashinfer#2131 in Flashinfer, we can benefit from SwapAB, where the input order is swapped to benefit when the M dimension is < 32 (e.g when BS < 32 in decoding). When it is larger, there is no benefit.
Modifications
(Requires Flashinfer nightly, and the backend currently only supports SM90)
Note that Flashinfer will compile it's own DeepGEMM. So it is separate from the DeepGEMM built in the Docker container.
Accuracy Tests
Benchmarking and Profiling
We can see that when the M dimension is small, there is around a 5-8% E2E benefit
Checklist