Skip to content

support smem in per_token_quant_fp8 kernel#16725

Merged
BBuf merged 21 commits intosgl-project:mainfrom
zhangxin81:pertoken_quant_fp8_smem
Feb 2, 2026
Merged

support smem in per_token_quant_fp8 kernel#16725
BBuf merged 21 commits intosgl-project:mainfrom
zhangxin81:pertoken_quant_fp8_smem

Conversation

@zhangxin81
Copy link
Contributor

@zhangxin81 zhangxin81 commented Jan 8, 2026

Motivation

This PR optimizes the per-token FP8 quantization kernel's shared memory usage strategy based on comprehensive performance analysis. The original implementation do not use shared memory, which may not effectiveness when we need to load input data from global memory twice.

Performance benchmarking revealed three distinct regions:

  • hidden_dim < 2048: "Sweet spot" region where shared memory optimization provides significant performance benefits through better memory hierarchy utilization and reduced global memory traffic
  • hidden_dim >= 2048: Critical point where shared memory becomes a bottleneck due to bank conflicts, increased register pressure, and unfavorable computation-to-memory ratios
  • hidden_dim >= 4096: Baseline and shared memory versions perform similarly, with baseline being more stable

So, we set use smem as default when hidden_dim < 2048, otherwise keep origin implementation.

Modifications

1. Shared Memory Usage Strategy Optimization

  • Changed threshold logic: Replaced the size-based heuristic with a dimension-based strategy

    • Enable shared memory when hidden_dim < 2048 (sweet spot region)
    • Disable shared memory when hidden_dim >= 2048 (avoid bottlenecks)
    • Retain 48KB safety check to prevent kernel launch failures on GPUs with limited shared memory
  • Performance impact: This change ensures optimal performance in the sweet spot region while avoiding performance degradation at larger dimensions where shared memory overhead outweighs benefits.

2. Bank Conflict Avoidance

  • Added padding mechanism: Implemented 32-byte alignment padding for each warp's shared memory region

    • Padding ensures each warp's data is aligned to bank boundaries (32 banks × 4 bytes = 128 bytes)
    • Prevents bank conflicts when multiple warps access elements at the same offset within their respective regions
    • Minimal overhead: typically 0-16 bytes per warp depending on hidden_dim
  • Modified memory layout calculation:

    • Updated warp_smem_stride calculation to include padding: (hidden_dim * sizeof(T) + smem_padding - 1) / smem_padding * smem_padding
    • Applied consistent padding calculation in both kernel code and host code
    • Updated dynamicSmemSz calculation to account for padded stride

Accuracy Tests

baseline resulr
⚠️ vLLM not available, skipping vLLM comparison
Scale difference (Torch vs SGLang): 0.00000048
Output difference (Torch vs SGLang): 0.09891525
⚠️ vLLM not available, skipping vLLM comparison
Scale difference (Torch vs SGLang): 0.00000049
Output difference (Torch vs SGLang): 0.10589788
⚠️ vLLM not available, skipping vLLM comparison
Scale difference (Torch vs SGLang): 0.00000052
Output difference (Torch vs SGLang): 0.11760408

after modify
⚠️ vLLM not available, skipping vLLM comparison
Scale difference (Torch vs SGLang): 0.00000048
Output difference (Torch vs SGLang): 0.09924992
⚠️ vLLM not available, skipping vLLM comparison
Scale difference (Torch vs SGLang): 0.00000049
Output difference (Torch vs SGLang): 0.10579950
⚠️ vLLM not available, skipping vLLM comparison
Scale difference (Torch vs SGLang): 0.00000051
Output difference (Torch vs SGLang): 0.11713473

Benchmarking and Profiling

