diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index d6ab4b6740..c23a1c23f9 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -1049,6 +1049,76 @@ void performTest(float (*OP)(const float), compare_rowwise_amax(output, ref_amax); } +// Columnwise-only 2D NVFP4 must match the columnwise half of both-directions output +template +void performTestColumnwiseOnly2D(const std::vector& shape) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = DType::kFloat4E2M1; + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + // Columnwise (transposed) scale-tensor dimensions. + const std::array scale_dims_t = get_scale_tensor_dims(cols, rows, 1, 16); + const size_t unpadded_blocks_Y_t = scale_dims_t[0]; + const size_t unpadded_blocks_X_t = scale_dims_t[1]; + const size_t scales_stride_t = scale_dims_t[3]; + + Tensor input("input", shape, itype); + fillCase(&input, InputsFillCase::uniform); + + // Golden amax chosen so the 2nd-stage scaling mantissa is zero (avoids rounding noise). + const float golden_amax = 448.0f * 6.0f * 8.0f; + + // Reference: both directions produced in a single kernel call (rowwise + columnwise). + Tensor output_both("output_both", shape, otype, /*rowwise=*/true, /*columnwise=*/true, + NVTE_NVFP4_1D_SCALING); + output_both.cpu_rowwise_amax_ptr()[0] = golden_amax; + output_both.cpu_columnwise_amax_ptr()[0] = golden_amax; + output_both.from_cpu(); + + // System under test: columnwise-only output (no rowwise data allocated). + Tensor output_col("output_col", shape, otype, /*rowwise=*/false, /*columnwise=*/true, + NVTE_NVFP4_1D_SCALING); + output_col.cpu_columnwise_amax_ptr()[0] = golden_amax; + output_col.from_cpu(); + + QuantizationConfigWrapper quant_config; + quant_config.set_stochastic_rounding(false); + quant_config.set_nvfp4_2d_quantization(true); + + nvte_quantize_v2(input.data(), output_both.data(), quant_config, 0); + nvte_quantize_v2(input.data(), output_col.data(), quant_config, 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + output_both.to_cpu(); + output_col.to_cpu(); + + // Columnwise FP4 data must match bitwise (atol = rtol = 0). + compare_nvfp4_tensors("columnwise_only_data", + output_col.columnwise_cpu_dptr(), + output_both.columnwise_cpu_dptr(), + static_cast(cols), static_cast(rows), + /*atol=*/0.0, /*rtol=*/0.0); + + // Columnwise scale factors must match over the in-bounds region. + size_t scale_mismatches = 0; + compare_scaling_factors("columnwise_only_scales", + output_col.columnwise_cpu_scale_inv_ptr(), + output_both.columnwise_cpu_scale_inv_ptr(), + unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, + scale_mismatches); + ASSERT_EQ(scale_mismatches, 0u); + + // The columnwise-only tensor must not allocate rowwise output. + EXPECT_FALSE(output_col.rowwise()); +} + std::vector> tensor_dims = { {32, 32}, {32, 64}, @@ -1226,3 +1296,30 @@ INSTANTIATE_TEST_SUITE_P( [](const testing::TestParamInfo& info) { return test_name(info.param); }); + +class CastNVFP4ColumnwiseOnly2DTestSuite : public ::testing::TestWithParam> {}; + +TEST_P(CastNVFP4ColumnwiseOnly2DTestSuite, ColumnwiseOnlyMatchesBothDirections) { + // The optimized NVFP4 quantize-transpose kernel requires Blackwell. + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + performTestColumnwiseOnly2D(GetParam()); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CastNVFP4ColumnwiseOnly2DTestSuite, + // Include rectangular 128-multiple shapes to guard transposed data/scale indexing. + ::testing::Values( + std::vector{128, 128}, + std::vector{256, 512}, + std::vector{384, 1024}, + std::vector{2048, 256}), + [](const testing::TestParamInfo& info) { + std::string name; + for (const auto& s : info.param) { + name += "X" + std::to_string(s); + } + return name; + }); diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 12366b731e..4b472d83f4 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -590,3 +590,103 @@ def test_nvfp4_quantization_noncontiguous_inputs( torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # Aligned tiles + (128, 128), + (256, 256), + (512, 512), + (2048, 2048), + # Padded tiles (non-multiple of kTileDim=128) + (256, 272), + (304, 304), + (320, 256), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +def test_nvfp4_2d_columnwise_only_matches_both_directions( + x_dtype: torch.dtype, + M: int, + N: int, +): + """Bitwise check: 2D NVFP4 with columnwise-only must produce the same + columnwise data/scales as the columnwise half of (rowwise + columnwise) 2D. + + Covers both kernels depending on the (dtype, shape) routing: + - bf16 with rows % 32 == 0 and cols % 32 == 0 routes to the optimized + ``quantize_transpose_nvfp4_2D_kernel`` (instantiated with RETURN_ROWWISE=false), + validating that gating the rowwise pass/store leaves the shared + ``block_amax_matrix`` and columnwise output bitwise-identical to both-directions. + - non-bf16, or cols % 32 != 0, falls back to the columnwise-only 2D-amax-only + pass in ``quantize_transpose_vector_blockwise_fp4.cu``. + """ + te_dtype = tex.DType.kFloat4E2M1 + device = "cuda" + + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn((M, N), dtype=x_dtype, device=device) + + def _make_quantizer(*, rowwise: bool, columnwise: bool) -> NVFP4Quantizer: + return NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=rowwise, + columnwise=columnwise, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=True, + row_scaled_nvfp4=False, + ) + + # Reference: produce both directions in a single kernel call. + q_both = _make_quantizer(rowwise=True, columnwise=True) + out_both = q_both(x) + + # SUT: produce columnwise only (the path that hits the new amax-only pass). + q_col_only = _make_quantizer(rowwise=False, columnwise=True) + out_col_only = q_col_only(x) + + # Columnwise data/scales/amax must be bitwise identical between the two paths. + # If amax_smem is populated differently in the column-only path, scales diverge, + # and the FP4 cast (which divides by encode_scale) produces different bytes. + assert out_both._columnwise_data is not None + assert out_col_only._columnwise_data is not None + torch.testing.assert_close( + out_col_only._columnwise_data.view(dtype=torch.uint8), + out_both._columnwise_data.view(dtype=torch.uint8), + atol=0, + rtol=0, + ) + + # Compare only the valid (in-bounds) region of the columnwise scale tensor. + # The padded tail (rows K..round_up(K, 128), cols ceil(M/16)..round_up(.., 4)) + # exists for cuBLAS alignment and is NEVER written by the kernel — its bytes + # are whatever ``at::empty`` returned, which differs between two allocations. + NVFP4_BLOCK = 16 + valid_outer = N # cols of input == rows of columnwise scale tensor + valid_inner = (M + NVFP4_BLOCK - 1) // NVFP4_BLOCK + assert out_both._columnwise_scale_inv is not None + assert out_col_only._columnwise_scale_inv is not None + col_sx_both = out_both._columnwise_scale_inv.view(dtype=torch.uint8) + col_sx_col_only = out_col_only._columnwise_scale_inv.view(dtype=torch.uint8) + torch.testing.assert_close( + col_sx_col_only[:valid_outer, :valid_inner], + col_sx_both[:valid_outer, :valid_inner], + atol=0, + rtol=0, + ) + + assert out_both._amax_columnwise is not None + assert out_col_only._amax_columnwise is not None + torch.testing.assert_close( + out_col_only._amax_columnwise, out_both._amax_columnwise, atol=0, rtol=0 + ) + + # Sanity: column-only path must not allocate a rowwise output. + assert out_col_only._rowwise_data is None diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index bad53a03c6..6c71285cd4 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -113,8 +113,13 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, "Row-scaled NVFP4 quantization does not produce columnwise output."); nvfp4::compute_rowwise_amax(*input_tensor, noop_tensor, output_tensor, stream); } - bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - (cols % 32 == 0) && output_tensor->has_data(); + // Columnwise-only is supported on the optimized path only for 2D scaling; rowwise-only and + // both-directions keep their existing routing. Columnwise-only 1D and non-bf16 fall back to + // quantize_transpose_vector_blockwise_fp4. + bool use_optimized_kernel = + (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && + (output_tensor->has_data() || + (output_tensor->has_columnwise_data() && quant_config_cpp.nvfp4_2d_quantization)); // Launch NVFP4 quantize kernel if (nvfp4_use_4over6) { @@ -268,8 +273,13 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens "NVFP4 4over6 quantization does not support stochastic rounding."); NVTE_CHECK(!output_tensor->row_scaled_nvfp4, "Backward NVFP4 quantization does not support row-scaled outputs."); - bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && - (cols % 32 == 0) && output_tensor->has_data(); + // Columnwise-only is supported on the optimized path only for 2D scaling; rowwise-only and + // both-directions keep their existing routing. Columnwise-only 1D and non-bf16 fall back to + // quantize_transpose_vector_blockwise_fp4. + bool use_optimized_kernel = + (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && + (output_tensor->has_data() || + (output_tensor->has_columnwise_data() && quant_config_cpp.nvfp4_2d_quantization)); // Launch NVFP4 quantize kernel if (nvfp4_use_4over6) { diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index e5100ec86f..a4979f8b81 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -778,7 +778,7 @@ __global__ void __launch_bounds__(THREADS_NUM) } template + typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_ROWWISE, bool RETURN_TRANSPOSE> __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -1127,7 +1127,7 @@ __global__ void __launch_bounds__(THREADS_NUM) } // ROWWISE scaling - { + if constexpr (RETURN_ROWWISE) { const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; #pragma unroll for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) { @@ -1270,9 +1270,11 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t global_offset_Y_t = block_offset_Y_t; const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, - reinterpret_cast(&out_data_sh[buff_offset_out])); + if constexpr (RETURN_ROWWISE) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, + global_offset_Y, reinterpret_cast(&out_data_sh[buff_offset_out])); + } if constexpr (RETURN_TRANSPOSE) { ptx::cp_async_bulk_tensor_2d_shared_to_global( @@ -1326,6 +1328,9 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, // return the transposed data. // TODO(Frank): Is there a better way to do this? bool return_transpose = output->has_columnwise_data(); + // Columnwise-only (no rowwise output) is supported on the optimized 2D path; the rowwise pass + // and its store are gated out via the RETURN_ROWWISE template bool. + const bool return_rowwise = output->has_data(); if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) { quantize_transpose_tuned_1D(input, noop, output, quant_config, stream); @@ -1342,9 +1347,13 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, CheckOutputTensor(*output, "output", false); NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); - NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); - NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + NVTE_CHECK(return_rowwise || (return_transpose && use_2d_quantization), + "NVFP4 optimized kernel supports rowwise output (1D or 2D), or columnwise-only output " + "with 2D quantization."); + if (return_rowwise) { + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + } NVTE_CHECK(!row_scaled_nvfp4 || output->amax.dptr != nullptr, "Row-scaled NVFP4 quantization requires rowwise amax."); NVTE_CHECK(!row_scaled_nvfp4 || !output->has_columnwise_data(), @@ -1370,7 +1379,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, const dim3 grid(blocks_X, blocks_Y); const size_t block_size = THREADS_NUM; - const size_t scale_stride = output->scale_inv.shape[1]; + const size_t scale_stride = return_rowwise ? output->scale_inv.shape[1] : 0; const size_t scale_stride_transpose = return_transpose ? output->columnwise_scale_inv.shape[1] : 0; @@ -1403,8 +1412,10 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(IType) * 8); - create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, - 4); + if (return_rowwise) { + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, + 0, 4); + } if (return_transpose) { create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows, BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4); @@ -1431,21 +1442,26 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, TRANSFORMER_ENGINE_SWITCH_CONDITION(row_scaled_nvfp4, ROW_SCALED_NVFP4, { - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = quantize_transpose_nvfp4_kernel; - - if constexpr (use_2d_quantization) { - kernel = quantize_transpose_nvfp4_2D_kernel; - } + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_rowwise, RETURN_ROWWISE, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + // The 1D kernel always produces rowwise output (no RETURN_ROWWISE); the dispatch only + // routes columnwise-only requests here when use_2d_quantization is true. + auto kernel = quantize_transpose_nvfp4_kernel; + + if constexpr (use_2d_quantization) { + kernel = quantize_transpose_nvfp4_2D_kernel; + } - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, - scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, - scale_stride, scale_stride_transpose, rng_state); + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + }); }); });); #else diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 1596bb3fd4..d5f2fa9a2c 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -353,14 +353,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo extern __shared__ char smem_base[]; SMemVec* smem = reinterpret_cast(&smem_base[0]); - // 2D block scaling is not supported for E8 scaling MXFP4 or for colwise only mode. - // Instead of static_assert, return early if these invalid modes are detected. + // 2D block scaling is not supported for E8 scaling MXFP4. + // Instead of static_assert, return early if this invalid mode is detected. if constexpr (kIs2DBlockScaling && kIsE8Scaling) { return; } - if constexpr (kIs2DBlockScaling && !kReturnIdentity) { - return; - } // for 128x128 block, 2D block scaling means there will be 8x8 amax values for nvfp4, 4x4 for 2D mxfp4 // use constexpr to define the size, when not using 2D, use minimal size 1x1 constexpr int kFP4BlockScalingSize = 16; @@ -576,6 +573,67 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } } + // Step 2b: 2D-amax-only pass for columnwise-only mode. + // When only the transposed output is requested but 2D block scaling is enabled, the columnwise + // reads in Step 3 below still need amax_smem populated. Re-run the load + local-amax + // + 2D warp/smem reduction from Step 2 (steps 2.1-2.3), skipping the rowwise scale/quantize/store + // writes that Step 2 normally does. Same amax_smem values as the rowwise-enabled path, so the + // dgrad/wgrad columnwise output of (rowwise=False, columnwise=True, 2D) is bitwise identical to + // the columnwise half of (rowwise=True, columnwise=True, 2D). + if constexpr (!kReturnIdentity && kIs2DBlockScaling) { + constexpr int r_stride = + kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory + constexpr int num_iterations = kTileDim / r_stride; // 4 iterations for kTileDim=128 + const int c_s = + (threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + SMemVec smem_vec[kNVecOut / kNVecSMem]; + // Step 2.1 (amax-only): Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem_vec[i] = smem[r * kSMemCol + c]; + } + // Step 2.2 (amax-only): Compute local amax + CType amax = 0; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { +#pragma unroll + for (int j = 0; j < kNVecSMem; ++j) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j])); + } + } + // Step 2.3 (amax-only): 2D warp + smem amax reduction (mirrors Step 2's 2D path) + constexpr int kNumRowsPerIter = kThreadsPerBlock / kNumThreadsStore; // 32 + int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7 + int tid_in_warp_x = threadIdx.x % kNumThreadsStore; + int tid_in_warp_y = (threadIdx.x / kNumThreadsStore) % kNumRowsPerWarp; + CType amax_warp_reduced = groupMax( + amax, WARP_REDUCE_AMAX_GROUP_MASKS[tid_in_warp_x]); + int data_row_idx = iter * kNumRowsPerIter + warp_idx * kNumRowsPerWarp + tid_in_warp_y; + if (tid_in_warp_y == 0) { + amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] + [warp_idx % k2DBlockAmaxReduceDim] = amax_warp_reduced; + } + __syncthreads(); + + if (data_row_idx % kFP4BlockScalingSize == 0) { + CType amax_2d = 0.0; + for (int i = 0; i < k2DBlockAmaxReduceDim; i++) { + amax_2d = + fmaxf(amax_2d, amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x][i]); + } + amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] = amax_2d; + } + __syncthreads(); + r_s += r_stride; + } + } + // Step 3: Transpose, cast and store to output_t if constexpr (kReturnTranspose) { constexpr int c_stride = @@ -731,8 +789,6 @@ void quantize_transpose_vector_blockwise_fp4( NVTE_CHECK(return_identity || return_transpose, "At least one of return_identity or return_transpose must be true."); - NVTE_CHECK(return_identity || !use_2d_quantization, - "2D block quantization is only supported when return_identity is true."); NVTE_CHECK(!row_scaled_nvfp4 || (return_identity && !return_transpose), "Row-scaled NVFP4 quantization only supports rowwise quantization."); NVTE_CHECK(!row_scaled_nvfp4 || !use_2d_quantization,