support smem in per_token_quant_fp8 kernel#16725
Conversation
Summary of ChangesHello @zhangxin81, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant optimization to the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for using shared memory in the per_token_quant_fp8 kernel, which is a good optimization. The changes are well-structured across the benchmark, C++ extension, and CUDA kernel files. However, I've identified a couple of critical issues in the CUDA kernel implementation. There's a syntax error that will prevent compilation, and a more serious bug in the shared memory size calculation that could lead to kernel crashes. I've provided suggestions to fix these issues. Additionally, I've suggested a refactoring to reduce code duplication in the kernel launch logic, which would improve maintainability.
| } | ||
| if constexpr (kVecSize == 16) { | ||
| *(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr; | ||
| *(uint4*)(token_output + i * kVecSize)) = *(uint4*)output_arr; |
| if (use_vec16) { | ||
| per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16><<<grid, block, 0, stream>>>( | ||
| static_cast<const scalar_t*>(input.data_ptr()), | ||
| static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), | ||
| static_cast<float*>(output_s.data_ptr()), | ||
| hidden_dim, | ||
| num_tokens); | ||
| if (use_smem) { | ||
| per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16, true><<<grid, block, dynamicSmemSz, stream>>>( | ||
| static_cast<const scalar_t*>(input.data_ptr()), | ||
| static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), | ||
| static_cast<float*>(output_s.data_ptr()), | ||
| hidden_dim, | ||
| num_tokens); | ||
| } else { | ||
| per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16, false><<<grid, block, 0, stream>>>( | ||
| static_cast<const scalar_t*>(input.data_ptr()), | ||
| static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), | ||
| static_cast<float*>(output_s.data_ptr()), | ||
| hidden_dim, | ||
| num_tokens); | ||
| } | ||
| } else if (use_vec8) { | ||
| per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 8><<<grid, block, 0, stream>>>( | ||
| static_cast<const scalar_t*>(input.data_ptr()), | ||
| static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), | ||
| static_cast<float*>(output_s.data_ptr()), | ||
| hidden_dim, | ||
| num_tokens); | ||
| if (use_smem) { | ||
| per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 8, true><<<grid, block, dynamicSmemSz, stream>>>( | ||
| static_cast<const scalar_t*>(input.data_ptr()), | ||
| static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), | ||
| static_cast<float*>(output_s.data_ptr()), | ||
| hidden_dim, | ||
| num_tokens); | ||
| } else { | ||
| per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 8, false><<<grid, block, 0, stream>>>( | ||
| static_cast<const scalar_t*>(input.data_ptr()), | ||
| static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), | ||
| static_cast<float*>(output_s.data_ptr()), | ||
| hidden_dim, | ||
| num_tokens); | ||
| } | ||
| } else { | ||
| per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 4><<<grid, block, 0, stream>>>( | ||
| static_cast<const scalar_t*>(input.data_ptr()), | ||
| static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), | ||
| static_cast<float*>(output_s.data_ptr()), | ||
| hidden_dim, | ||
| num_tokens); | ||
| if (use_smem) { | ||
| per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 4, true><<<grid, block, dynamicSmemSz, stream>>>( | ||
| static_cast<const scalar_t*>(input.data_ptr()), | ||
| static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), | ||
| static_cast<float*>(output_s.data_ptr()), | ||
| hidden_dim, | ||
| num_tokens); | ||
| } else { | ||
| per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 4, false><<<grid, block, 0, stream>>>( | ||
| static_cast<const scalar_t*>(input.data_ptr()), | ||
| static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), | ||
| static_cast<float*>(output_s.data_ptr()), | ||
| hidden_dim, | ||
| num_tokens); | ||
| } | ||
| } |
There was a problem hiding this comment.
There is a lot of duplicated code for launching the kernel with and without shared memory for different vector sizes. This can be refactored to improve readability and maintainability.
You can use a templated lambda to dispatch on the use_smem boolean at compile time. This will remove the need for nested if/else statements and reduce code duplication.
Here's a suggested refactoring. Note that this also assumes you've applied the fix for the dynamic shared memory size from my other comment (i.e., you have a totalDynamicSmemSz variable).
auto launcher = [&](auto use_smem_tag) {
constexpr bool USE_SMEM = decltype(use_smem_tag)::value;
const size_t smem_size = USE_SMEM ? totalDynamicSmemSz : 0;
if (use_vec16) {
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16, USE_SMEM><<<grid, block, smem_size, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()),
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
} else if (use_vec8) {
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 8, USE_SMEM><<<grid, block, smem_size, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()),
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
} else {
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 4, USE_SMEM><<<grid, block, smem_size, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()),
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
}
};
if (use_smem) {
launcher(std::true_type{});
} else {
launcher(std::false_type{});
}
|
/tag-run-ci-label |
|
/tag-and-rerun-ci |
|
/tag-run-ci-label |
…81/sglang into pertoken_quant_fp8_smem
|
|
||
| // Additional safety check: disable shared memory if it exceeds GPU limits (48KB default) | ||
| // This prevents kernel launch failures on GPUs with limited shared memory | ||
| if (dynamicSmemSz >= 48 * 1024) { |
There was a problem hiding this comment.
To maintain compatibility with older GPUs (such as the P100, K80, etc., which have only 48 KB of shared memory), a relatively conservative limit was set.
| hidden_dim, | ||
| num_tokens); | ||
| // Use templated lambda to dispatch on use_smem at compile time | ||
| auto launcher = [&](auto use_smem_tag) { |
There was a problem hiding this comment.
Can you avoid using a new lambda since you've already added a template parameter?
There was a problem hiding this comment.
Yes, it has been modified
|
/tag-and-rerun-ci |
…81/sglang into pertoken_quant_fp8_smem
|
Sgl-kernel ci is green and the only failed case is not related to this change. https://github.com/sgl-project/sglang/actions/runs/21580343201/job/62176914998?pr=16725 |
Co-authored-by: zhangxin81 <969206500@qq.com>
Co-authored-by: zhangxin81 <969206500@qq.com>
Co-authored-by: zhangxin81 <969206500@qq.com>
Co-authored-by: zhangxin81 <969206500@qq.com>
Motivation
This PR optimizes the per-token FP8 quantization kernel's shared memory usage strategy based on comprehensive performance analysis. The original implementation do not use shared memory, which may not effectiveness when we need to load input data from global memory twice.
Performance benchmarking revealed three distinct regions:
So, we set use smem as default when hidden_dim < 2048, otherwise keep origin implementation.
Modifications
1. Shared Memory Usage Strategy Optimization
Changed threshold logic: Replaced the size-based heuristic with a dimension-based strategy
hidden_dim < 2048(sweet spot region)hidden_dim >= 2048(avoid bottlenecks)Performance impact: This change ensures optimal performance in the sweet spot region while avoiding performance degradation at larger dimensions where shared memory overhead outweighs benefits.
2. Bank Conflict Avoidance
Added padding mechanism: Implemented 32-byte alignment padding for each warp's shared memory region
Modified memory layout calculation:
warp_smem_stridecalculation to include padding:(hidden_dim * sizeof(T) + smem_padding - 1) / smem_padding * smem_paddingdynamicSmemSzcalculation to account for padded strideAccuracy Tests
baseline resulr
⚠️ vLLM not available, skipping vLLM comparison
⚠️ vLLM not available, skipping vLLM comparison
⚠️ vLLM not available, skipping vLLM comparison
Scale difference (Torch vs SGLang): 0.00000048
Output difference (Torch vs SGLang): 0.09891525
Scale difference (Torch vs SGLang): 0.00000049
Output difference (Torch vs SGLang): 0.10589788
Scale difference (Torch vs SGLang): 0.00000052
Output difference (Torch vs SGLang): 0.11760408
after modify
⚠️ vLLM not available, skipping vLLM comparison
⚠️ vLLM not available, skipping vLLM comparison
⚠️ vLLM not available, skipping vLLM comparison
Scale difference (Torch vs SGLang): 0.00000048
Output difference (Torch vs SGLang): 0.09924992
Scale difference (Torch vs SGLang): 0.00000049
Output difference (Torch vs SGLang): 0.10579950
Scale difference (Torch vs SGLang): 0.00000051
Output difference (Torch vs SGLang): 0.11713473
Benchmarking and Profiling
Test env: H20, cuda12.8, python3.12, torch 2.9.1+cu128
baseline
batch_size seq_len hidden_dim Torch Reference SGL Kernel
0 16.0 64.0 1368.0 29.646327 7.693756
1 16.0 64.0 2048.0 36.739619 7.769798
2 16.0 64.0 4096.0 59.636530 10.718977
3 16.0 128.0 1368.0 40.477385 16.193892
4 16.0 128.0 2048.0 54.182625 10.641069
5 16.0 128.0 4096.0 95.370303 20.021468
6 16.0 256.0 1368.0 68.273983 27.070275
7 16.0 256.0 2048.0 94.092407 18.739484
8 16.0 256.0 4096.0 195.518128 41.840198
9 16.0 512.0 1368.0 126.534215 53.106706
10 16.0 512.0 2048.0 199.952341 40.661654
11 16.0 512.0 4096.0 379.464645 87.691549
12 16.0 1024.0 1368.0 267.703886 99.069689
13 16.0 1024.0 2048.0 384.651768 80.183049
14 16.0 1024.0 4096.0 741.917419 167.837811
15 16.0 2048.0 1368.0 517.432004 190.873726
16 16.0 2048.0 2048.0 743.639983 152.849611
17 16.0 2048.0 4096.0 1454.121351 337.736003
18 16.0 4096.0 1368.0 1018.137991 374.650460
19 16.0 4096.0 2048.0 1467.684893 299.245137
20 16.0 4096.0 4096.0 3203.431924 972.257614
21 32.0 64.0 1368.0 40.363321 16.178037
22 32.0 64.0 2048.0 54.437726 10.647474
23 32.0 64.0 4096.0 95.382402 19.921997
24 32.0 128.0 1368.0 68.261046 27.045915
25 32.0 128.0 2048.0 94.176903 18.695826
26 32.0 128.0 4096.0 195.367833 41.774324
27 32.0 256.0 1368.0 126.377859 52.992323
28 32.0 256.0 2048.0 199.789362 40.577023
29 32.0 256.0 4096.0 378.575668 87.965213
30 32.0 512.0 1368.0 267.704248 98.951995
31 32.0 512.0 2048.0 384.486713 79.590603
32 32.0 512.0 4096.0 741.669769 168.114328
33 32.0 1024.0 1368.0 516.815513 190.870814
34 32.0 1024.0 2048.0 744.294790 153.332089
35 32.0 1024.0 4096.0 1451.184034 335.001326
36 32.0 2048.0 1368.0 1017.246008 374.412922
37 32.0 2048.0 2048.0 1468.242499 300.410271
38 32.0 2048.0 4096.0 3205.632051 972.964795
39 32.0 4096.0 1368.0 1984.675121 740.604291
40 32.0 4096.0 2048.0 3238.525391 893.841884
41 32.0 4096.0 4096.0 6357.877413 1934.074720
42 64.0 64.0 1368.0 68.664609 27.205118
43 64.0 64.0 2048.0 94.189751 18.729272
44 64.0 64.0 4096.0 195.518473 41.948088
45 64.0 128.0 1368.0 126.471057 53.250535
46 64.0 128.0 2048.0 200.005397 40.772312
47 64.0 128.0 4096.0 378.841610 88.279052
48 64.0 256.0 1368.0 267.490822 99.206563
49 64.0 256.0 2048.0 384.445133 80.182252
50 64.0 256.0 4096.0 741.692162 172.218396
51 64.0 512.0 1368.0 518.103753 191.162979
52 64.0 512.0 2048.0 744.359383 153.925334
53 64.0 512.0 4096.0 1453.385353 339.800008
54 64.0 1024.0 1368.0 1012.616992 374.700308
55 64.0 1024.0 2048.0 1468.544006 300.041555
56 64.0 1024.0 4096.0 3202.634652 972.154140
57 64.0 2048.0 1368.0 1983.733283 741.153240
58 64.0 2048.0 2048.0 3238.799890 893.892709
59 64.0 2048.0 4096.0 6355.258624 1934.047953
60 64.0 4096.0 1368.0 3950.086403 1474.644954
61 64.0 4096.0 2048.0 6465.925217 1778.223991
62 64.0 4096.0 4096.0 12692.096233 3857.122285
63 128.0 64.0 1368.0 126.281553 53.017800
64 128.0 64.0 2048.0 200.021733 40.841206
65 128.0 64.0 4096.0 378.875523 88.366959
66 128.0 128.0 1368.0 267.453642 99.246659
67 128.0 128.0 2048.0 384.636154 80.312486
68 128.0 128.0 4096.0 741.959686 170.309882
69 128.0 256.0 1368.0 516.608515 190.937566
70 128.0 256.0 2048.0 744.329856 153.928617
71 128.0 256.0 4096.0 1451.529344 338.519442
72 128.0 512.0 1368.0 1015.675008 374.994150
73 128.0 512.0 2048.0 1468.575991 300.274358
74 128.0 512.0 4096.0 3203.466733 973.186652
75 128.0 1024.0 1368.0 1984.363174 740.550775
76 128.0 1024.0 2048.0 3240.565300 894.563759
77 128.0 1024.0 4096.0 6358.693441 1933.537102
78 128.0 2048.0 1368.0 3949.756813 1475.326758
79 128.0 2048.0 2048.0 6455.930710 1778.697407
80 128.0 2048.0 4096.0 12685.455799 3860.089166
81 128.0 4096.0 1368.0 7884.495974 2943.461418
82 128.0 4096.0 2048.0 12887.295723 3559.793949
83 128.0 4096.0 4096.0 25633.647919 7841.941516
after modify
0 16.0 64.0 1368.0 29.617860 7.818903
1 16.0 64.0 2048.0 36.760108 7.869630
2 16.0 64.0 4096.0 59.624661 10.829550
3 16.0 128.0 1368.0 40.423062 13.748718
4 16.0 128.0 2048.0 54.158679 10.739860
5 16.0 128.0 4096.0 95.420762 20.046813
6 16.0 256.0 1368.0 68.277600 22.341624
7 16.0 256.0 2048.0 94.175012 18.894946
8 16.0 256.0 4096.0 195.467908 41.953588
9 16.0 512.0 1368.0 126.439942 44.532937
10 16.0 512.0 2048.0 200.052178 40.870309
11 16.0 512.0 4096.0 379.287033 87.594001
12 16.0 1024.0 1368.0 267.510027 84.456216
13 16.0 1024.0 2048.0 384.763095 80.126844
14 16.0 1024.0 4096.0 740.679374 167.809910
15 16.0 2048.0 1368.0 514.154665 159.894218
16 16.0 2048.0 2048.0 743.463370 153.955694
17 16.0 2048.0 4096.0 1451.664050 337.815730
18 16.0 4096.0 1368.0 1014.358997 310.609530
19 16.0 4096.0 2048.0 1469.172918 299.084165
20 16.0 4096.0 4096.0 3204.533418 975.174395
21 32.0 64.0 1368.0 41.018890 13.753491
22 32.0 64.0 2048.0 54.638873 10.743134
23 32.0 64.0 4096.0 95.432003 20.026558
24 32.0 128.0 1368.0 68.253939 22.468845
25 32.0 128.0 2048.0 94.218024 18.860435
26 32.0 128.0 4096.0 195.380004 41.758302
27 32.0 256.0 1368.0 126.322454 44.466069
28 32.0 256.0 2048.0 199.911380 40.712394
29 32.0 256.0 4096.0 378.574409 87.605160
30 32.0 512.0 1368.0 267.332644 84.395788
31 32.0 512.0 2048.0 384.481277 80.129982
32 32.0 512.0 4096.0 741.952667 167.779446
33 32.0 1024.0 1368.0 515.731869 159.883105
34 32.0 1024.0 2048.0 744.195681 154.097986
35 32.0 1024.0 4096.0 1452.787956 338.460000
36 32.0 2048.0 1368.0 1012.876987 310.605715
37 32.0 2048.0 2048.0 1469.004338 301.278100
38 32.0 2048.0 4096.0 3207.794666 975.072543
39 32.0 4096.0 1368.0 1985.804749 609.508991
40 32.0 4096.0 2048.0 3237.773418 895.201403
41 32.0 4096.0 4096.0 6356.303851 1936.829885
42 64.0 64.0 1368.0 68.277307 22.288776
43 64.0 64.0 2048.0 94.191913 18.854252
44 64.0 64.0 4096.0 195.553019 41.899990
45 64.0 128.0 1368.0 126.219743 44.488439
46 64.0 128.0 2048.0 199.857743 40.752123
47 64.0 128.0 4096.0 378.594551 88.069430
48 64.0 256.0 1368.0 267.318123 84.384359
49 64.0 256.0 2048.0 384.589119 80.089084
50 64.0 256.0 4096.0 741.670418 168.234662
51 64.0 512.0 1368.0 517.268509 159.833388
52 64.0 512.0 2048.0 746.521840 156.082856
53 64.0 512.0 4096.0 1454.720020 339.689334
54 64.0 1024.0 1368.0 1017.301977 310.276834
55 64.0 1024.0 2048.0 1469.255374 300.602030
56 64.0 1024.0 4096.0 3206.781387 973.668798
57 64.0 2048.0 1368.0 1984.043217 609.552503
58 64.0 2048.0 2048.0 3238.874753 894.412714
59 64.0 2048.0 4096.0 6356.847763 1937.302399
60 64.0 4096.0 1368.0 3952.172852 1209.543049
61 64.0 4096.0 2048.0 6456.197103 1780.689857
62 64.0 4096.0 4096.0 12688.127995 3863.046782
63 128.0 64.0 1368.0 126.629871 44.320961
64 128.0 64.0 2048.0 200.008351 40.936135
65 128.0 64.0 4096.0 378.891201 88.158736
66 128.0 128.0 1368.0 267.397798 84.670429
67 128.0 128.0 2048.0 384.598537 80.167795
68 128.0 128.0 4096.0 741.516151 169.494564
69 128.0 256.0 1368.0 517.369002 159.945921
70 128.0 256.0 2048.0 744.374165 154.913719
71 128.0 256.0 4096.0 1452.633381 340.217752
72 128.0 512.0 1368.0 1012.525976 310.533584
73 128.0 512.0 2048.0 1469.053562 300.445130
74 128.0 512.0 4096.0 3208.733400 973.441060
75 128.0 1024.0 1368.0 1984.448009 609.678000
76 128.0 1024.0 2048.0 3237.690608 894.697891
77 128.0 1024.0 4096.0 6357.642810 1936.261304
78 128.0 2048.0 1368.0 3952.332878 1210.336030
79 128.0 2048.0 2048.0 6464.906693 1780.900731
80 128.0 2048.0 4096.0 12696.752071 3862.123489
81 128.0 4096.0 1368.0 7881.160021 2409.373999
82 128.0 4096.0 2048.0 12887.392044 3558.654070
83 128.0 4096.0 4096.0 25635.439873 7858.458837
hidden_dim=1368 configurations show consistent and substantial gains:
For long sequences (seq_len=4096): Up to 17.1% improvement (374.65ms → 310.61ms at batch=16)
For medium sequences (seq_len=1024): Up to 14.8% improvement (99.07ms → 84.46ms at batch=16)
For shorter sequences (seq_len=256): Up to 17.4% improvement (27.07ms → 22.34ms at batch=16)
while other hidden_dim keep the same。
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci) or contact authorized users to do so.