@@ -158,6 +158,103 @@ def _sgl_kernel_scaled_fp4_quant_fake(
158158 "SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE" , "true"
159159)
160160
161+ # FP4 GEMM alignment constant - CUTLASS/FlashInfer kernels require dimensions divisible by 32
162+ FP4_GEMM_ALIGNMENT = 32
163+
164+
165+ def round_up_to_multiple (x : int , m : int ) -> int :
166+ """Round up x to the nearest multiple of m."""
167+ return (x + m - 1 ) // m * m
168+
169+
170+ def pad_nvfp4_weight (
171+ weight : torch .Tensor ,
172+ n_alignment : int = FP4_GEMM_ALIGNMENT ,
173+ k_alignment : int = FP4_GEMM_ALIGNMENT ,
174+ ) -> tuple [torch .Tensor , int ]:
175+ """
176+ Pad packed NVFP4 weights to satisfy alignment constraints for FP4 GEMM kernels.
177+
178+ Different backends have different alignment requirements:
179+ - CUTLASS/cuDNN: N % 32 == 0, K % 32 == 0
180+ - TRTLLM: N % 128 == 0 (for shuffle_matrix_sf_a), K padding handled separately
181+
182+ Args:
183+ weight: Packed FP4 weight tensor of shape [N, K//2] (2 FP4 values per byte)
184+ n_alignment: Required alignment for N dimension (default 32, use 128 for TRTLLM)
185+ k_alignment: Required alignment for K dimension (default 32, use 0 to skip)
186+
187+ Returns:
188+ Tuple of (padded_weight, weights_padding_cols) where weights_padding_cols
189+ is the number of columns added for K-dimension padding (in bytes).
190+ """
191+ weight_current_rows = weight .shape [0 ] # N dimension
192+ weight_current_col_bytes = weight .shape [1 ] # K//2 (packed)
193+
194+ # Calculate padding for N dimension (rows)
195+ pad_rows = 0
196+ if n_alignment > 0 and weight_current_rows % n_alignment != 0 :
197+ total_rows = round_up_to_multiple (weight_current_rows , n_alignment )
198+ pad_rows = total_rows - weight_current_rows
199+
200+ # Calculate padding for K dimension (columns)
201+ # 2 FP4 items are packed per byte in the input dimension
202+ weight_current_col_elements = weight_current_col_bytes * 2
203+ pad_cols_bytes = 0
204+ if k_alignment > 0 and weight_current_col_elements % k_alignment != 0 :
205+ total_cols = round_up_to_multiple (weight_current_col_elements , k_alignment )
206+ pad_cols = total_cols - weight_current_col_elements
207+ # pad_cols is in elements, but padding is in bytes (2 elements per byte)
208+ pad_cols_bytes = pad_cols // 2
209+
210+ # Apply padding in a single operation if needed
211+ # For 2D tensor, pad argument is (pad_left, pad_right, pad_top, pad_bottom)
212+ if pad_rows > 0 or pad_cols_bytes > 0 :
213+ weight = torch .nn .functional .pad (
214+ weight , (0 , pad_cols_bytes , 0 , pad_rows )
215+ ).contiguous ()
216+
217+ return weight , pad_cols_bytes
218+
219+
220+ def pad_nvfp4_activation_for_cutlass (
221+ x_fp4 : torch .Tensor ,
222+ weights_padding_cols : int ,
223+ ) -> torch .Tensor :
224+ """
225+ Pad packed FP4 activations to match the K-dimension padding applied to weights.
226+
227+ Args:
228+ x_fp4: Packed FP4 activation tensor
229+ weights_padding_cols: Number of padding columns (in bytes) from weight padding
230+
231+ Returns:
232+ Padded activation tensor
233+ """
234+ if weights_padding_cols > 0 :
235+ return torch .nn .functional .pad (x_fp4 , (0 , weights_padding_cols )).contiguous ()
236+ return x_fp4
237+
238+
239+ def slice_nvfp4_output (
240+ out : torch .Tensor ,
241+ output_size : int ,
242+ ) -> torch .Tensor :
243+ """
244+ Slice the output tensor to remove padding in N dimension if weight was padded.
245+
246+ Args:
247+ out: Output tensor from FP4 GEMM
248+ output_size: Original output size before padding
249+
250+ Returns:
251+ Sliced output tensor with padding removed
252+ """
253+ if out .shape [- 1 ] != output_size :
254+ return out [..., :output_size ].contiguous ()
255+ return out
256+
257+
161258# TODO make it true by default when the DeepEP PR is merged
162259MOE_NVFP4_DISPATCH = envs .SGLANG_MOE_NVFP4_DISPATCH .get ()
163260# Supported activation schemes for the current configuration
@@ -1059,36 +1156,75 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
10591156 layer .input_scale_inv = Parameter (
10601157 (1 / input_scale_2 ).to (torch .float32 ), requires_grad = False
10611158 )
1159+
1160+ # Store original output size before any padding
1161+ layer .output_size_per_partition = layer .weight .shape [0 ]
1162+
10621163 if get_fp4_gemm_runner_backend ().is_flashinfer_trtllm ():
10631164 # FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
10641165 # FlashInfer provides nvfp4_quantize to quantize + shuffle the
10651166 # layout but we use our own quantization so we have to call
10661167 # shuffles ourselves.
1168+ #
1169+ # Alignment requirements:
1170+ # - shuffle_matrix_a: weight.shape[0] (N) % 32 == 0
1171+ # - shuffle_matrix_sf_a: scale.shape[0] (N) % 128 == 0, scale.shape[1] (K/16) % 4 == 0
1172+ # We pad N to multiple of 128 and K/16 to multiple of 4.
10671173 from flashinfer import shuffle_matrix_a , shuffle_matrix_sf_a
10681174
1069- weight = layer .weight
1175+ # Pad weight N dimension to 128
1176+ weight , _ = pad_nvfp4_weight (
1177+ layer .weight .data , n_alignment = 128 , k_alignment = 0
1178+ )
1179+ # Pad scale N dimension to match weight
10701180 scale = layer .weight_scale
1181+ if scale .shape [0 ] != weight .shape [0 ]:
1182+ pad_n = weight .shape [0 ] - scale .shape [0 ]
1183+ scale = torch .nn .functional .pad (scale , (0 , 0 , 0 , pad_n ))
1184+
1185+ # Pad K dimension: scale K/16 must be multiple of 4
1186+ scale_k = scale .shape [1 ] # K/16
1187+ weights_padding_cols = 0
1188+ if scale_k % 4 != 0 :
1189+ padded_scale_k = round_up_to_multiple (scale_k , 4 )
1190+ pad_scale_k = padded_scale_k - scale_k
1191+ # Pad scale K/16 dimension
1192+ scale = torch .nn .functional .pad (scale , (0 , pad_scale_k , 0 , 0 ))
1193+ # Pad weight K/2 dimension correspondingly (K/2 = K/16 * 8)
1194+ pad_weight_k = pad_scale_k * 8
1195+ weight = torch .nn .functional .pad (weight , (0 , pad_weight_k , 0 , 0 ))
1196+ # Store K padding for activation padding in apply()
1197+ weights_padding_cols = pad_weight_k
1198+
1199+ # Shuffle for TRTLLM layout
10711200 epilogue_tile_m = 128
1201+ shuffled_scale_shape = scale .shape
10721202 weight = shuffle_matrix_a (weight .view (torch .uint8 ), epilogue_tile_m )
10731203 scale = (
10741204 shuffle_matrix_sf_a (scale .view (torch .uint8 ), epilogue_tile_m )
1075- .reshape (scale . shape )
1205+ .reshape (shuffled_scale_shape )
10761206 .view (torch .float8_e4m3fn )
10771207 )
10781208
10791209 layer .weight_scale_interleaved = Parameter (scale , requires_grad = False )
10801210 layer .weight = Parameter (weight , requires_grad = False )
1211+ layer .weights_padding_cols = weights_padding_cols
10811212 return
1213+
1214+ # Pad weights for CUTLASS/FlashInfer kernel alignment (K and N divisible by 32)
1215+ weight , weights_padding_cols = pad_nvfp4_weight (layer .weight .data )
1216+ layer .weights_padding_cols = weights_padding_cols
1217+ layer .weight = Parameter (weight , requires_grad = False )
1218+
10821219 # Pad and blockwise interleave weight_scale
10831220 scales = layer .weight_scale
10841221 scale_ndim = scales .ndim
10851222 if scale_ndim == 2 :
10861223 scales = scales .unsqueeze (0 )
10871224 assert scales .ndim == 3
10881225 B , M , K = scales .shape
1089- round_up_multiple = lambda x , m : (x + m - 1 ) // m * m
1090- M_padded = round_up_multiple (M , 128 )
1091- K_padded = round_up_multiple (K , 4 )
1226+ M_padded = round_up_to_multiple (M , 128 )
1227+ K_padded = round_up_to_multiple (K , 4 )
10921228 padded_scales = torch .zeros ((B , M_padded , K_padded ), dtype = scales .dtype )
10931229 padded_scales [:B , :M , :K ] = scales
10941230 batches , rows , cols = padded_scales .shape
@@ -1112,8 +1248,11 @@ def apply(
11121248 ) -> torch .Tensor :
11131249 output_dtype = x .dtype
11141250 x_m , _ = x .shape
1251+
1252+ # Get original output size (before padding) and padded weight size
1253+ output_size = layer .output_size_per_partition
11151254 w_n , _ = layer .weight .shape
1116- output_shape = [x_m , w_n ]
1255+ output_shape = [x_m , output_size ]
11171256
11181257 # Quantize BF16 or FP16 to (FP4 and interleaved block scale)
11191258 x_fp4 , x_scale_interleaved = fp4_quantize (x , layer .input_scale_inv )
@@ -1123,11 +1262,16 @@ def apply(
11231262 assert layer .weight_scale_interleaved .dtype == torch .float8_e4m3fn
11241263 assert layer .alpha .dtype == torch .float32
11251264
1265+ # Pad activations to match weight K-dimension padding
1266+ weights_padding_cols = getattr (layer , "weights_padding_cols" , 0 )
1267+ x_fp4 = pad_nvfp4_activation_for_cutlass (x_fp4 , weights_padding_cols )
1268+
11261269 w = layer .weight
11271270 w_scale_interleaved = layer .weight_scale_interleaved
11281271 if enable_flashinfer_fp4_gemm :
11291272 w = layer .weight .T
11301273 w_scale_interleaved = layer .weight_scale_interleaved .T
1274+
11311275 out = fp4_gemm (
11321276 x_fp4 ,
11331277 w ,
@@ -1137,6 +1281,10 @@ def apply(
11371281 output_dtype ,
11381282 w_n ,
11391283 )
1284+
1285+ # Slice output to remove N-dimension padding
1286+ out = slice_nvfp4_output (out , output_size )
1287+
11401288 if bias is not None :
11411289 out = out + bias
11421290 return out .view (* output_shape )
0 commit comments