Test env: H20, cuda12.8, python3.12, torch 2.9.1+cu128
baseline
batch_size seq_len hidden_dim Torch Reference SGL Kernel
0 16.0 64.0 1368.0 29.646327 7.693756
1 16.0 64.0 2048.0 36.739619 7.769798
2 16.0 64.0 4096.0 59.636530 10.718977
3 16.0 128.0 1368.0 40.477385 16.193892
4 16.0 128.0 2048.0 54.182625 10.641069
5 16.0 128.0 4096.0 95.370303 20.021468
6 16.0 256.0 1368.0 68.273983 27.070275
7 16.0 256.0 2048.0 94.092407 18.739484
8 16.0 256.0 4096.0 195.518128 41.840198
9 16.0 512.0 1368.0 126.534215 53.106706
10 16.0 512.0 2048.0 199.952341 40.661654
11 16.0 512.0 4096.0 379.464645 87.691549
12 16.0 1024.0 1368.0 267.703886 99.069689
13 16.0 1024.0 2048.0 384.651768 80.183049
14 16.0 1024.0 4096.0 741.917419 167.837811
15 16.0 2048.0 1368.0 517.432004 190.873726
16 16.0 2048.0 2048.0 743.639983 152.849611
17 16.0 2048.0 4096.0 1454.121351 337.736003
18 16.0 4096.0 1368.0 1018.137991 374.650460
19 16.0 4096.0 2048.0 1467.684893 299.245137
20 16.0 4096.0 4096.0 3203.431924 972.257614
21 32.0 64.0 1368.0 40.363321 16.178037
22 32.0 64.0 2048.0 54.437726 10.647474
23 32.0 64.0 4096.0 95.382402 19.921997
24 32.0 128.0 1368.0 68.261046 27.045915
25 32.0 128.0 2048.0 94.176903 18.695826
26 32.0 128.0 4096.0 195.367833 41.774324
27 32.0 256.0 1368.0 126.377859 52.992323
28 32.0 256.0 2048.0 199.789362 40.577023
29 32.0 256.0 4096.0 378.575668 87.965213
30 32.0 512.0 1368.0 267.704248 98.951995
31 32.0 512.0 2048.0 384.486713 79.590603
32 32.0 512.0 4096.0 741.669769 168.114328
33 32.0 1024.0 1368.0 516.815513 190.870814
34 32.0 1024.0 2048.0 744.294790 153.332089
35 32.0 1024.0 4096.0 1451.184034 335.001326
36 32.0 2048.0 1368.0 1017.246008 374.412922
37 32.0 2048.0 2048.0 1468.242499 300.410271
38 32.0 2048.0 4096.0 3205.632051 972.964795
39 32.0 4096.0 1368.0 1984.675121 740.604291
40 32.0 4096.0 2048.0 3238.525391 893.841884
41 32.0 4096.0 4096.0 6357.877413 1934.074720
42 64.0 64.0 1368.0 68.664609 27.205118
43 64.0 64.0 2048.0 94.189751 18.729272
44 64.0 64.0 4096.0 195.518473 41.948088
45 64.0 128.0 1368.0 126.471057 53.250535
46 64.0 128.0 2048.0 200.005397 40.772312
47 64.0 128.0 4096.0 378.841610 88.279052
48 64.0 256.0 1368.0 267.490822 99.206563
49 64.0 256.0 2048.0 384.445133 80.182252
50 64.0 256.0 4096.0 741.692162 172.218396
51 64.0 512.0 1368.0 518.103753 191.162979
52 64.0 512.0 2048.0 744.359383 153.925334
53 64.0 512.0 4096.0 1453.385353 339.800008
54 64.0 1024.0 1368.0 1012.616992 374.700308
55 64.0 1024.0 2048.0 1468.544006 300.041555
56 64.0 1024.0 4096.0 3202.634652 972.154140
57 64.0 2048.0 1368.0 1983.733283 741.153240
58 64.0 2048.0 2048.0 3238.799890 893.892709
59 64.0 2048.0 4096.0 6355.258624 1934.047953
60 64.0 4096.0 1368.0 3950.086403 1474.644954
61 64.0 4096.0 2048.0 6465.925217 1778.223991
62 64.0 4096.0 4096.0 12692.096233 3857.122285
63 128.0 64.0 1368.0 126.281553 53.017800
64 128.0 64.0 2048.0 200.021733 40.841206
65 128.0 64.0 4096.0 378.875523 88.366959
66 128.0 128.0 1368.0 267.453642 99.246659
67 128.0 128.0 2048.0 384.636154 80.312486
68 128.0 128.0 4096.0 741.959686 170.309882
69 128.0 256.0 1368.0 516.608515 190.937566
70 128.0 256.0 2048.0 744.329856 153.928617
71 128.0 256.0 4096.0 1451.529344 338.519442
72 128.0 512.0 1368.0 1015.675008 374.994150
73 128.0 512.0 2048.0 1468.575991 300.274358
74 128.0 512.0 4096.0 3203.466733 973.186652
75 128.0 1024.0 1368.0 1984.363174 740.550775
76 128.0 1024.0 2048.0 3240.565300 894.563759
77 128.0 1024.0 4096.0 6358.693441 1933.537102
78 128.0 2048.0 1368.0 3949.756813 1475.326758
79 128.0 2048.0 2048.0 6455.930710 1778.697407
80 128.0 2048.0 4096.0 12685.455799 3860.089166
81 128.0 4096.0 1368.0 7884.495974 2943.461418
82 128.0 4096.0 2048.0 12887.295723 3559.793949
83 128.0 4096.0 4096.0 25633.647919 7841.941516

