Skip to content

Commit ca25134

Browse files
danielafrimiJohnsonms
authored andcommitted
[FIX] Always support TP > 4 for FP4 Gemm (sgl-project#17300)
1 parent e14a4ba commit ca25134

File tree

1 file changed

+154
-6
lines changed

1 file changed

+154
-6
lines changed

python/sglang/srt/layers/quantization/modelopt_quant.py

Lines changed: 154 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,103 @@ def _sgl_kernel_scaled_fp4_quant_fake(
158158
"SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true"
159159
)
160160

161+
# FP4 GEMM alignment constant - CUTLASS/FlashInfer kernels require dimensions divisible by 32
162+
FP4_GEMM_ALIGNMENT = 32
163+
164+
165+
def round_up_to_multiple(x: int, m: int) -> int:
166+
"""Round up x to the nearest multiple of m."""
167+
return (x + m - 1) // m * m
168+
169+
170+
def pad_nvfp4_weight(
171+
weight: torch.Tensor,
172+
n_alignment: int = FP4_GEMM_ALIGNMENT,
173+
k_alignment: int = FP4_GEMM_ALIGNMENT,
174+
) -> tuple[torch.Tensor, int]:
175+
"""
176+
Pad packed NVFP4 weights to satisfy alignment constraints for FP4 GEMM kernels.
177+
178+
Different backends have different alignment requirements:
179+
- CUTLASS/cuDNN: N % 32 == 0, K % 32 == 0
180+
- TRTLLM: N % 128 == 0 (for shuffle_matrix_sf_a), K padding handled separately
181+
182+
Args:
183+
weight: Packed FP4 weight tensor of shape [N, K//2] (2 FP4 values per byte)
184+
n_alignment: Required alignment for N dimension (default 32, use 128 for TRTLLM)
185+
k_alignment: Required alignment for K dimension (default 32, use 0 to skip)
186+
187+
Returns:
188+
Tuple of (padded_weight, weights_padding_cols) where weights_padding_cols
189+
is the number of columns added for K-dimension padding (in bytes).
190+
"""
191+
weight_current_rows = weight.shape[0] # N dimension
192+
weight_current_col_bytes = weight.shape[1] # K//2 (packed)
193+
194+
# Calculate padding for N dimension (rows)
195+
pad_rows = 0
196+
if n_alignment > 0 and weight_current_rows % n_alignment != 0:
197+
total_rows = round_up_to_multiple(weight_current_rows, n_alignment)
198+
pad_rows = total_rows - weight_current_rows
199+
200+
# Calculate padding for K dimension (columns)
201+
# 2 FP4 items are packed per byte in the input dimension
202+
weight_current_col_elements = weight_current_col_bytes * 2
203+
pad_cols_bytes = 0
204+
if k_alignment > 0 and weight_current_col_elements % k_alignment != 0:
205+
total_cols = round_up_to_multiple(weight_current_col_elements, k_alignment)
206+
pad_cols = total_cols - weight_current_col_elements
207+
# pad_cols is in elements, but padding is in bytes (2 elements per byte)
208+
pad_cols_bytes = pad_cols // 2
209+
210+
# Apply padding in a single operation if needed
211+
# For 2D tensor, pad argument is (pad_left, pad_right, pad_top, pad_bottom)
212+
if pad_rows > 0 or pad_cols_bytes > 0:
213+
weight = torch.nn.functional.pad(
214+
weight, (0, pad_cols_bytes, 0, pad_rows)
215+
).contiguous()
216+
217+
return weight, pad_cols_bytes
218+
219+
220+
def pad_nvfp4_activation_for_cutlass(
221+
x_fp4: torch.Tensor,
222+
weights_padding_cols: int,
223+
) -> torch.Tensor:
224+
"""
225+
Pad packed FP4 activations to match the K-dimension padding applied to weights.
226+
227+
Args:
228+
x_fp4: Packed FP4 activation tensor
229+
weights_padding_cols: Number of padding columns (in bytes) from weight padding
230+
231+
Returns:
232+
Padded activation tensor
233+
"""
234+
if weights_padding_cols > 0:
235+
return torch.nn.functional.pad(x_fp4, (0, weights_padding_cols)).contiguous()
236+
return x_fp4
237+
238+
239+
def slice_nvfp4_output(
240+
out: torch.Tensor,
241+
output_size: int,
242+
) -> torch.Tensor:
243+
"""
244+
Slice the output tensor to remove padding in N dimension if weight was padded.
245+
246+
Args:
247+
out: Output tensor from FP4 GEMM
248+
output_size: Original output size before padding
249+
250+
Returns:
251+
Sliced output tensor with padding removed
252+
"""
253+
if out.shape[-1] != output_size:
254+
return out[..., :output_size].contiguous()
255+
return out
256+
257+
161258
# TODO make it true by default when the DeepEP PR is merged
162259
MOE_NVFP4_DISPATCH = envs.SGLANG_MOE_NVFP4_DISPATCH.get()
163260
# Supported activation schemes for the current configuration
@@ -1059,36 +1156,75 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
10591156
layer.input_scale_inv = Parameter(
10601157
(1 / input_scale_2).to(torch.float32), requires_grad=False
10611158
)
1159+
1160+
# Store original output size before any padding
1161+
layer.output_size_per_partition = layer.weight.shape[0]
1162+
10621163
if get_fp4_gemm_runner_backend().is_flashinfer_trtllm():
10631164
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
10641165
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
10651166
# layout but we use our own quantization so we have to call
10661167
# shuffles ourselves.
1168+
#
1169+
# Alignment requirements:
1170+
# - shuffle_matrix_a: weight.shape[0] (N) % 32 == 0
1171+
# - shuffle_matrix_sf_a: scale.shape[0] (N) % 128 == 0, scale.shape[1] (K/16) % 4 == 0
1172+
# We pad N to multiple of 128 and K/16 to multiple of 4.
10671173
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
10681174

1069-
weight = layer.weight
1175+
# Pad weight N dimension to 128
1176+
weight, _ = pad_nvfp4_weight(
1177+
layer.weight.data, n_alignment=128, k_alignment=0
1178+
)
1179+
# Pad scale N dimension to match weight
10701180
scale = layer.weight_scale
1181+
if scale.shape[0] != weight.shape[0]:
1182+
pad_n = weight.shape[0] - scale.shape[0]
1183+
scale = torch.nn.functional.pad(scale, (0, 0, 0, pad_n))
1184+
1185+
# Pad K dimension: scale K/16 must be multiple of 4
1186+
scale_k = scale.shape[1] # K/16
1187+
weights_padding_cols = 0
1188+
if scale_k % 4 != 0:
1189+
padded_scale_k = round_up_to_multiple(scale_k, 4)
1190+
pad_scale_k = padded_scale_k - scale_k
1191+
# Pad scale K/16 dimension
1192+
scale = torch.nn.functional.pad(scale, (0, pad_scale_k, 0, 0))
1193+
# Pad weight K/2 dimension correspondingly (K/2 = K/16 * 8)
1194+
pad_weight_k = pad_scale_k * 8
1195+
weight = torch.nn.functional.pad(weight, (0, pad_weight_k, 0, 0))
1196+
# Store K padding for activation padding in apply()
1197+
weights_padding_cols = pad_weight_k
1198+
1199+
# Shuffle for TRTLLM layout
10711200
epilogue_tile_m = 128
1201+
shuffled_scale_shape = scale.shape
10721202
weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
10731203
scale = (
10741204
shuffle_matrix_sf_a(scale.view(torch.uint8), epilogue_tile_m)
1075-
.reshape(scale.shape)
1205+
.reshape(shuffled_scale_shape)
10761206
.view(torch.float8_e4m3fn)
10771207
)
10781208

10791209
layer.weight_scale_interleaved = Parameter(scale, requires_grad=False)
10801210
layer.weight = Parameter(weight, requires_grad=False)
1211+
layer.weights_padding_cols = weights_padding_cols
10811212
return
1213+
1214+
# Pad weights for CUTLASS/FlashInfer kernel alignment (K and N divisible by 32)
1215+
weight, weights_padding_cols = pad_nvfp4_weight(layer.weight.data)
1216+
layer.weights_padding_cols = weights_padding_cols
1217+
layer.weight = Parameter(weight, requires_grad=False)
1218+
10821219
# Pad and blockwise interleave weight_scale
10831220
scales = layer.weight_scale
10841221
scale_ndim = scales.ndim
10851222
if scale_ndim == 2:
10861223
scales = scales.unsqueeze(0)
10871224
assert scales.ndim == 3
10881225
B, M, K = scales.shape
1089-
round_up_multiple = lambda x, m: (x + m - 1) // m * m
1090-
M_padded = round_up_multiple(M, 128)
1091-
K_padded = round_up_multiple(K, 4)
1226+
M_padded = round_up_to_multiple(M, 128)
1227+
K_padded = round_up_to_multiple(K, 4)
10921228
padded_scales = torch.zeros((B, M_padded, K_padded), dtype=scales.dtype)
10931229
padded_scales[:B, :M, :K] = scales
10941230
batches, rows, cols = padded_scales.shape
@@ -1112,8 +1248,11 @@ def apply(
11121248
) -> torch.Tensor:
11131249
output_dtype = x.dtype
11141250
x_m, _ = x.shape
1251+
1252+
# Get original output size (before padding) and padded weight size
1253+
output_size = layer.output_size_per_partition
11151254
w_n, _ = layer.weight.shape
1116-
output_shape = [x_m, w_n]
1255+
output_shape = [x_m, output_size]
11171256

11181257
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
11191258
x_fp4, x_scale_interleaved = fp4_quantize(x, layer.input_scale_inv)
@@ -1123,11 +1262,16 @@ def apply(
11231262
assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn
11241263
assert layer.alpha.dtype == torch.float32
11251264

1265+
# Pad activations to match weight K-dimension padding
1266+
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
1267+
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
1268+
11261269
w = layer.weight
11271270
w_scale_interleaved = layer.weight_scale_interleaved
11281271
if enable_flashinfer_fp4_gemm:
11291272
w = layer.weight.T
11301273
w_scale_interleaved = layer.weight_scale_interleaved.T
1274+
11311275
out = fp4_gemm(
11321276
x_fp4,
11331277
w,
@@ -1137,6 +1281,10 @@ def apply(
11371281
output_dtype,
11381282
w_n,
11391283
)
1284+
1285+
# Slice output to remove N-dimension padding
1286+
out = slice_nvfp4_output(out, output_size)
1287+
11401288
if bias is not None:
11411289
out = out + bias
11421290
return out.view(*output_shape)

0 commit comments

Comments
 (0)