66#include " utils.h"
77
88static constexpr int kWarpSize = 32 ;
9+ static constexpr int DEFAULT_SHARED_MEM_THRESHOLD_KB = 48 ; // Default shared memory quota in KB
910
1011// ---------------------------------------------------------------------------
11- // 1. Warp‑local, no shared memory
12+ // 1. Warp‑local with configurable shared memory
1213// • One warp handles one token.
1314// • Eight tokens per 256‑thread CTA.
15+ // • Shared memory usage is configurable via template parameter.
1416// ---------------------------------------------------------------------------
15- template <typename T, typename DST_DTYPE, int kTokensPerCTA = 8 , int kVecSize = 16 >
17+ template <typename T, typename DST_DTYPE, int kTokensPerCTA = 8 , int kVecSize = 16 , bool USE_SMEM = true >
1618__global__ void per_token_quant_fp8_kernel (
1719 const T* __restrict__ input,
1820 DST_DTYPE* __restrict__ output_q,
@@ -29,8 +31,14 @@ __global__ void per_token_quant_fp8_kernel(
2931 DST_DTYPE* token_output = output_q + token_id * hidden_dim;
3032 float * token_scale = output_s + token_id;
3133
34+ extern __shared__ char smem_buffer[];
35+ const int smem_padding = 32 ; // Pad to bank boundary (32 banks * 4 bytes = 128 bytes)
36+ const int warp_smem_stride = (hidden_dim * sizeof (T) + smem_padding - 1 ) / smem_padding * smem_padding;
37+ const int warp_smem_offset = warp_id * warp_smem_stride;
38+ T* shared_input = reinterpret_cast <T*>(smem_buffer + warp_smem_offset);
39+
3240 //
33- // Pass-1: Perform a warp reduce to find the max_value of a token's hidden_dim
41+ // Pass-1: Load data and compute max_value
3442 //
3543 float max_value = 0 .f ;
3644 using vec_t = flashinfer::vec_t <T, kVecSize >;
@@ -40,12 +48,26 @@ __global__ void per_token_quant_fp8_kernel(
4048 vec_t input_vec;
4149 input_vec.cast_load (token_input + i * kVecSize );
4250
51+ // Store to shared memory if USE_SMEM=true
52+ if constexpr (USE_SMEM) {
53+ #pragma unroll
54+ for (uint32_t j = 0 ; j < kVecSize ; ++j) {
55+ shared_input[i * kVecSize + j] = input_vec[j];
56+ }
57+ }
58+
59+ // Compute max value in parallel
4360#pragma unroll
4461 for (uint32_t j = 0 ; j < kVecSize ; ++j) {
4562 max_value = fmaxf (max_value, fabsf (static_cast <float >(input_vec[j])));
4663 }
4764 }
4865
66+ // Ensure all threads in the warp have finished writing to shared memory
67+ if constexpr (USE_SMEM) {
68+ __syncwarp ();
69+ }
70+
4971 float warp_max = warpReduceMax (max_value);
5072
5173 // NOTE: one CTA has multiple warps (each warp handles one token), so `scale`
@@ -58,11 +80,22 @@ __global__ void per_token_quant_fp8_kernel(
5880 const float scale_inv = (scale == 0 .f ) ? 0 .f : 1 .0f / scale;
5981
6082 //
61- // Pass-2: quantize and write back
83+ // Pass-2: Quantize and write back
6284 //
6385 for (int i = lane_id; i < num_vec_elems; i += kWarpSize ) {
6486 vec_t input_vec;
65- input_vec.cast_load (token_input + i * kVecSize );
87+
88+ if constexpr (USE_SMEM) {
89+ // Load from shared memory
90+ #pragma unroll
91+ for (uint32_t j = 0 ; j < kVecSize ; ++j) {
92+ input_vec[j] = shared_input[i * kVecSize + j];
93+ }
94+ } else {
95+ // Reload from global memory
96+ input_vec.cast_load (token_input + i * kVecSize );
97+ }
98+
6699 DST_DTYPE output_arr[kVecSize ];
67100#pragma unroll
68101 for (uint32_t j = 0 ; j < kVecSize ; ++j) {
@@ -164,6 +197,48 @@ __global__ void per_token_quant_fp8_small_batch_kernel(
164197 }
165198}
166199
200+ template <bool USE_SMEM, typename scalar_t , int TOKENS_PER_CTA>
201+ static inline void launch_per_token_quant_fp8_warp_kernel (
202+ const dim3 & grid,
203+ const dim3 & block,
204+ size_t dynamicSmemSz,
205+ cudaStream_t stream,
206+ bool use_vec16,
207+ bool use_vec8,
208+ torch::Tensor input,
209+ torch::Tensor output_q,
210+ torch::Tensor output_s,
211+ const int64_t hidden_dim,
212+ const int64_t num_tokens) {
213+ const size_t smem_size = USE_SMEM ? dynamicSmemSz : 0 ;
214+
215+ if (use_vec16) {
216+ per_token_quant_fp8_kernel<scalar_t , __nv_fp8_e4m3, TOKENS_PER_CTA, 16 , USE_SMEM>
217+ <<<grid, block, smem_size, stream>>> (
218+ static_cast <const scalar_t *>(input.data_ptr ()),
219+ static_cast <__nv_fp8_e4m3*>(output_q.data_ptr ()),
220+ static_cast <float *>(output_s.data_ptr ()),
221+ hidden_dim,
222+ num_tokens);
223+ } else if (use_vec8) {
224+ per_token_quant_fp8_kernel<scalar_t , __nv_fp8_e4m3, TOKENS_PER_CTA, 8 , USE_SMEM>
225+ <<<grid, block, smem_size, stream>>> (
226+ static_cast <const scalar_t *>(input.data_ptr ()),
227+ static_cast <__nv_fp8_e4m3*>(output_q.data_ptr ()),
228+ static_cast <float *>(output_s.data_ptr ()),
229+ hidden_dim,
230+ num_tokens);
231+ } else {
232+ per_token_quant_fp8_kernel<scalar_t , __nv_fp8_e4m3, TOKENS_PER_CTA, 4 , USE_SMEM>
233+ <<<grid, block, smem_size, stream>>> (
234+ static_cast <const scalar_t *>(input.data_ptr ()),
235+ static_cast <__nv_fp8_e4m3*>(output_q.data_ptr ()),
236+ static_cast <float *>(output_s.data_ptr ()),
237+ hidden_dim,
238+ num_tokens);
239+ }
240+ }
241+
167242void sgl_per_token_quant_fp8 (torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s) {
168243 CHECK_INPUT (input);
169244 CHECK_INPUT (output_q);
@@ -180,34 +255,30 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
180255 const bool use_vec16 = (hidden_dim % 16 == 0 );
181256 const bool use_vec8 = (hidden_dim % 8 == 0 );
182257
258+ const int sizeof_T = input.scalar_type () == torch::kFloat16 ? 2 : (input.scalar_type () == torch::kBFloat16 ? 2 : 4 );
259+ const int smem_padding = 32 ; // Pad to bank boundary to avoid conflicts
260+ const int warp_smem_stride = (hidden_dim * sizeof_T + smem_padding - 1 ) / smem_padding * smem_padding;
261+ const size_t dynamicSmemSz = warp_smem_stride * TOKENS_PER_CTA;
262+
263+ bool use_smem = (hidden_dim < 2048 );
264+
265+ if (dynamicSmemSz >= DEFAULT_SHARED_MEM_THRESHOLD_KB) {
266+ use_smem = false ; // Disable shared memory if >= 48KB to avoid allocation failures
267+ }
268+
183269 DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16 (input.scalar_type (), scalar_t , [&] {
184270 if (use_warp_kernel) {
185271 // -------- warp‑local ---------------------------------------------------
186- constexpr int THREADS = TOKENS_PER_CTA * kWarpSize ; // 256
272+ constexpr int THREADS = TOKENS_PER_CTA * kWarpSize ;
187273 dim3 grid ((num_tokens + TOKENS_PER_CTA - 1 ) / TOKENS_PER_CTA);
188274 dim3 block (THREADS);
189275
190- if (use_vec16) {
191- per_token_quant_fp8_kernel<scalar_t , __nv_fp8_e4m3, TOKENS_PER_CTA, 16 ><<<grid, block, 0 , stream>>> (
192- static_cast <const scalar_t *>(input.data_ptr ()),
193- static_cast <__nv_fp8_e4m3*>(output_q.data_ptr ()),
194- static_cast <float *>(output_s.data_ptr ()),
195- hidden_dim,
196- num_tokens);
197- } else if (use_vec8) {
198- per_token_quant_fp8_kernel<scalar_t , __nv_fp8_e4m3, TOKENS_PER_CTA, 8 ><<<grid, block, 0 , stream>>> (
199- static_cast <const scalar_t *>(input.data_ptr ()),
200- static_cast <__nv_fp8_e4m3*>(output_q.data_ptr ()),
201- static_cast <float *>(output_s.data_ptr ()),
202- hidden_dim,
203- num_tokens);
276+ if (use_smem) {
277+ launch_per_token_quant_fp8_warp_kernel</* USE_SMEM=*/ true , scalar_t , TOKENS_PER_CTA>(
278+ grid, block, dynamicSmemSz, stream, use_vec16, use_vec8, input, output_q, output_s, hidden_dim, num_tokens);
204279 } else {
205- per_token_quant_fp8_kernel<scalar_t , __nv_fp8_e4m3, TOKENS_PER_CTA, 4 ><<<grid, block, 0 , stream>>> (
206- static_cast <const scalar_t *>(input.data_ptr ()),
207- static_cast <__nv_fp8_e4m3*>(output_q.data_ptr ()),
208- static_cast <float *>(output_s.data_ptr ()),
209- hidden_dim,
210- num_tokens);
280+ launch_per_token_quant_fp8_warp_kernel</* USE_SMEM=*/ false , scalar_t , TOKENS_PER_CTA>(
281+ grid, block, dynamicSmemSz, stream, use_vec16, use_vec8, input, output_q, output_s, hidden_dim, num_tokens);
211282 }
212283 } else {
213284 // -------- baseline -----------------------------------------------------
0 commit comments