after modify
0 16.0 64.0 1368.0 29.617860 7.818903
1 16.0 64.0 2048.0 36.760108 7.869630
2 16.0 64.0 4096.0 59.624661 10.829550
3 16.0 128.0 1368.0 40.423062 13.748718
4 16.0 128.0 2048.0 54.158679 10.739860
5 16.0 128.0 4096.0 95.420762 20.046813
6 16.0 256.0 1368.0 68.277600 22.341624
7 16.0 256.0 2048.0 94.175012 18.894946
8 16.0 256.0 4096.0 195.467908 41.953588
9 16.0 512.0 1368.0 126.439942 44.532937
10 16.0 512.0 2048.0 200.052178 40.870309
11 16.0 512.0 4096.0 379.287033 87.594001
12 16.0 1024.0 1368.0 267.510027 84.456216
13 16.0 1024.0 2048.0 384.763095 80.126844
14 16.0 1024.0 4096.0 740.679374 167.809910
15 16.0 2048.0 1368.0 514.154665 159.894218
16 16.0 2048.0 2048.0 743.463370 153.955694
17 16.0 2048.0 4096.0 1451.664050 337.815730
18 16.0 4096.0 1368.0 1014.358997 310.609530
19 16.0 4096.0 2048.0 1469.172918 299.084165
20 16.0 4096.0 4096.0 3204.533418 975.174395
21 32.0 64.0 1368.0 41.018890 13.753491
22 32.0 64.0 2048.0 54.638873 10.743134
23 32.0 64.0 4096.0 95.432003 20.026558
24 32.0 128.0 1368.0 68.253939 22.468845
25 32.0 128.0 2048.0 94.218024 18.860435
26 32.0 128.0 4096.0 195.380004 41.758302
27 32.0 256.0 1368.0 126.322454 44.466069
28 32.0 256.0 2048.0 199.911380 40.712394
29 32.0 256.0 4096.0 378.574409 87.605160
30 32.0 512.0 1368.0 267.332644 84.395788
31 32.0 512.0 2048.0 384.481277 80.129982
32 32.0 512.0 4096.0 741.952667 167.779446
33 32.0 1024.0 1368.0 515.731869 159.883105
34 32.0 1024.0 2048.0 744.195681 154.097986
35 32.0 1024.0 4096.0 1452.787956 338.460000
36 32.0 2048.0 1368.0 1012.876987 310.605715
37 32.0 2048.0 2048.0 1469.004338 301.278100
38 32.0 2048.0 4096.0 3207.794666 975.072543
39 32.0 4096.0 1368.0 1985.804749 609.508991
40 32.0 4096.0 2048.0 3237.773418 895.201403
41 32.0 4096.0 4096.0 6356.303851 1936.829885
42 64.0 64.0 1368.0 68.277307 22.288776
43 64.0 64.0 2048.0 94.191913 18.854252
44 64.0 64.0 4096.0 195.553019 41.899990
45 64.0 128.0 1368.0 126.219743 44.488439
46 64.0 128.0 2048.0 199.857743 40.752123
47 64.0 128.0 4096.0 378.594551 88.069430
48 64.0 256.0 1368.0 267.318123 84.384359
49 64.0 256.0 2048.0 384.589119 80.089084
50 64.0 256.0 4096.0 741.670418 168.234662
51 64.0 512.0 1368.0 517.268509 159.833388
52 64.0 512.0 2048.0 746.521840 156.082856
53 64.0 512.0 4096.0 1454.720020 339.689334
54 64.0 1024.0 1368.0 1017.301977 310.276834
55 64.0 1024.0 2048.0 1469.255374 300.602030
56 64.0 1024.0 4096.0 3206.781387 973.668798
57 64.0 2048.0 1368.0 1984.043217 609.552503
58 64.0 2048.0 2048.0 3238.874753 894.412714
59 64.0 2048.0 4096.0 6356.847763 1937.302399
60 64.0 4096.0 1368.0 3952.172852 1209.543049
61 64.0 4096.0 2048.0 6456.197103 1780.689857
62 64.0 4096.0 4096.0 12688.127995 3863.046782
63 128.0 64.0 1368.0 126.629871 44.320961
64 128.0 64.0 2048.0 200.008351 40.936135
65 128.0 64.0 4096.0 378.891201 88.158736
66 128.0 128.0 1368.0 267.397798 84.670429
67 128.0 128.0 2048.0 384.598537 80.167795
68 128.0 128.0 4096.0 741.516151 169.494564
69 128.0 256.0 1368.0 517.369002 159.945921
70 128.0 256.0 2048.0 744.374165 154.913719
71 128.0 256.0 4096.0 1452.633381 340.217752
72 128.0 512.0 1368.0 1012.525976 310.533584
73 128.0 512.0 2048.0 1469.053562 300.445130
74 128.0 512.0 4096.0 3208.733400 973.441060
75 128.0 1024.0 1368.0 1984.448009 609.678000
76 128.0 1024.0 2048.0 3237.690608 894.697891
77 128.0 1024.0 4096.0 6357.642810 1936.261304
78 128.0 2048.0 1368.0 3952.332878 1210.336030
79 128.0 2048.0 2048.0 6464.906693 1780.900731
80 128.0 2048.0 4096.0 12696.752071 3862.123489
81 128.0 4096.0 1368.0 7881.160021 2409.373999
82 128.0 4096.0 2048.0 12887.392044 3558.654070
83 128.0 4096.0 4096.0 25635.439873 7858.458837

