Skip to content

Commit b35d8f7

Browse files
zhangxin81zxxxxxxxin
authored andcommitted
support smem in per_token_quant_fp8 kernel (sgl-project#16725)
Co-authored-by: zhangxin81 <969206500@qq.com>
1 parent 8ac57b5 commit b35d8f7

File tree

1 file changed

+97
-26
lines changed

1 file changed

+97
-26
lines changed

sgl-kernel/csrc/gemm/per_token_quant_fp8.cu

Lines changed: 97 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
#include "utils.h"
77

88
static constexpr int kWarpSize = 32;
9+
static constexpr int DEFAULT_SHARED_MEM_THRESHOLD_KB = 48; // Default shared memory quota in KB
910

1011
// ---------------------------------------------------------------------------
11-
// 1. Warp‑local, no shared memory
12+
// 1. Warp‑local with configurable shared memory
1213
// • One warp handles one token.
1314
// • Eight tokens per 256‑thread CTA.
15+
// • Shared memory usage is configurable via template parameter.
1416
// ---------------------------------------------------------------------------
15-
template <typename T, typename DST_DTYPE, int kTokensPerCTA = 8, int kVecSize = 16>
17+
template <typename T, typename DST_DTYPE, int kTokensPerCTA = 8, int kVecSize = 16, bool USE_SMEM = true>
1618
__global__ void per_token_quant_fp8_kernel(
1719
const T* __restrict__ input,
1820
DST_DTYPE* __restrict__ output_q,
@@ -29,8 +31,14 @@ __global__ void per_token_quant_fp8_kernel(
2931
DST_DTYPE* token_output = output_q + token_id * hidden_dim;
3032
float* token_scale = output_s + token_id;
3133

34+
extern __shared__ char smem_buffer[];
35+
const int smem_padding = 32; // Pad to bank boundary (32 banks * 4 bytes = 128 bytes)
36+
const int warp_smem_stride = (hidden_dim * sizeof(T) + smem_padding - 1) / smem_padding * smem_padding;
37+
const int warp_smem_offset = warp_id * warp_smem_stride;
38+
T* shared_input = reinterpret_cast<T*>(smem_buffer + warp_smem_offset);
39+
3240
//
33-
// Pass-1: Perform a warp reduce to find the max_value of a token's hidden_dim
41+
// Pass-1: Load data and compute max_value
3442
//
3543
float max_value = 0.f;
3644
using vec_t = flashinfer::vec_t<T, kVecSize>;
@@ -40,12 +48,26 @@ __global__ void per_token_quant_fp8_kernel(
4048
vec_t input_vec;
4149
input_vec.cast_load(token_input + i * kVecSize);
4250

51+
// Store to shared memory if USE_SMEM=true
52+
if constexpr (USE_SMEM) {
53+
#pragma unroll
54+
for (uint32_t j = 0; j < kVecSize; ++j) {
55+
shared_input[i * kVecSize + j] = input_vec[j];
56+
}
57+
}
58+
59+
// Compute max value in parallel
4360
#pragma unroll
4461
for (uint32_t j = 0; j < kVecSize; ++j) {
4562
max_value = fmaxf(max_value, fabsf(static_cast<float>(input_vec[j])));
4663
}
4764
}
4865

66+
// Ensure all threads in the warp have finished writing to shared memory
67+
if constexpr (USE_SMEM) {
68+
__syncwarp();
69+
}
70+
4971
float warp_max = warpReduceMax(max_value);
5072

5173
// NOTE: one CTA has multiple warps (each warp handles one token), so `scale`
@@ -58,11 +80,22 @@ __global__ void per_token_quant_fp8_kernel(
5880
const float scale_inv = (scale == 0.f) ? 0.f : 1.0f / scale;
5981

6082
//
61-
// Pass-2: quantize and write back
83+
// Pass-2: Quantize and write back
6284
//
6385
for (int i = lane_id; i < num_vec_elems; i += kWarpSize) {
6486
vec_t input_vec;
65-
input_vec.cast_load(token_input + i * kVecSize);
87+
88+
if constexpr (USE_SMEM) {
89+
// Load from shared memory
90+
#pragma unroll
91+
for (uint32_t j = 0; j < kVecSize; ++j) {
92+
input_vec[j] = shared_input[i * kVecSize + j];
93+
}
94+
} else {
95+
// Reload from global memory
96+
input_vec.cast_load(token_input + i * kVecSize);
97+
}
98+
6699
DST_DTYPE output_arr[kVecSize];
67100
#pragma unroll
68101
for (uint32_t j = 0; j < kVecSize; ++j) {
@@ -164,6 +197,48 @@ __global__ void per_token_quant_fp8_small_batch_kernel(
164197
}
165198
}
166199

200+
template <bool USE_SMEM, typename scalar_t, int TOKENS_PER_CTA>
201+
static inline void launch_per_token_quant_fp8_warp_kernel(
202+
const dim3& grid,
203+
const dim3& block,
204+
size_t dynamicSmemSz,
205+
cudaStream_t stream,
206+
bool use_vec16,
207+
bool use_vec8,
208+
torch::Tensor input,
209+
torch::Tensor output_q,
210+
torch::Tensor output_s,
211+
const int64_t hidden_dim,
212+
const int64_t num_tokens) {
213+
const size_t smem_size = USE_SMEM ? dynamicSmemSz : 0;
214+
215+
if (use_vec16) {
216+
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16, USE_SMEM>
217+
<<<grid, block, smem_size, stream>>>(
218+
static_cast<const scalar_t*>(input.data_ptr()),
219+
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
220+
static_cast<float*>(output_s.data_ptr()),
221+
hidden_dim,
222+
num_tokens);
223+
} else if (use_vec8) {
224+
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 8, USE_SMEM>
225+
<<<grid, block, smem_size, stream>>>(
226+
static_cast<const scalar_t*>(input.data_ptr()),
227+
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
228+
static_cast<float*>(output_s.data_ptr()),
229+
hidden_dim,
230+
num_tokens);
231+
} else {
232+
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 4, USE_SMEM>
233+
<<<grid, block, smem_size, stream>>>(
234+
static_cast<const scalar_t*>(input.data_ptr()),
235+
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
236+
static_cast<float*>(output_s.data_ptr()),
237+
hidden_dim,
238+
num_tokens);
239+
}
240+
}
241+
167242
void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s) {
168243
CHECK_INPUT(input);
169244
CHECK_INPUT(output_q);
@@ -180,34 +255,30 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
180255
const bool use_vec16 = (hidden_dim % 16 == 0);
181256
const bool use_vec8 = (hidden_dim % 8 == 0);
182257

258+
const int sizeof_T = input.scalar_type() == torch::kFloat16 ? 2 : (input.scalar_type() == torch::kBFloat16 ? 2 : 4);
259+
const int smem_padding = 32; // Pad to bank boundary to avoid conflicts
260+
const int warp_smem_stride = (hidden_dim * sizeof_T + smem_padding - 1) / smem_padding * smem_padding;
261+
const size_t dynamicSmemSz = warp_smem_stride * TOKENS_PER_CTA;
262+
263+
bool use_smem = (hidden_dim < 2048);
264+
265+
if (dynamicSmemSz >= DEFAULT_SHARED_MEM_THRESHOLD_KB) {
266+
use_smem = false; // Disable shared memory if >= 48KB to avoid allocation failures
267+
}
268+
183269
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
184270
if (use_warp_kernel) {
185271
// -------- warp‑local ---------------------------------------------------
186-
constexpr int THREADS = TOKENS_PER_CTA * kWarpSize; // 256
272+
constexpr int THREADS = TOKENS_PER_CTA * kWarpSize;
187273
dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA);
188274
dim3 block(THREADS);
189275

190-
if (use_vec16) {
191-
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16><<<grid, block, 0, stream>>>(
192-
static_cast<const scalar_t*>(input.data_ptr()),
193-
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
194-
static_cast<float*>(output_s.data_ptr()),
195-
hidden_dim,
196-
num_tokens);
197-
} else if (use_vec8) {
198-
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 8><<<grid, block, 0, stream>>>(
199-
static_cast<const scalar_t*>(input.data_ptr()),
200-
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
201-
static_cast<float*>(output_s.data_ptr()),
202-
hidden_dim,
203-
num_tokens);
276+
if (use_smem) {
277+
launch_per_token_quant_fp8_warp_kernel</*USE_SMEM=*/true, scalar_t, TOKENS_PER_CTA>(
278+
grid, block, dynamicSmemSz, stream, use_vec16, use_vec8, input, output_q, output_s, hidden_dim, num_tokens);
204279
} else {
205-
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 4><<<grid, block, 0, stream>>>(
206-
static_cast<const scalar_t*>(input.data_ptr()),
207-
static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()),
208-
static_cast<float*>(output_s.data_ptr()),
209-
hidden_dim,
210-
num_tokens);
280+
launch_per_token_quant_fp8_warp_kernel</*USE_SMEM=*/false, scalar_t, TOKENS_PER_CTA>(
281+
grid, block, dynamicSmemSz, stream, use_vec16, use_vec8, input, output_q, output_s, hidden_dim, num_tokens);
211282
}
212283
} else {
213284
// -------- baseline -----------------------------------------------------

0 commit comments

Comments
 (0)