hidden_dim=1368 configurations show consistent and substantial gains:
For long sequences (seq_len=4096): Up to 17.1% improvement (374.65ms → 310.61ms at batch=16)
For medium sequences (seq_len=1024): Up to 14.8% improvement (99.07ms → 84.46ms at batch=16)
For shorter sequences (seq_len=256): Up to 17.4% improvement (27.07ms → 22.34ms at batch=16)
while other hidden_dim keep the same。

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments (/tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci) or contact authorized users to do so.
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @zhangxin81, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant optimization to the per_token_quant_fp8 kernel by leveraging shared memory to cache input data. This change is designed to enhance performance by minimizing redundant global memory reads during the two-pass quantization process. The implementation provides a configurable option for shared memory usage, along with a dynamic sizing heuristic, and updates the benchmarking tools to facilitate performance evaluation of this new feature.

Highlights

  • Shared Memory Integration: The per_token_quant_fp8 CUDA kernel now supports optional shared memory (smem) caching for input data, aiming to reduce global memory access during the quantization process.
  • Configurable Shared Memory Usage: A new use_smem_cache boolean parameter has been introduced across the Python, C++, and CUDA interfaces, allowing users to explicitly control whether shared memory is utilized.
  • Dynamic Shared Memory Heuristic: The C++ implementation includes a heuristic to dynamically decide whether to enable shared memory based on the required size (calculated from hidden_dim and data type), preventing excessive shared memory allocation and potential kernel launch failures.
  • Benchmarking Enhancements: The benchmarking script has been updated to include separate profiles for the SGL kernel with and without shared memory (sglang_smem and sglang_no_smem), enabling direct performance comparison of the optimization.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for using shared memory in the per_token_quant_fp8 kernel, which is a good optimization. The changes are well-structured across the benchmark, C++ extension, and CUDA kernel files. However, I've identified a couple of critical issues in the CUDA kernel implementation. There's a syntax error that will prevent compilation, and a more serious bug in the shared memory size calculation that could lead to kernel crashes. I've provided suggestions to fix these issues. Additionally, I've suggested a refactoring to reduce code duplication in the kernel launch logic, which would improve maintainability.

}
if constexpr (kVecSize == 16) {
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
*(uint4*)(token_output + i * kVecSize)) = *(uint4*)output_arr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a syntax error on this line due to an extra closing parenthesis. This will cause a compilation failure.

      *(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;

Comment on lines 243 to 291
if (use_vec16) {
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16><<<grid, block, 0, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()),
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
if (use_smem) {
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16, true><<<grid, block, dynamicSmemSz, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()),
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
} else {
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16, false><<<grid, block, 0, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()),
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
}
} else if (use_vec8) {
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 8><<<grid, block, 0, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()),
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
if (use_smem) {
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 8, true><<<grid, block, dynamicSmemSz, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()),
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
} else {
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 8, false><<<grid, block, 0, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()),
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
}
} else {
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 4><<<grid, block, 0, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()),
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
if (use_smem) {
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 4, true><<<grid, block, dynamicSmemSz, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()),
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
} else {
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 4, false><<<grid, block, 0, stream>>>(
static_cast<const scalar_t*>(input.data_ptr()),
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
static_cast<float*>(output_s.data_ptr()),
hidden_dim,
num_tokens);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a lot of duplicated code for launching the kernel with and without shared memory for different vector sizes. This can be refactored to improve readability and maintainability.

You can use a templated lambda to dispatch on the use_smem boolean at compile time. This will remove the need for nested if/else statements and reduce code duplication.

Here's a suggested refactoring. Note that this also assumes you've applied the fix for the dynamic shared memory size from my other comment (i.e., you have a totalDynamicSmemSz variable).

      auto launcher = [&](auto use_smem_tag) {
        constexpr bool USE_SMEM = decltype(use_smem_tag)::value;
        const size_t smem_size = USE_SMEM ? totalDynamicSmemSz : 0;

        if (use_vec16) {
          per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16, USE_SMEM><<<grid, block, smem_size, stream>>>(
              static_cast<const scalar_t*>(input.data_ptr()),
              static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
              static_cast<float*>(output_s.data_ptr()),
              hidden_dim,
              num_tokens);
        } else if (use_vec8) {
          per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 8, USE_SMEM><<<grid, block, smem_size, stream>>>(
              static_cast<const scalar_t*>(input.data_ptr()),
              static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
              static_cast<float*>(output_s.data_ptr()),
              hidden_dim,
              num_tokens);
        } else {
          per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 4, USE_SMEM><<<grid, block, smem_size, stream>>>(
              static_cast<const scalar_t*>(input.data_ptr()),
              static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
              static_cast<float*>(output_s.data_ptr()),
              hidden_dim,
              num_tokens);
        }
      };

      if (use_smem) {
        launcher(std::true_type{});
      } else {
        launcher(std::false_type{});
      }

@zhangxin81
Copy link
Contributor Author

/tag-run-ci-label

@zhangxin81
Copy link
Contributor Author

/tag-and-rerun-ci

@zhangxin81
Copy link
Contributor Author

/tag-run-ci-label


// Additional safety check: disable shared memory if it exceeds GPU limits (48KB default)
// This prevents kernel launch failures on GPUs with limited shared memory
if (dynamicSmemSz >= 48 * 1024) {
Copy link
Collaborator

@BBuf BBuf Jan 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 48KB?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To maintain compatibility with older GPUs (such as the P100, K80, etc., which have only 48 KB of shared memory), a relatively conservative limit was set.

@zhangxin81 zhangxin81 requested a review from BBuf January 24, 2026 09:14
hidden_dim,
num_tokens);
// Use templated lambda to dispatch on use_smem at compile time
auto launcher = [&](auto use_smem_tag) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you avoid using a new lambda since you've already added a template parameter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it has been modified

@BBuf
Copy link
Collaborator

BBuf commented Jan 28, 2026

/tag-and-rerun-ci

@BBuf
Copy link
Collaborator

BBuf commented Feb 2, 2026

Sgl-kernel ci is green and the only failed case is not related to this change. https://github.com/sgl-project/sglang/actions/runs/21580343201/job/62176914998?pr=16725

@BBuf BBuf merged commit e3021b6 into sgl-project:main Feb 2, 2026
87 of 94 checks passed
@zhangxin81 zhangxin81 deleted the pertoken_quant_fp8_smem branch February 2, 2026 09:41
yingluosanqian pushed a commit to AichenF/sglang that referenced this pull request Feb 3, 2026
Co-authored-by: zhangxin81 <969206500@qq.com>
charlesHsuGG pushed a commit to charlesHsuGG/sglang that referenced this pull request Feb 5, 2026
Co-authored-by: zhangxin81 <969206500@qq.com>
sfiisf pushed a commit to sfiisf/sglang that referenced this pull request Feb 5, 2026
Co-authored-by: zhangxin81 <969206500@qq.com>
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
Co-authored-by: zhangxin81 <969206500@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants