diff --git a/docs/envvars.rst b/docs/envvars.rst index 1e040b4c3e..58988b5473 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -281,6 +281,12 @@ Kernel Configuration :Default: ``0`` :Description: Emit a warning when falling back from CUTLASS to cuBLAS for grouped GEMM operations. +.. envvar:: NVTE_NVFP4_PER_TOKEN_ACTIVATION + + :Type: ``int`` (0 or 1) + :Default: ``0`` + :Description: Enable per-token activation quantization for the ``NVFP4BlockScaling`` recipe in GroupedLinear split-quantize paths. When set to ``1`` (or when ``NVFP4BlockScaling(per_token_activation=True)`` is used), NVFP4 rowwise ``amax`` metadata stores one FP32 value per token (row) instead of a single scalar. + Torch Compilation and Fusion ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 15d7c695c9..f7a16539cc 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -114,16 +114,14 @@ void quantize_nvfp4_1d(float (*OP)(const float), block_amax = std::max(block_amax, std::abs(elt)); } - // 2. Compute E4M3 scaling factor - // Compute per-block encoding/decoding scaling factor - const float S_dec_b = block_amax / 6.0f; - - // Scale & Store per-block decoding scaling factor - const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // Compute and store the per-block FP8 decode scale + const float S_dec_b = block_amax * (S_enc * (1.0f / 6.0f)); + const fp8e4m3 S_dec_b_fp8 = static_cast(fminf(S_dec_b, Numeric_Traits::maxNorm)); const float S_dec_b_fp32 = static_cast(S_dec_b_fp8); // Compute "correct" per-block encoding scaling factor - const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32; + const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : + fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits::maxNorm); const size_t scale_idx = i * scales_stride + block_X; scales[scale_idx] = S_dec_b_fp8; @@ -317,11 +315,69 @@ void compute_ref(float (*OP)(const float), const size_t scales_stride, const size_t scales_stride_t, const bool use_fast_math, - const bool use_2d_quantization = false) + const bool use_2d_quantization = false, + std::vector *per_token_amax = nullptr) { std::vector input_t = create_transpose(input, rows, cols); - if (use_2d_quantization) { + if (per_token_amax != nullptr) { + constexpr size_t kBlockSize = 16; + constexpr float fp4_max_inv = 1.0f / 6.0f; + constexpr float float_max = Numeric_Traits::maxNorm; + + per_token_amax->resize(rows, 0.0f); + for (size_t row = 0; row < rows; ++row) { + float row_amax = 0.0f; + for (size_t col = 0; col < cols; ++col) { + row_amax = fmaxf(row_amax, fabsf(static_cast(input[row * cols + col]))); + } + (*per_token_amax)[row] = row_amax; + quantize_nvfp4(OP, + input + row * cols, + output + row * (cols / 2), + scales + row * scales_stride, + 1, + cols, + scales_stride, + row_amax, + use_fast_math, + use_2d_quantization); + } + + for (size_t col = 0; col < cols; ++col) { + for (size_t row_start = 0; row_start < rows; row_start += kBlockSize) { + float vals[kBlockSize]; + float s_enc[kBlockSize]; + float scaled_block_amax = 0.0f; + for (size_t i = 0; i < kBlockSize; ++i) { + const size_t row = row_start + i; + const float val = static_cast(input[row * cols + col]); + const float S_enc = + compute_global_encode_scaling_factor_FP4((*per_token_amax)[row], false); + vals[i] = val; + s_enc[i] = S_enc; + scaled_block_amax = fmaxf(scaled_block_amax, fabsf(val) * (S_enc * fp4_max_inv)); + } + + const float S_dec_b_f32 = fminf(scaled_block_amax, float_max); + const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b_f32); + scales_t[col * scales_stride_t + row_start / kBlockSize] = S_dec_b_fp8; + + for (size_t i = 0; i < kBlockSize; i += 2) { + const float S_dec_rowwise_x = 1.0f / s_enc[i]; + const float S_dec_rowwise_y = 1.0f / s_enc[i + 1]; + const float S_dec_b_fp32 = static_cast(S_dec_b_fp8); + const float S_enc_b_fp8_x = + fminf(1.0f / (S_dec_b_fp32 * S_dec_rowwise_x), float_max); + const float S_enc_b_fp8_y = + fminf(1.0f / (S_dec_b_fp32 * S_dec_rowwise_y), float_max); + const float2 scaled_elt_pair = {vals[i] * S_enc_b_fp8_x, + vals[i + 1] * S_enc_b_fp8_y}; + output_t[(col * rows + row_start + i) / 2] = fp4e2m1x2(scaled_elt_pair); + } + } + } + } else if (use_2d_quantization) { // Step 1: Compute mathematical 8×8 scaling factors std::vector> math_scales; compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math); @@ -526,10 +582,56 @@ void compareResults_nvfp4(const Tensor &test, compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol); } +void compare_per_token_amax(const Tensor &output, const std::vector &ref_amax) { + NVTEBasicTensor amax; + nvte_get_tensor_param_v2(output.data(), kNVTEAmax, &amax, sizeof(amax), nullptr); + ASSERT_NE(amax.data_ptr, nullptr); + ASSERT_EQ(amax.shape.ndim, 1); + ASSERT_EQ(amax.shape.data[0], ref_amax.size()); + + std::vector test_amax_data(ref_amax.size()); + ASSERT_EQ(cudaMemcpy(test_amax_data.data(), + amax.data_ptr, + ref_amax.size() * sizeof(float), + cudaMemcpyDeviceToHost), + cudaSuccess); + for (size_t row = 0; row < ref_amax.size(); ++row) { + ASSERT_EQ(test_amax_data[row], ref_amax[row]) + << "Per-token amax mismatch at row " << row; + } +} + +void set_per_token_amax_metadata(Tensor &output, const size_t rows) { + const std::vector shape = {rows}; + NVTETensor output_tensor = output.data(); + + auto replace_amax = [&](const NVTETensorParam param) { + NVTEBasicTensor old_amax; + nvte_get_tensor_param_v2(output_tensor, param, &old_amax, sizeof(old_amax), nullptr); + if (old_amax.data_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaFree(old_amax.data_ptr)); + } + + float *amax = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&amax, rows * sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(amax, 0, rows * sizeof(float))); + + NVTEBasicTensor amax_tensor = {amax, + static_cast(DType::kFloat32), + nvte_make_shape(shape.data(), shape.size())}; + nvte_set_tensor_param_v2(output_tensor, param, &amax_tensor, sizeof(amax_tensor)); + return amax; + }; + + replace_amax(kNVTEAmax); + replace_amax(kNVTEColumnwiseAmax); +} + template void performTest(float (*OP)(const float), const std::vector& shape, - const bool use_fast_math) { + const bool use_fast_math, + const bool per_token_activation = false) { using namespace test; DType itype = TypeInfo::dtype; @@ -567,26 +669,43 @@ void performTest(float (*OP)(const float), // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues const float amax = 448.0f * 6.0f * 8.0f; - - // Set 2nd stage NVFP4 scaling factor - output.set_tensor_amax(amax); - output.set_tensor_amax_columnwise(amax); - + std::vector ref_per_token_amax; bool use_2d_quantization = false; + if (per_token_activation) { + set_per_token_amax_metadata(output, rows); + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + 0.0f, + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math, + use_2d_quantization, + &ref_per_token_amax); + } else { + // Set 2nd stage NVFP4 scaling factor + output.set_tensor_amax(amax); + output.set_tensor_amax_columnwise(amax); + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + amax, + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math, + use_2d_quantization); + } - compute_ref(OP, - input.rowwise_cpu_dptr(), - ref_output.get(), - ref_output_t.get(), - ref_scales.get(), - ref_scales_t.get(), - amax, - rows, - cols, - scales_stride, - scales_stride_t, - use_fast_math, - use_2d_quantization); // Initialize stochastic rounding Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); rng_state.rowwise_cpu_dptr()[0] = 123; // rng_seed @@ -600,6 +719,7 @@ void performTest(float (*OP)(const float), // Set 2D quantization based on compile-time flag quant_config.set_nvfp4_2d_quantization(use_2d_quantization); + quant_config.set_nvfp4_per_token_activation(per_token_activation); // Call appropriate function based on operation type // Activation functions take 3 parameters (input, output, stream) @@ -646,6 +766,10 @@ void performTest(float (*OP)(const float), ref_scales_t.get(), unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, scale_mismatches_num); + + if (per_token_activation) { + compare_per_token_amax(output, ref_per_token_amax); + } } std::vector> tensor_dims = { @@ -678,6 +802,7 @@ class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam , transformer_engine::DType, + bool, bool>> {}; TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { @@ -693,6 +818,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { const auto tensor_dims = std::get<1>(GetParam()); const DType input_type = std::get<2>(GetParam()); const bool use_fast_math = std::get<3>(GetParam()); + const bool per_token_activation = std::get<4>(GetParam()); // Skip tests if the input tensor is 1D if (tensor_dims.size() < 2) { @@ -710,7 +836,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { } TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - performTest(OP, tensor_dims, use_fast_math); + performTest(OP, tensor_dims, use_fast_math, per_token_activation); ); } @@ -733,6 +859,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(Activation_types), ::testing::ValuesIn(tensor_dims), ::testing::Values(DType::kBFloat16), + ::testing::Values(false), ::testing::Values(false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)); @@ -746,3 +873,28 @@ INSTANTIATE_TEST_SUITE_P( } return name; }); + +INSTANTIATE_TEST_SUITE_P( + OperatorTestPerToken, + FusedCastTransposeNVFP4TestSuite, + ::testing::Combine( + ::testing::Values(ActivationType::Identity), + ::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]), + ::testing::Values(DType::kBFloat16, DType::kFloat32), + ::testing::Values(false), + ::testing::Values(true)), + [](const testing::TestParamInfo& info) { + std::string name = to_string(std::get<0>(info.param)); + const auto& shape = std::get<1>(info.param); + for (const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<2>(info.param)); + if (std::get<3>(info.param)) { + name += "X_FAST_SCALING"; + } + if (std::get<4>(info.param)) { + name += "XPER_TOKEN"; + } + return name; + }); diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index 96e85cb5ed..f932c7dd7a 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -42,7 +42,7 @@ float2 cvt_fp4x2_to_float2(fp4e2m1x2 fp4_pair) { template void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, const fp8e4m3 *scales, - float amax, + const std::vector &amax, OType *output, size_t rows, size_t cols, @@ -55,7 +55,8 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, for (size_t row = 0; row < rows; ++row) { for (size_t block = 0; block < Mread; ++block) { const fp8e4m3 scale = scales[row * scale_stride + block]; - const float final_scale = static_cast(scale) * amax * factor_inv; + const float final_scale = + static_cast(scale) * (amax.size() == 1 ? amax[0] : amax[row]) * factor_inv; for (size_t pair_idx = 0; pair_idx < bytes_per_block; ++pair_idx) { const size_t byte_idx = @@ -74,6 +75,43 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, } } +void set_per_token_amax_metadata(Tensor &output, const size_t rows) { + const std::vector shape = {rows}; + NVTETensor output_tensor = output.data(); + + auto replace_amax = [&](const NVTETensorParam param) { + NVTEBasicTensor old_amax; + nvte_get_tensor_param_v2(output_tensor, param, &old_amax, sizeof(old_amax), nullptr); + if (old_amax.data_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaFree(old_amax.data_ptr)); + } + + float *amax = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&amax, rows * sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(amax, 0, rows * sizeof(float))); + + NVTEBasicTensor amax_tensor = {amax, + static_cast(DType::kFloat32), + nvte_make_shape(shape.data(), shape.size())}; + nvte_set_tensor_param_v2(output_tensor, param, &amax_tensor, sizeof(amax_tensor)); + }; + + replace_amax(kNVTEAmax); + replace_amax(kNVTEColumnwiseAmax); +} + +std::vector get_amax_values(const Tensor &tensor) { + NVTEBasicTensor amax; + nvte_get_tensor_param_v2(tensor.data(), kNVTEAmax, &amax, sizeof(amax), nullptr); + const size_t numel = amax.shape.ndim == 0 ? 1 : amax.shape.data[0]; + std::vector amax_values(numel); + if (numel > 0) { + NVTE_CHECK_CUDA(cudaMemcpy(amax_values.data(), amax.data_ptr, numel * sizeof(float), + cudaMemcpyDeviceToHost)); + } + return amax_values; +} + template float compute_amax(const test::Tensor &t, size_t rows, size_t cols) { t.to_cpu(); @@ -88,7 +126,8 @@ float compute_amax(const test::Tensor &t, size_t rows, size_t cols) { // Quantize a high-precision input to NVFP4, then dequantize and compare // against a CPU reference computed from the quantized data. template -void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { +void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, + const bool per_token_activation) { using namespace test; DType otype = TypeInfo::dtype; @@ -97,14 +136,22 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { Tensor quantized("quantized", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - if (rows > 0 && cols > 0) { + if (per_token_activation) { + set_per_token_amax_metadata(quantized, rows); + } else if (rows > 0 && cols > 0) { quantized.set_tensor_amax(compute_amax(input, rows, cols)); } else { quantized.set_tensor_amax(0.0f); } if (rows > 0 && cols > 0) { - nvte_quantize(input.data(), quantized.data(), 0); + if (per_token_activation) { + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_per_token_activation(true); + nvte_quantize_v2(input.data(), quantized.data(), quant_config, 0); + } else { + nvte_quantize(input.data(), quantized.data(), 0); + } cudaDeviceSynchronize(); } @@ -120,7 +167,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { const uint8_t *fp4_data = reinterpret_cast(quantized.rowwise_cpu_dptr()); const fp8e4m3 *scales = quantized.rowwise_cpu_scale_inv_ptr(); - const float amax_val = quantized.amax(); + const std::vector amax_val = get_amax_values(quantized); const NVTEShape scale_shape = quantized.rowwise_scale_inv_shape(); const size_t scale_stride = scale_shape.data[scale_shape.ndim - 1]; @@ -137,7 +184,8 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { // Dequantize NVFP4 with GEMM-swizzled scales and compare against compact path. template -void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols) { +void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, + const bool per_token_activation) { using namespace test; DType otype = TypeInfo::dtype; @@ -146,14 +194,22 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols) Tensor quantized_compact("quantized_compact", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - if (rows > 0 && cols > 0) { + if (per_token_activation) { + set_per_token_amax_metadata(quantized_compact, rows); + } else if (rows > 0 && cols > 0) { quantized_compact.set_tensor_amax(compute_amax(input, rows, cols)); } else { quantized_compact.set_tensor_amax(0.0f); } if (rows > 0 && cols > 0) { - nvte_quantize(input.data(), quantized_compact.data(), 0); + if (per_token_activation) { + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_per_token_activation(true); + nvte_quantize_v2(input.data(), quantized_compact.data(), quant_config, 0); + } else { + nvte_quantize(input.data(), quantized_compact.data(), 0); + } cudaDeviceSynchronize(); } @@ -165,13 +221,30 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols) // Create tensor with same FP4 data but swizzled scales Tensor quantized_swizzled("quantized_swizzled", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - quantized_swizzled.set_tensor_amax(0.0f); + if (per_token_activation) { + set_per_token_amax_metadata(quantized_swizzled, rows); + } else { + quantized_swizzled.set_tensor_amax(0.0f); + } quantized_swizzled.set_with_gemm_swizzled_scales(true); // Copy amax and scale from compact to swizzled before FP4 data, // since from_cpu() uploads all CPU buffers (including zero-init data). quantized_compact.to_cpu(); - quantized_swizzled.set_tensor_amax(quantized_compact.amax()); + if (per_token_activation) { + NVTEBasicTensor compact_amax; + NVTEBasicTensor swizzled_amax; + nvte_get_tensor_param_v2(quantized_compact.data(), kNVTEAmax, &compact_amax, + sizeof(compact_amax), nullptr); + nvte_get_tensor_param_v2(quantized_swizzled.data(), kNVTEAmax, &swizzled_amax, + sizeof(swizzled_amax), nullptr); + if (rows > 0) { + NVTE_CHECK_CUDA(cudaMemcpy(swizzled_amax.data_ptr, compact_amax.data_ptr, + rows * sizeof(float), cudaMemcpyDeviceToDevice)); + } + } else { + quantized_swizzled.set_tensor_amax(quantized_compact.amax()); + } // Copy FP4 data after from_cpu() to avoid being overwritten const size_t data_bytes = rows * cols / 2; @@ -227,7 +300,8 @@ std::vector> nvfp4_tensor_dims = { class DequantizeNVFP4TestSuite : public ::testing::TestWithParam , - transformer_engine::DType>> {}; + transformer_engine::DType, + bool>> {}; TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) { @@ -237,10 +311,11 @@ TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); + const bool per_token_activation = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4( - tensor_size.first, tensor_size.second); + tensor_size.first, tensor_size.second, per_token_activation); ); } @@ -249,19 +324,22 @@ INSTANTIATE_TEST_SUITE_P( DequantizeNVFP4TestSuite, ::testing::Combine( ::testing::ValuesIn(nvfp4_tensor_dims), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Bool()), [](const testing::TestParamInfo& info) { std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + - test::typeName(std::get<1>(info.param)); + test::typeName(std::get<1>(info.param)) + "X" + + (std::get<2>(info.param) ? "PerToken" : "PerTensor"); return name; } ); class DequantizeNVFP4SwizzledTestSuite : public ::testing::TestWithParam , - transformer_engine::DType>> {}; + transformer_engine::DType, + bool>> {}; TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) { @@ -271,10 +349,11 @@ TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); + const bool per_token_activation = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4_swizzled( - tensor_size.first, tensor_size.second); + tensor_size.first, tensor_size.second, per_token_activation); ); } @@ -283,12 +362,14 @@ INSTANTIATE_TEST_SUITE_P( DequantizeNVFP4SwizzledTestSuite, ::testing::Combine( ::testing::ValuesIn(nvfp4_tensor_dims), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Bool()), [](const testing::TestParamInfo& info) { std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + test::typeName(std::get<1>(info.param)) + "X" + + (std::get<2>(info.param) ? "PerToken" : "PerTensor") + "X" + "Swizzled"; return name; } diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 911b7660dc..ef6eda8dcd 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -8,6 +8,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils @@ -26,6 +27,7 @@ def check_nvfp4_gemm_versus_reference( *, x_columnwise: bool = False, w_columnwise: bool = False, + per_token_activation: bool = False, ): te_dtype = tex.DType.kFloat4E2M1 @@ -56,6 +58,7 @@ def check_nvfp4_gemm_versus_reference( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + per_token_activation=per_token_activation, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -112,7 +115,16 @@ def check_nvfp4_gemm_versus_reference( sw_trimmed = sw_trimmed.view(torch.float8_e4m3fn) # Create reference quantizer for reference GEMM - ref_quantizer = NVFP4QuantizerRef( + x_ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=True, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + per_token_activation=per_token_activation, + ) + w_ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, columnwise=True, @@ -124,16 +136,16 @@ def check_nvfp4_gemm_versus_reference( # Create reference quantized tensors needed by reference GEMM # Reference GEMM is only rowwise. if x_columnwise: - x_nvfp4_ref = ref_quantizer.quantize(x.t().contiguous()) + x_nvfp4_ref = x_ref_quantizer.quantize(x.t().contiguous()) else: - x_nvfp4_ref = ref_quantizer.quantize(x) + x_nvfp4_ref = x_ref_quantizer.quantize(x) if w_columnwise: - w_nvfp4_ref = ref_quantizer.quantize(w.t().contiguous()) + w_nvfp4_ref = w_ref_quantizer.quantize(w.t().contiguous()) else: - w_nvfp4_ref = ref_quantizer.quantize(w) + w_nvfp4_ref = w_ref_quantizer.quantize(w) # Reference GEMM using quantizer's qgemm method - y_ref = ref_quantizer.qgemm( + y_ref = x_ref_quantizer.qgemm( qx=qx_data, qw=qw_data, m_params=None, # MMParams not used in reference @@ -166,27 +178,38 @@ def check_nvfp4_gemm_versus_reference( x_nvfp4_native.update_usage(rowwise_usage=False) if w_columnwise: w_nvfp4_native.update_usage(rowwise_usage=False) - # Native cuBLAS GEMM - # return type is out, bias_grad, gelu_input, extra_output - # We are just capturing out. - y_native = tex.generic_gemm( - w_nvfp4_native, - transa, - x_nvfp4_native, - transb, - out.clone() if accumulate else None, - out_quantizer, - TE_DType[out_dtype], - bias, - bias_dtype, - use_gelu, - gelu_input, - use_grad, - workspace, - workspace.shape[0], - accumulate, - use_split_accumulator, - )[0] + if per_token_activation: + layout = ("T" if transa else "N") + ("T" if transb else "N") + y_native = general_gemm( + w_nvfp4_native, + x_nvfp4_native, + out_dtype=out_dtype, + accumulate=accumulate, + layout=layout, + out=out.clone() if accumulate else None, + )[0] + else: + # Native cuBLAS GEMM + # return type is out, bias_grad, gelu_input, extra_output + # We are just capturing out. + y_native = tex.generic_gemm( + w_nvfp4_native, + transa, + x_nvfp4_native, + transb, + out.clone() if accumulate else None, + out_quantizer, + TE_DType[out_dtype], + bias, + bias_dtype, + use_gelu, + gelu_input, + use_grad, + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + )[0] # just in case of accumulation, make sure y_ref and y_native are not the same tensor assert y_ref is not y_native, "y_ref and y_native should not be the same tensor" @@ -199,6 +222,98 @@ def check_nvfp4_gemm_versus_reference( torch.testing.assert_close(y_native, y_ref, atol=8e-3, rtol=8e-3) +def check_nvfp4_per_token_grouped_gemm_matches_per_gemm( + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + m_splits: list[int], + k: int, + n: int, + *, + use_bias: bool, + single_output: bool, +): + te_dtype = tex.DType.kFloat4E2M1 + device = "cuda" + torch.manual_seed(23) + torch.cuda.manual_seed(23) + + num_gemms = len(m_splits) + + x_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + per_token_activation=True, + ) + w_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + x_nvfp4 = [] + w_nvfp4 = [] + bias = [] + expected = [] + for m in m_splits: + x = torch.randn((m, k), dtype=x_dtype, device=device) + w = torch.randn((n, k), dtype=w_dtype, device=device) + x_nvfp4.append( + x_quantizer.update_quantized( + x, x_quantizer.make_empty(x.shape, dtype=x_dtype, device=device) + ) + ) + w_nvfp4.append( + w_quantizer.update_quantized( + w, w_quantizer.make_empty(w.shape, dtype=w_dtype, device=device) + ) + ) + bias.append(torch.randn(n, dtype=torch.bfloat16, device=device) if use_bias else None) + expected.append( + general_gemm( + w_nvfp4[-1], + x_nvfp4[-1], + out_dtype=out_dtype, + layout="TN", + bias=bias[-1], + )[0] + ) + + if single_output: + out = [torch.empty((sum(m_splits), n), dtype=out_dtype, device=device)] + else: + out = [torch.empty((m, n), dtype=out_dtype, device=device) for m in m_splits] + + grouped_out, _, _ = general_grouped_gemm( + w_nvfp4, + x_nvfp4, + out, + quantization_params=[None] * num_gemms, + out_dtype=out_dtype, + layout="TN", + m_splits=m_splits, + bias=bias, + use_bias=use_bias, + single_output=single_output, + ) + + if single_output: + grouped_slices = torch.split(grouped_out, m_splits, dim=0) + else: + grouped_slices = grouped_out + for grouped, ref in zip(grouped_slices, expected): + torch.testing.assert_close(grouped, ref, atol=0.0, rtol=0.0) + + @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.parametrize( "M, K, N", @@ -229,6 +344,7 @@ def check_nvfp4_gemm_versus_reference( ], ids=["rowxrow", "colxrow", "colxcol"], ) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) def test_nvfp4_gemm_versus_reference( M: int, K: int, @@ -239,7 +355,14 @@ def test_nvfp4_gemm_versus_reference( accumulate: bool, is_x_columnwise: bool, is_w_columnwise: bool, + per_token_activation: bool, ): + if per_token_activation: + if accumulate: + pytest.skip("Per-token NVFP4 GEMM output rescale does not support accumulation") + if is_x_columnwise: + pytest.skip("Per-token NVFP4 GEMM output rescale requires rowwise activation usage") + check_nvfp4_gemm_versus_reference( x_dtype=x_dtype, w_dtype=w_dtype, @@ -250,4 +373,49 @@ def test_nvfp4_gemm_versus_reference( accumulate=accumulate, x_columnwise=is_x_columnwise, w_columnwise=is_w_columnwise, + per_token_activation=per_token_activation, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "m_splits, k, n", + [ + ([32, 48, 48], 128, 128), + ([64, 80, 112], 128, 256), + ([64, 80, 112], 256, 256), + ([64, 80, 112], 1024, 256), + ([256, 256, 512], 1024, 1024), + ([1024, 1536, 1536], 512, 3072), + ([16, 32, 64], 128, 96), + ([80, 96, 128], 640, 304), + ([320, 336, 352], 3072, 992), + ([64, 80, 112], 64, 256), + ([32, 48, 48], 128, 112), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("use_bias", [False, True], ids=["no_bias", "bias"]) +@pytest.mark.parametrize("single_output", [False, True], ids=["list_output", "single_output"]) +def test_nvfp4_per_token_grouped_gemm_matches_per_gemm( + m_splits: list[int], + k: int, + n: int, + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + use_bias: bool, + single_output: bool, +): + check_nvfp4_per_token_grouped_gemm_matches_per_gemm( + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, + m_splits=m_splits, + k=k, + n=n, + use_bias=use_bias, + single_output=single_output, ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index bf3f545b8b..098807b685 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -31,6 +31,7 @@ def check_quantization_nvfp4_versus_reference( swizzled_scale: bool, use_cpp_allocator: bool, with_2d_quantization: bool, + per_token_activation: bool = False, ) -> None: te_dtype = tex.DType.kFloat4E2M1 @@ -52,6 +53,7 @@ def check_quantization_nvfp4_versus_reference( with_rht=False, with_post_rht_amax=False, with_2d_quantization=with_2d_quantization, + per_token_activation=per_token_activation, ) if use_cpp_allocator: x_nvfp4_sut = nvfp4_quantizer(x) @@ -73,6 +75,7 @@ def check_quantization_nvfp4_versus_reference( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise # Reference quantization quant_tile_shape = (1, 16) if not with_2d_quantization else (16, 16) @@ -83,6 +86,7 @@ def check_quantization_nvfp4_versus_reference( pow_2_scales=False, eps=0.0, quant_tile_shape=quant_tile_shape, + per_token_activation=per_token_activation, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -102,6 +106,7 @@ def check_quantization_nvfp4_versus_reference( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col qx = unpack_fp4(qx) qx_t = unpack_fp4(qx_t) if qx_t is not None else None @@ -121,6 +126,7 @@ def check_quantization_nvfp4_versus_reference( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -155,6 +161,7 @@ def check_quantization_nvfp4_versus_reference( @pytest.mark.parametrize( "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] ) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -163,7 +170,11 @@ def test_quantization_block_tiling_versus_reference( swizzled_scale: bool, use_cpp_allocator: bool, with_2d_quantization: bool, + per_token_activation: bool, ) -> None: + if per_token_activation and with_2d_quantization: + pytest.skip("Per-token NVFP4 does not support 2D quantization") + check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, M=M, @@ -172,6 +183,7 @@ def test_quantization_block_tiling_versus_reference( swizzled_scale=swizzled_scale, use_cpp_allocator=use_cpp_allocator, with_2d_quantization=with_2d_quantization, + per_token_activation=per_token_activation, ) @@ -188,6 +200,7 @@ def test_quantization_block_tiling_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, M: int, @@ -195,6 +208,7 @@ def test_nvfp4_quantization_extrema_versus_reference( extrema_high: bool, return_transpose: bool, use_cpp_allocator: bool, + per_token_activation: bool, ): te_dtype = tex.DType.kFloat4E2M1 @@ -216,6 +230,7 @@ def test_nvfp4_quantization_extrema_versus_reference( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + per_token_activation=per_token_activation, ) if use_cpp_allocator: @@ -237,6 +252,7 @@ def test_nvfp4_quantization_extrema_versus_reference( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -245,6 +261,7 @@ def test_nvfp4_quantization_extrema_versus_reference( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), + per_token_activation=per_token_activation, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -257,6 +274,7 @@ def test_nvfp4_quantization_extrema_versus_reference( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) @@ -269,6 +287,7 @@ def test_nvfp4_quantization_extrema_versus_reference( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -286,12 +305,14 @@ def test_nvfp4_quantization_extrema_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, M: int, N: int, return_transpose: bool, use_cpp_allocator: bool, + per_token_activation: bool, ): """ Stress rounding/threshold behavior by placing values just below/above @@ -327,6 +348,7 @@ def test_nvfp4_quantization_boundary_values( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + per_token_activation=per_token_activation, ) if use_cpp_allocator: @@ -348,6 +370,7 @@ def test_nvfp4_quantization_boundary_values( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -356,6 +379,7 @@ def test_nvfp4_quantization_boundary_values( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), + per_token_activation=per_token_activation, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -368,6 +392,7 @@ def test_nvfp4_quantization_boundary_values( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) @@ -381,6 +406,7 @@ def test_nvfp4_quantization_boundary_values( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -397,12 +423,14 @@ def test_nvfp4_quantization_boundary_values( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, N: int, return_transpose: bool, use_cpp_allocator: bool, + per_token_activation: bool, ): te_dtype = tex.DType.kFloat4E2M1 @@ -424,6 +452,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + per_token_activation=per_token_activation, ) if use_cpp_allocator: @@ -445,6 +474,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -453,6 +483,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), + per_token_activation=per_token_activation, ) x_nvfp4_ref = ref_quantizer.quantize(x_nc) @@ -465,6 +496,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col # Quantized must match torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) @@ -479,5 +511,6 @@ def test_nvfp4_quantization_noncontiguous_inputs( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index ed4f73adbc..15f08975e2 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -78,6 +78,11 @@ marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), id="NVFP4BlockScaling", ), + pytest.param( + "nvfp4_per_token", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + id="NVFP4PerTokenBlockScaling", + ), ] @@ -165,7 +170,7 @@ def _maybe_skip_recipe_dtype( ) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) - if recipe_name == "nvfp4": + if recipe_name in ("nvfp4", "nvfp4_per_token"): if module_type in ("linear", "layernorm_linear") and dtype not in ( torch.bfloat16, torch.float32, @@ -180,6 +185,11 @@ def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: s pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") +def _maybe_skip_unsupported_fused_ops(recipe_name: str) -> None: + if recipe_name == "nvfp4_per_token": + pytest.skip("Per-token NVFP4 currently does not support fused te_ops paths.") + + def _maybe_skip_unsupported_recipe_shape( recipe_name: str, input_shape: tuple[int, ...], @@ -195,7 +205,9 @@ def _maybe_skip_unsupported_recipe_shape( " by 32." ) return - if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if recipe_name in ("nvfp4", "nvfp4_per_token") and ( + flat_first_dim % 16 != 0 or last_dim % 16 != 0 + ): pytest.skip( "Linear/LayerNormLinear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible" " by 16." @@ -220,7 +232,9 @@ def _maybe_skip_unsupported_recipe_shape( pytest.skip( "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." ) - if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if recipe_name in ("nvfp4", "nvfp4_per_token") and ( + flat_first_dim % 16 != 0 or last_dim % 16 != 0 + ): pytest.skip( "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." ) @@ -239,9 +253,9 @@ def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int] ) if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") - if recipe_name == "nvfp4" and any(m % 16 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_per_token") and any(m % 16 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + NVFP4 requires each non-empty m_split divisible by 16.") - if recipe_name == "nvfp4" and any(m % 64 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_per_token") and any(m % 64 != 0 for m in non_empty_splits): pytest.skip( "GroupedLinear + NVFP4 grouped split_quantize currently requires each non-empty " "m_split divisible by 64 due to grouped amax kernel constraints." @@ -1033,6 +1047,7 @@ def test_grouped_linear_backward_override_matches_reference( quantized_ref_recipe = make_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) + skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) module_quantized_ref = te.GroupedLinear( num_gemms, @@ -1271,6 +1286,7 @@ def test_grouped_linear_runtime_backward_override_switch_updates_ctx( default_recipe = make_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) + skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) *_, default_ctx = _run_grouped_linear_single_step_with_ctx_state( module, @@ -1333,6 +1349,7 @@ def test_fused_linear_paths_match_backward_override_reference( _maybe_skip_recipe_dtype(recipe_name, dtype, "ops_linear") _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") _maybe_skip_unsupported_recipe_shape(recipe_name, (m, in_features), "ops_linear") + _maybe_skip_unsupported_fused_ops(recipe_name) reset_rng_states() @@ -1472,6 +1489,7 @@ def test_fused_bias_activation_matches_masked_linear_backward( _maybe_skip_recipe_dtype(recipe_name, dtype, "ops_linear") _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, "ops_linear") + _maybe_skip_unsupported_fused_ops(recipe_name) reset_rng_states() in_features = input_shape[-1] diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index a782dadc60..8a01acf0eb 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -20,17 +20,19 @@ is_fp8_available, is_fp8_block_scaling_available, is_mxfp8_available, + is_nvfp4_available, is_bf16_available, ) from transformer_engine.pytorch.quantization import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.common import recipe -from utils import ModelConfig, reset_rng_states, skip_unsupported_backward_override +from utils import ModelConfig, recipe_id, reset_rng_states, skip_unsupported_backward_override # Check if FP8 is supported. fp8_available = is_fp8_available() fp8_block_scaling_available = is_fp8_block_scaling_available() mxfp8_available = is_mxfp8_available() +nvfp4_available = is_nvfp4_available() # Reset RNG states. reset_rng_states() @@ -62,6 +64,14 @@ def nvfp4_rht_and_2d_quantization(): return nvfp4_recipe +def nvfp4_per_token(): + nvfp4_recipe = recipe.NVFP4BlockScaling(per_token_activation=True) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + def check_rht_usage(recipe: recipe.Recipe) -> bool: # if using RHT, we can only support bf16 # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad @@ -88,7 +98,9 @@ def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) +if nvfp4_available: fp8_recipes.append(nvfp4_rht_and_2d_quantization()) + fp8_recipes.append(nvfp4_per_token()) if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: @@ -360,7 +372,7 @@ def _test_cuda_graphs( @pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=recipe_id) @pytest.mark.parametrize("backward_override", (None, "high_precision", "dequantized")) def test_make_graphed_callables( *, @@ -390,6 +402,8 @@ def test_make_graphed_callables( f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs" ) if fp8 and fp8_recipe.nvfp4(): + if getattr(fp8_recipe, "per_token_activation", False) and module == "mha": + pytest.skip("Per-token NVFP4 CUDA graph coverage applies to GEMM modules.") if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype): pytest.skip( f"Input dtype {dtype} not supported for NVFP4 Recipe" @@ -448,7 +462,7 @@ def test_make_graphed_callables( ) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("backward_override", (None, "high_precision", "dequantized")) def test_make_graphed_callables_with_fp8_weight_caching( *, diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 91d4b89013..f12148232c 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -25,10 +25,16 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.quantization import ( FP8GlobalStateManager, + NVFP4BlockScalingRecipeState, _amax_and_scale_update, ) import transformer_engine.pytorch.ops as te_ops -from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8BlockScaling, + MXFP8BlockScaling, + NVFP4BlockScaling, +) # Check if FP8 is supported fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) @@ -507,8 +513,28 @@ def test_quantizer_update(self, module_class): y = module(x) +@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) +def test_nvfp4_per_token_quantizer_roles(): + recipe = NVFP4BlockScaling(per_token_activation=True) + + forward_quantizers = NVFP4BlockScalingRecipeState( + recipe, + mode="forward", + num_quantizers=3, + ).make_quantizers() + assert [q.per_token_activation for q in forward_quantizers] == [True, False, True] + + backward_quantizers = NVFP4BlockScalingRecipeState( + recipe, + mode="backward", + num_quantizers=2, + ).make_quantizers() + assert [q.per_token_activation for q in backward_quantizers] == [False, False] + + @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) @pytest.mark.parametrize( "M, N", [ @@ -524,8 +550,8 @@ def test_quantizer_update(self, module_class): (8192, 8192), ], ) -def test_fp4_dequantize(dtype, M, N): - q = NVFP4Quantizer() +def test_fp4_dequantize(dtype, per_token_activation, M, N): + q = NVFP4Quantizer(per_token_activation=per_token_activation) a = torch.rand((M, N)).cuda().to(dtype=dtype) starting_tensor = q(a) dequantized_tensor = starting_tensor.dequantize() diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 7f2f24fd69..bb1c952163 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -38,12 +38,13 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.tensor.utils import replace_raw_data -from utils import ModelConfig, skip_unsupported_backward_override +from utils import ModelConfig, recipe_id, skip_unsupported_backward_override # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +nvfp4_available, _ = te.is_nvfp4_available(return_reason=True) # Record initial RNG state from script run. seed = 1234 @@ -93,9 +94,18 @@ def nvfp4_vanilla(): return nvfp4_recipe +def nvfp4_per_token(): + nvfp4_recipe = recipe.NVFP4BlockScaling(per_token_activation=True) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) +if nvfp4_available: fp8_recipes.append(nvfp4_vanilla()) # TODO: fix check for this if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) @@ -103,6 +113,9 @@ def nvfp4_vanilla(): fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(recipe.DelayedScaling()) fp8_recipes.append(None) +fp8_recipes_with_per_token = fp8_recipes.copy() +if nvfp4_available: + fp8_recipes_with_per_token.insert(-1, nvfp4_per_token()) param_types = [torch.float32, torch.float16] if is_bf16_available(): # bf16 requires sm_80 or higher @@ -402,7 +415,7 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @@ -450,7 +463,7 @@ def test_sanity_layernorm_linear( @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @@ -488,7 +501,7 @@ def test_sanity_linear( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -529,7 +542,7 @@ def test_sanity_linear_with_zero_tokens( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -563,7 +576,12 @@ def test_sanity_grouped_linear( if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") if fp8_recipe.nvfp4(): - pytest.skip("NVFP4 not supported for grouped linear") + if not getattr(fp8_recipe, "per_token_activation", False): + pytest.skip("NVFP4 not supported for grouped linear") + if dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") + if backward_override is None and dtype != torch.bfloat16: + pytest.skip("NVFP4 grouped default backward requires BF16 grad output") use_fp8 = fp8_recipe is not None with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 9d0ed79888..d67c5e77b7 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -32,6 +32,7 @@ is_fp8_block_scaling_available, is_nvfp4_available, ) +from utils import recipe_id fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) @@ -47,6 +48,7 @@ _all_recipes.append(recipe.MXFP8BlockScaling()) if nvfp4_available: _all_recipes.append(recipe.NVFP4BlockScaling()) + _all_recipes.append(recipe.NVFP4BlockScaling(per_token_activation=True)) # --------------------------------------------------------------------------- @@ -303,7 +305,7 @@ def fn(inp): @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.parametrize("fp8_recipe", _all_recipes, ids=lambda r: type(r).__name__) +@pytest.mark.parametrize("fp8_recipe", _all_recipes, ids=recipe_id) def test_autocast_sanity(fp8_recipe): """Smoke test: torch.nn.Linear inside a single te.autocast with each built-in recipe. Forward + backward under torch.compile(fullgraph=True).""" diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 8f8852edc2..3dc4cdffe8 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -115,7 +115,7 @@ def quantization_tols(name: str) -> dict[str, float]: "mxfp8_block_scaling", ): return dtype_tols(tex.DType.kFloat8E4M3) - if name == "nvfp4": + if name in ("nvfp4", "nvfp4_per_token"): return dtype_tols(tex.DType.kFloat4E2M1) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -149,9 +149,26 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: disable_2d_quantization=True, **recipe_kwargs, ) + if name == "nvfp4_per_token": + return transformer_engine.common.recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + per_token_activation=True, + **recipe_kwargs, + ) raise ValueError(f"Unsupported quantization scheme ({name})") +def recipe_id(fp8_recipe: Optional[Recipe]) -> str: + """Readable pytest id for FP8/FP4 recipes.""" + if fp8_recipe is None: + return "None" + if fp8_recipe.nvfp4() and getattr(fp8_recipe, "per_token_activation", False): + return "NVFP4PerTokenBlockScaling" + return type(fp8_recipe).__name__ + + def skip_unsupported_backward_override( layer_type: str, quant_recipe: Optional[Recipe], diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 5d0d3c28e8..1200979f6b 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -21,6 +21,7 @@ #include "../mxfp8/group_quantize_mxfp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" +#include "../nvfp4/quantize_per_token_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" namespace transformer_engine { @@ -100,6 +101,13 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, int32_t rows = input_tensor->flat_first_dim(); int32_t cols = input_tensor->flat_last_dim(); auto dtype = input_tensor->dtype(); + const bool per_token_activation = quant_config_cpp.nvfp4_per_token_activation; + if (per_token_activation) { + NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + "Per-token NVFP4 quantization does not support 2D quantization."); + nvfp4::quantize_per_token(*input_tensor, noop_tensor, output_tensor, stream); + break; + } bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); @@ -239,6 +247,13 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens int32_t rows = grad_tensor->flat_first_dim(); int32_t cols = grad_tensor->flat_last_dim(); auto dtype = grad_tensor->dtype(); + const bool per_token_activation = quant_config_cpp.nvfp4_per_token_activation; + if (per_token_activation) { + NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + "Per-token NVFP4 quantization does not support 2D quantization."); + nvfp4::quantize_per_token(*grad_tensor, noop_tensor, output_tensor, stream); + break; + } bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index 4143208153..85e858e146 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -34,8 +34,9 @@ namespace dequantize_kernel { template __global__ void __launch_bounds__(512) dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, - const float *const tensor_amax, const size_t N, const size_t M, - const size_t scale_stride, const size_t num_scale_tiles_X) { + const float *const tensor_amax, const size_t amax_numel, const size_t N, + const size_t M, const size_t scale_stride, + const size_t num_scale_tiles_X) { const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t x = thread_idx % M; const size_t y = thread_idx / M; @@ -63,7 +64,7 @@ __global__ void __launch_bounds__(512) fp4vec value; value.vec = input_vectorized[my_index]; fp8e4m3 scale = scales[my_scale_index]; - float amax = *tensor_amax; + float amax = (amax_numel == 1) ? tensor_amax[0] : tensor_amax[y]; constexpr float factor_inv = 1.0 / (6.0 * 448.0); float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll @@ -112,7 +113,8 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) dequantize_fp4_kernel<<>>( input.data.dptr, reinterpret_cast(output->data.dptr), reinterpret_cast(input.scale_inv.dptr), - reinterpret_cast(input.amax.dptr), N, Mread, input.scale_inv.shape.back(), + reinterpret_cast(input.amax.dptr), input.amax.numel(), N, Mread, + input.scale_inv.shape.back(), num_scale_tiles_X);); // NOLINT(*) ); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); diff --git a/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh new file mode 100644 index 0000000000..c4b16c557e --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh @@ -0,0 +1,486 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_per_token_nvfp4.cuh + * \brief CUDA kernels to cast to NVFP4 with per-token (per-row) global scaling. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_PERTOKEN_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_PERTOKEN_NVFP4_CUH_ + +#include +#include + +#include +#include + +#include "../../common.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "core_nvfp4.cuh" + +#if FP4_TYPE_SUPPORTED +#include +#endif + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { +namespace quantize_per_token_kernel { + +using namespace core; +using namespace ptx; + +constexpr int PERTOKEN_BLOCK_SIZE = 256; +constexpr int PERTOKEN_SF_VEC_SIZE = 16; + +template +__device__ __forceinline__ void abs_max_2x_update(ptx::FPx2 &dst, + const ptx::FPx2 &val) { + if constexpr (std::is_same_v) { + dst.x = fmaxf(fabsf(dst.x), fabsf(val.x)); + dst.y = fmaxf(fabsf(dst.y), fabsf(val.y)); + } else { + ptx::abs_max_2x(dst, dst, val); + } +} + +template +__device__ __forceinline__ float abs_max_2x_to_float(const ptx::FPx2 &val) { + if constexpr (std::is_same_v) { + return fmaxf(fabsf(val.x), fabsf(val.y)); + } else { + return static_cast(__hmax(__habs(val.x), __habs(val.y))); + } +} + +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(BLOCK_SIZE) +#endif + quantize_per_token_nvfp4_kernel(const int num_rows, const int num_cols, + const IType *__restrict__ input, + const int *__restrict__ row_offsets, + uint8_t *__restrict__ output_data, + fp8e4m3 *__restrict__ output_scales, + float *__restrict__ output_per_token_amax, + const int scale_stride, const float *__restrict__ noop) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using namespace detail; + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + using IType2 = typename ptx::FPx2; + + const int row_idx = blockIdx.x; + if (row_idx >= num_rows) return; + + const int actual_row = (row_offsets != nullptr) ? row_offsets[row_idx] : row_idx; + if (actual_row < 0) return; + + const int num_vec2 = num_cols / 2; + const IType2 *input_row = reinterpret_cast(input + actual_row * num_cols); + + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + for (int i = threadIdx.x; i < num_vec2; i += BLOCK_SIZE) { + const IType2 val = input_row[i]; + abs_max_2x_update(thread_amax_2x, val); + } + const float thread_max = abs_max_2x_to_float(thread_amax_2x); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + float row_amax = + BlockReduce(temp_storage).Reduce(thread_max, [](float a, float b) { return fmaxf(a, b); }); + + __shared__ float shared_s_enc; + if (threadIdx.x == 0) { + const float s_enc = compute_global_encode_scaling_factor_FP4(row_amax); + output_per_token_amax[row_idx] = row_amax; + shared_s_enc = s_enc; + } + __syncthreads(); + const float S_enc = shared_s_enc; + const float S_dec_rowwise = 1.0 / S_enc; + constexpr float fp4_max_inv = 1.0f / detail::TypeExtrema::max; + const float global_encode_scale_multiplier = S_enc * fp4_max_inv; + + const int num_sf_blocks = num_cols / PERTOKEN_SF_VEC_SIZE; + for (int sf_idx = threadIdx.x; sf_idx < num_sf_blocks; sf_idx += BLOCK_SIZE) { + const int col_start = sf_idx * PERTOKEN_SF_VEC_SIZE; + + IType2 block_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + alignas(8) IType2 vals[PERTOKEN_SF_VEC_SIZE / 2]; + const IType2 *input_block = + reinterpret_cast(input + actual_row * num_cols + col_start); + for (int j = 0; j < PERTOKEN_SF_VEC_SIZE / 2; ++j) { + vals[j] = input_block[j]; + abs_max_2x_update(block_amax_2x, vals[j]); + } + const float block_max = abs_max_2x_to_float(block_amax_2x); + + const float S_dec_b_f32 = + fminf(block_max * global_encode_scale_multiplier, detail::TypeExtrema::max); + const nvfp4_scale_t S_dec_b_fp8 = static_cast(S_dec_b_f32); + output_scales[row_idx * scale_stride + sf_idx] = S_dec_b_fp8; + + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = + fminf(1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + uint8_t *out_ptr = output_data + actual_row * (num_cols / 2) + col_start / 2; + if constexpr (std::is_same_v) { + auto *out_fp4_8x = reinterpret_cast(out_ptr); + for (int j = 0; j < PERTOKEN_SF_VEC_SIZE / 2; j += 4) { + const uint64_t elts03 = *reinterpret_cast(&vals[j]); + const uint64_t elts47 = *reinterpret_cast(&vals[j + 2]); + out_fp4_8x[j / 4] = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest( + elts03, elts47, block_scale_inverse); + } + } else { + auto *out_fp4 = reinterpret_cast(out_ptr); + for (int j = 0; j < PERTOKEN_SF_VEC_SIZE / 2; j += 2) { + const float2 in01 = + make_float2(static_cast(vals[j].x), static_cast(vals[j].y)); + const float2 in23 = + make_float2(static_cast(vals[j + 1].x), static_cast(vals[j + 1].y)); + out_fp4[j / 2] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, /*rbits=*/0u); + } + } + } +#endif +} + +template +void launch_quantize_per_token_nvfp4(const int num_rows, const int num_cols, const IType *input, + const int *row_offsets, uint8_t *output_data, + fp8e4m3 *output_scales, float *output_per_token_amax, + const int scale_stride, cudaStream_t stream, + const float *noop = nullptr) { +#if FP4_TYPE_SUPPORTED + if (num_rows == 0 || num_cols == 0) return; + + NVTE_CHECK(num_cols % PERTOKEN_SF_VEC_SIZE == 0, "num_cols must be a multiple of ", + PERTOKEN_SF_VEC_SIZE, " for per-token NVFP4 quantization, got ", num_cols); + dim3 grid(num_rows); + dim3 block(PERTOKEN_BLOCK_SIZE); + + quantize_per_token_nvfp4_kernel + <<>>(num_rows, num_cols, input, row_offsets, output_data, + output_scales, output_per_token_amax, scale_stride, noop); + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); +#endif +} + +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(BLOCK_SIZE) +#endif + compute_per_token_amax_kernel(const int num_rows, const int num_cols, + const IType *__restrict__ input, + float *__restrict__ output_per_token_amax, + const float *__restrict__ noop) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + using IType2 = typename ptx::FPx2; + + const int row_idx = blockIdx.x; + if (row_idx >= num_rows) return; + + const int num_vec2 = num_cols / 2; + const IType2 *input_row = reinterpret_cast(input + row_idx * num_cols); + + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + for (int i = threadIdx.x; i < num_vec2; i += BLOCK_SIZE) { + const IType2 val = input_row[i]; + abs_max_2x_update(thread_amax_2x, val); + } + const float thread_max = abs_max_2x_to_float(thread_amax_2x); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + float row_amax = + BlockReduce(temp_storage).Reduce(thread_max, [](float a, float b) { return fmaxf(a, b); }); + + if (threadIdx.x == 0) { + output_per_token_amax[row_idx] = row_amax; + } +#endif +} + +template +void launch_compute_per_token_amax(const int num_rows, const int num_cols, const IType *input, + float *output_per_token_amax, cudaStream_t stream, + const float *noop = nullptr) { +#if FP4_TYPE_SUPPORTED + if (num_rows == 0 || num_cols == 0) return; + + NVTE_CHECK(num_cols % 2 == 0, "num_cols must be even for per-token amax computation, got ", + num_cols); + dim3 grid(num_rows); + dim3 block(PERTOKEN_BLOCK_SIZE); + + compute_per_token_amax_kernel + <<>>(num_rows, num_cols, input, output_per_token_amax, noop); + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); +#endif +} + +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(BLOCK_SIZE) +#endif + quantize_per_token_nvfp4_columnwise_kernel(const int num_rows, const int num_cols, + const IType *__restrict__ input, + uint8_t *__restrict__ output_data_t, + fp8e4m3 *__restrict__ output_scales_t, + const float *__restrict__ per_token_amax, + const int scale_stride_t, + const float *__restrict__ noop) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using namespace detail; + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + const int col_idx = blockIdx.x; + if (col_idx >= num_cols) return; + + constexpr float fp4_max_inv = 1.0f / TypeExtrema::max; + constexpr float float_max = TypeExtrema::max; + constexpr float one = 1.0f; + const float2 one_2x{one, one}; + const int num_row_blocks = num_rows / PERTOKEN_SF_VEC_SIZE; + + for (int row_block = threadIdx.x; row_block < num_row_blocks; row_block += BLOCK_SIZE) { + const int row_start = row_block * PERTOKEN_SF_VEC_SIZE; + + float vals[PERTOKEN_SF_VEC_SIZE]; + float s_enc[PERTOKEN_SF_VEC_SIZE]; + float scaled_block_amax = 0.0f; +#pragma unroll + for (int i = 0; i < PERTOKEN_SF_VEC_SIZE; ++i) { + const int row_idx = row_start + i; + const float val = static_cast(input[row_idx * num_cols + col_idx]); + const float S_enc = compute_global_encode_scaling_factor_FP4(per_token_amax[row_idx]); + vals[i] = val; + s_enc[i] = S_enc; + scaled_block_amax = fmaxf(scaled_block_amax, fabsf(val) * (S_enc * fp4_max_inv)); + } + + const float S_dec_b_f32 = fminf(scaled_block_amax, float_max); + const nvfp4_scale_t S_dec_b_fp8 = static_cast(S_dec_b_f32); + output_scales_t[col_idx * scale_stride_t + row_block] = S_dec_b_fp8; + + float scaled_vals[PERTOKEN_SF_VEC_SIZE]; +#pragma unroll + for (int i = 0; i < PERTOKEN_SF_VEC_SIZE; ++i) { + const float S_dec_rowwise = 1.0f / s_enc[i]; + const float block_scale_inverse = + fminf(1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); + scaled_vals[i] = vals[i] * block_scale_inverse; + } + + uint8_t *out_ptr = output_data_t + col_idx * (num_rows / 2) + row_start / 2; + auto *out_fp4 = reinterpret_cast(out_ptr); +#pragma unroll + for (int j = 0; j < PERTOKEN_SF_VEC_SIZE; j += 4) { + const float2 in01 = make_float2(scaled_vals[j], scaled_vals[j + 1]); + const float2 in23 = make_float2(scaled_vals[j + 2], scaled_vals[j + 3]); + out_fp4[j / 4] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, one_2x, /*rbits=*/0u); + } + } +#endif +} + +template +void launch_quantize_per_token_nvfp4_columnwise(const int num_rows, const int num_cols, + const IType *input, uint8_t *output_data_t, + fp8e4m3 *output_scales_t, + const float *per_token_amax, + const int scale_stride_t, cudaStream_t stream, + const float *noop = nullptr) { +#if FP4_TYPE_SUPPORTED + if (num_rows == 0 || num_cols == 0) return; + + NVTE_CHECK(num_rows % PERTOKEN_SF_VEC_SIZE == 0, "num_rows must be a multiple of ", + PERTOKEN_SF_VEC_SIZE, " for per-token NVFP4 columnwise quantization, got ", num_rows); + dim3 grid(num_cols); + dim3 block(PERTOKEN_BLOCK_SIZE); + + quantize_per_token_nvfp4_columnwise_kernel + <<>>(num_rows, num_cols, input, output_data_t, output_scales_t, + per_token_amax, scale_stride_t, noop); + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); +#endif +} + +} // namespace quantize_per_token_kernel + +inline void quantize_per_token(const Tensor &input, const Tensor *noop, Tensor *output, + cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", false); + + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "NVFP4 output tensor must be allocated."); + NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format."); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + NVTE_CHECK(cols % quantize_per_token_kernel::PERTOKEN_SF_VEC_SIZE == 0, + "Per-token NVFP4 quantization requires last dim divisible by ", + quantize_per_token_kernel::PERTOKEN_SF_VEC_SIZE, "."); + + const auto *noop_ptr = reinterpret_cast(noop->data.dptr); + auto *amax_ptr = reinterpret_cast(output->amax.dptr); + auto *amax_colwise_ptr = reinterpret_cast(output->columnwise_amax.dptr); + auto *per_token_amax_ptr = (amax_ptr != nullptr) ? amax_ptr : amax_colwise_ptr; + NVTE_CHECK(per_token_amax_ptr != nullptr, "Per-token amax tensor must be allocated."); + if (amax_ptr != nullptr) { + NVTE_CHECK(output->amax.numel() == rows, "Per-token rowwise amax must have ", rows, + " entries, got ", output->amax.shape, "."); + } + if (amax_colwise_ptr != nullptr) { + NVTE_CHECK(output->columnwise_amax.numel() == rows, "Per-token columnwise amax must have ", + rows, " entries, got ", output->columnwise_amax.shape, "."); + } + const int *row_offsets = nullptr; + + if (input.dtype() == DType::kBFloat16) { + const auto *input_ptr = reinterpret_cast(input.data.dptr); + if (output->has_data()) { + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Rowwise output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Rowwise scaling tensor must be allocated."); + NVTE_CHECK(output->amax.dptr != nullptr, "Rowwise per-token amax tensor must be allocated."); + auto *data_ptr = reinterpret_cast(output->data.dptr); + auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); + const int scale_stride = static_cast(output->scale_inv.shape.back()); + quantize_per_token_kernel::launch_quantize_per_token_nvfp4<__nv_bfloat16>( + static_cast(rows), static_cast(cols), input_ptr, row_offsets, data_ptr, + scale_ptr, amax_ptr, scale_stride, stream, noop_ptr); + } else { + quantize_per_token_kernel::launch_compute_per_token_amax<__nv_bfloat16>( + static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, + noop_ptr); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Columnwise output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated."); + if (amax_ptr != nullptr && amax_colwise_ptr != nullptr && amax_ptr != amax_colwise_ptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(amax_colwise_ptr, amax_ptr, rows * sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); + auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); + quantize_per_token_kernel::launch_quantize_per_token_nvfp4_columnwise<__nv_bfloat16>( + static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, + per_token_amax_ptr, scale_stride_t, stream, noop_ptr); + } + } else if (input.dtype() == DType::kFloat16) { + const auto *input_ptr = reinterpret_cast(input.data.dptr); + if (output->has_data()) { + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Rowwise output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Rowwise scaling tensor must be allocated."); + NVTE_CHECK(output->amax.dptr != nullptr, "Rowwise per-token amax tensor must be allocated."); + auto *data_ptr = reinterpret_cast(output->data.dptr); + auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); + const int scale_stride = static_cast(output->scale_inv.shape.back()); + quantize_per_token_kernel::launch_quantize_per_token_nvfp4( + static_cast(rows), static_cast(cols), input_ptr, row_offsets, data_ptr, + scale_ptr, amax_ptr, scale_stride, stream, noop_ptr); + } else { + quantize_per_token_kernel::launch_compute_per_token_amax( + static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, + noop_ptr); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Columnwise output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated."); + if (amax_ptr != nullptr && amax_colwise_ptr != nullptr && amax_ptr != amax_colwise_ptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(amax_colwise_ptr, amax_ptr, rows * sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); + auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); + quantize_per_token_kernel::launch_quantize_per_token_nvfp4_columnwise( + static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, + per_token_amax_ptr, scale_stride_t, stream, noop_ptr); + } + } else if (input.dtype() == DType::kFloat32) { + const auto *input_ptr = reinterpret_cast(input.data.dptr); + if (output->has_data()) { + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Rowwise output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Rowwise scaling tensor must be allocated."); + NVTE_CHECK(output->amax.dptr != nullptr, "Rowwise per-token amax tensor must be allocated."); + auto *data_ptr = reinterpret_cast(output->data.dptr); + auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); + const int scale_stride = static_cast(output->scale_inv.shape.back()); + quantize_per_token_kernel::launch_quantize_per_token_nvfp4( + static_cast(rows), static_cast(cols), input_ptr, row_offsets, data_ptr, + scale_ptr, amax_ptr, scale_stride, stream, noop_ptr); + } else { + quantize_per_token_kernel::launch_compute_per_token_amax( + static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, + noop_ptr); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Columnwise output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated."); + if (amax_ptr != nullptr && amax_colwise_ptr != nullptr && amax_ptr != amax_colwise_ptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(amax_colwise_ptr, amax_ptr, rows * sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); + auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); + quantize_per_token_kernel::launch_quantize_per_token_nvfp4_columnwise( + static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, + per_token_amax_ptr, scale_stride_t, stream, noop_ptr); + } + } else { + NVTE_ERROR( + "Unsupported input dtype for per-token NVFP4 quantization. " + "Expected BFloat16, Float16, or Float32."); + } +#else + NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); +#endif +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_PERTOKEN_NVFP4_CUH_ diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index c1b3f8f427..c5b4254e8b 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -470,6 +470,7 @@ struct QuantizationConfig { bool nvfp4_2d_quantization = false; bool stochastic_rounding = false; bool use_fast_math = false; + bool nvfp4_per_token_activation = false; static constexpr size_t attr_sizes[] = { sizeof(uint8_t), // force_pow_2_scales @@ -479,7 +480,8 @@ struct QuantizationConfig { sizeof(NVTETensor), // rng_seed and offset sizeof(uint8_t), // nvfp4_2d_quantization sizeof(uint8_t), // stochastic_rounding - sizeof(uint8_t) // use_fast_math + sizeof(uint8_t), // use_fast_math + sizeof(uint8_t) // nvfp4_per_token_activation }; }; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index b7461a85d1..0463d51d1c 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -370,6 +370,8 @@ enum NVTEQuantizationConfigAttribute { * inconsistently between kernels. */ kNVTEQuantizationConfigUseFastMath = 7, + /*! Whether to enable per-token (per-row) NVFP4 quantization */ + kNVTEQuantizationConfigNVFP4PerTokenActivation = 8, kNVTEQuantizationConfigNumAttributes }; @@ -1296,6 +1298,13 @@ class QuantizationConfigWrapper { sizeof(val)); } + /*! \brief Set whether to enable per-token NVFP4 quantization */ + void set_nvfp4_per_token_activation(bool nvfp4_per_token_activation) { + const auto val = static_cast(nvfp4_per_token_activation); + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP4PerTokenActivation, + &val, sizeof(val)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 67b6f87067..e59d01d82a 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -478,6 +478,10 @@ class NVFP4BlockScaling(Recipe): If set to `True`, stochastic rounding is disabled during quantization for all tensors. disable_2d_quantization : bool, default = False If set to `True`, 1D block scaling with block size 16 is used for all tensors. + per_token_activation : bool, default = False + If set to `True`, GroupedLinear activation split quantization uses per-token + (per-row) NVFP4 global amax values. In this mode, rowwise ``amax`` metadata + is stored as a vector with one FP32 value per token. backward_override : {None, 'high_precision', 'dequantized'}, default = None Backward precision mode. None does not modify backward behavior, `high_precision` keeps original high-precision operands for backward, @@ -491,6 +495,7 @@ class NVFP4BlockScaling(Recipe): os.getenv("NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING", "0") == "1" ) disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1" + per_token_activation: bool = os.getenv("NVTE_NVFP4_PER_TOKEN_ACTIVATION", "0") == "1" fp4_format: Format = Format.E2M1 fp8_format: Format = Format.E4M3 @@ -534,6 +539,7 @@ def __repr__(self) -> str: f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " f"backward_override={self.backward_override}, " + f"per_token_activation={self.per_token_activation}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 1261879a8b..a0a0ffa45f 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1043,6 +1043,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: bool_to_uint8(config_.use_fast_math, buf); break; + case kNVTEQuantizationConfigNVFP4PerTokenActivation: + bool_to_uint8(config_.nvfp4_per_token_activation, buf); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -1098,6 +1101,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: uint8_to_bool(buf, config_.use_fast_math); break; + case kNVTEQuantizationConfigNVFP4PerTokenActivation: + uint8_to_bool(buf, config_.nvfp4_per_token_activation); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 6f3553bf94..79a7d28df5 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -15,6 +15,7 @@ from ..quantized_tensor import Quantizer from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage from ..tensor.utils import is_custom from ..custom_recipes.gemm import custom_gemm from ...debug.pytorch.debug_quantization import DebugQuantizer @@ -69,6 +70,30 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float: return 0.0 +def _is_nvfp4_per_token_tensor(tensor: torch.Tensor) -> bool: + """Whether tensor carries per-token NVFP4 global amax metadata.""" + if not isinstance(tensor, NVFP4TensorStorage): + return False + amax = tensor._amax_rowwise if tensor._amax_rowwise is not None else tensor._amax_columnwise + return amax is not None and amax.numel() > 1 + + +def _nvfp4_per_token_gemm_input( + tensor: NVFP4TensorStorage, +) -> Tuple[NVFP4TensorStorage, torch.Tensor]: + """Return a GEMM alias with identity activation amax and the original per-token amax.""" + metadata = tensor.get_metadata() + if tensor._amax_rowwise is not None: + amax = tensor._amax_rowwise + assert amax is not None and amax.numel() > 1 + metadata["amax_rowwise"] = amax.new_ones(1) + else: + amax = tensor._amax_columnwise + assert amax is not None and amax.numel() > 1 + metadata["amax_columnwise"] = amax.new_ones(1) + return NVFP4TensorStorage(**metadata), amax + + def general_gemm( A: torch.Tensor, B: torch.Tensor, @@ -174,7 +199,58 @@ def general_gemm( "beta": beta, } - out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) + if not _is_nvfp4_per_token_tensor(B): + out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) + else: + assert layout[1] == "N", "Per-token NVFP4 GEMM currently supports N-layout B only." + if grad: + raise RuntimeError( + "Per-token NVFP4 GEMM currently supports fprop only. " + "Backward NVFP4 gradient quantizers should use scalar global amax." + ) + assert not gelu, "Per-token NVFP4 GEMM currently does not support fused GELU." + assert not accumulate, "Per-token NVFP4 GEMM currently does not support accumulation." + assert ( + quantization_params is None + ), "Per-token NVFP4 GEMM currently does not support output quantization." + assert out is None or ( + isinstance(out, torch.Tensor) and not is_custom(out) + ), "Per-token NVFP4 GEMM currently supports only plain torch.Tensor outputs." + # cuBLAS folds the first activation amax into GEMM alpha. Keep per-token amax out of + # alpha by using identity here, then apply the true per-token scale in FP32 below. + gemm_B, amax = _nvfp4_per_token_gemm_input(B) + per_token_scales = amax.view(-1, 1) + + requested_out, requested_out_dtype = out, out_dtype + fp32_out = ( + torch.empty_like(requested_out, dtype=torch.float32) + if requested_out is not None + else None + ) + gemm_args = list(args) + gemm_args[2] = gemm_B # B + gemm_args[4] = fp32_out # out + gemm_args[5] = None # quantization_params + gemm_args[6] = TE_DType[torch.float32] # out_dtype + out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*gemm_args, **kwargs) + out_2d = out.reshape(-1, out.shape[-1]) + + assert amax.dtype == torch.float32 and out.dtype == torch.float32 + assert amax.numel() == out_2d.shape[0] + + if bias is not None: + bias_cast = bias.to(dtype=torch.float32) + out_2d.sub_(bias_cast) + out_2d.mul_(per_token_scales) + out_2d.add_(bias_cast) + else: + out_2d.mul_(per_token_scales) + + if requested_out is not None: + requested_out.copy_(out.to(dtype=requested_out.dtype)) + out = requested_out + elif requested_out_dtype is not None and requested_out_dtype != torch.float32: + out = out.to(dtype=requested_out_dtype) if debug_quantizer is not None: out = debug_quantizer.process_gemm_output(out) @@ -229,6 +305,52 @@ def general_grouped_gemm( else: bias_dtype = TE_DType[torch.bfloat16] + if any(_is_nvfp4_per_token_tensor(tensor) for tensor in B): + assert layout[1] == "N", "Per-token NVFP4 grouped GEMM currently supports N-layout B only." + if grad: + raise RuntimeError( + "Per-token NVFP4 grouped GEMM currently supports fprop only. " + "Backward NVFP4 gradient quantizers should use scalar global amax." + ) + assert not gelu, "Per-token NVFP4 grouped GEMM currently does not support fused GELU." + assert ( + not accumulate + ), "Per-token NVFP4 grouped GEMM currently does not support accumulation." + assert D_dtype is None, "Per-token NVFP4 grouped GEMM currently does not support D_dtype." + assert all( + q is None for q in quantization_params + ), "Per-token NVFP4 grouped GEMM currently does not support output quantization." + if single_output: + assert ( + m_splits is not None + ), "Per-token NVFP4 grouped GEMM requires m_splits with single output." + out_init = out[0] if single_output else None + if single_output: + start_idx = 0 + out_views = [] + for i in range(num_gemms): + size = m_splits[i] + out_views.append(out_init[start_idx : start_idx + size]) + start_idx += size + else: + out_views = out + for i in range(num_gemms): + if out_views[i].numel() == 0: + continue + gemm_out, _, _, _ = general_gemm( + A[i], + B[i], + quantization_params=None, + out_dtype=out_views[i].dtype, + layout=layout, + bias=bias[i] if use_bias else None, + use_split_accumulator=use_split_accumulator, + ) + out_views[i].copy_(gemm_out) + if single_output: + out = out_init + return out, bias, gelu_input + if isinstance(quantization_params[0], DebugQuantizer): assert not gelu, "GELU not supported in debug mode" if single_output: diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 8e3bcdd5b3..b9f852c07d 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -320,6 +320,7 @@ class NVFP4Quantizer : public Quantizer { // 2D block scaling bool with_2d_quantization; bool stochastic_rounding; + bool per_token_activation; int rht_matrix_random_sign_mask_t; at::Tensor rht_matrix; diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 4a2ea7412b..f62853bb2b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -326,6 +326,8 @@ py::object group_dequantize(const py::handle &input, DType otype); py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, std::optional first_dims); +std::tuple quantize_nvfp4_per_token(at::Tensor input); + std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 2df3b66553..17f86d63d6 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -42,8 +42,9 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->per_token_activation || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else { impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; @@ -154,8 +155,9 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->per_token_activation || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else { impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index 0cf2025f1b..e2dba46370 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -152,8 +152,9 @@ std::vector dact_dbias( } else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->per_token_activation || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else { impl = Impl::FUSED_DACT_AMAX_NVFP4; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 50fe4c109e..ba75867a15 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -801,6 +801,7 @@ std::tuple, std::vector, bool> bulk_alloc const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto fp4_dtype = quantizer_cpp_list[0]->dtype; + const bool per_token_activation = quantizer_cpp_list[0]->per_token_activation; const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) Enable based on optimize_for_gemm; constexpr size_t scale_elem_size = 1; @@ -828,6 +829,16 @@ std::tuple, std::vector, bool> bulk_alloc } return fp4_shape; }; + auto flat_first_dim = [](const std::vector &shape) -> size_t { + if (shape.empty()) { + return 1; + } + size_t rows = 1; + for (size_t i = 0; i + 1 < shape.size(); ++i) { + rows *= shape[i]; + } + return rows; + }; // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list, amax_rowwise_list; @@ -866,7 +877,9 @@ std::tuple, std::vector, bool> bulk_alloc // Note: Multi-quantize kernel does not require contiguous amaxes. const auto offset = roundup(buffer_size, 16); amax_offsets.push_back(offset); - buffer_size = offset + 4; + const size_t amax_size = + per_token_activation ? 4 * flat_first_dim(rowwise_data_shapes[i]) : 4; + buffer_size = offset + amax_size; } // Allocate full buffer @@ -879,8 +892,11 @@ std::tuple, std::vector, bool> bulk_alloc data_offsets[i], torch::kUInt8)); rowwise_scale_list.emplace_back( make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + const std::vector amax_shape = + per_token_activation ? std::vector{flat_first_dim(rowwise_data_shapes[i])} + : std::vector{1}; amax_rowwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); } } @@ -928,7 +944,8 @@ std::tuple, std::vector, bool> bulk_alloc // Note: Multi-quantize kernel does not require contiguous amaxes. const auto offset = roundup(buffer_size, 16); amax_offsets.push_back(offset); - buffer_size = offset + 4; + const size_t amax_size = per_token_activation ? 4 * flat_first_dim(shape_list[i]) : 4; + buffer_size = offset + amax_size; } // Allocate full buffer @@ -941,8 +958,11 @@ std::tuple, std::vector, bool> bulk_alloc buffer, to_fp4_shape(columnwise_data_shapes[i]), data_offsets[i], torch::kUInt8)); columnwise_scale_list.emplace_back( make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + const std::vector amax_shape = + per_token_activation ? std::vector{flat_first_dim(shape_list[i])} + : std::vector{1}; amax_columnwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); } } @@ -983,11 +1003,11 @@ std::tuple, std::vector, bool> bulk_alloc // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, - std::vector{1}); + getTensorShape(amax_rowwise_list[i])); } if (columnwise_usage) { tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, - std::vector{1}); + getTensorShape(amax_columnwise_list[i])); } tensor_cpp_list.emplace_back(std::move(tensor_wrapper)); @@ -1263,6 +1283,33 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, nvte_tensor_output_list.push_back(output_list[i].data()); } + if (quantizer.per_token_activation) { + NVTE_CHECK(!quantizer.with_rht, "Per-token NVFP4 split quantize does not support RHT."); + NVTE_CHECK(!quantizer.with_2d_quantization, + "Per-token NVFP4 split quantize does not support 2D quantization."); + NVTE_CHECK(!quantizer.stochastic_rounding, + "Per-token NVFP4 split quantize does not support stochastic rounding."); + + std::vector quant_config_list; + quant_config_list.reserve(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + quant_config_list.emplace_back(QuantizationConfigWrapper()); + quant_config_list.back().set_nvfp4_per_token_activation(true); + } + + for (size_t i = 0; i < num_tensors; i++) { + if (input_list[i].numel() == 0) { + continue; + } + const size_t input_ndim = input_list[i].ndim(); + const size_t cols = input_ndim > 0 ? input_list[i].size(input_ndim - 1) : 1; + NVTE_CHECK(cols % 16 == 0, + "Per-token NVFP4 split quantize requires split inner dim divisible by 16."); + nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config_list[i], stream); + } + return; + } + // In this case without RHT, the rowwise and colwise quantization are fused // we don't need separate rng states for rowwise and colwise bool need_separate_rng_states = false; @@ -1360,8 +1407,13 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, // Check input tensor shape const size_t input_last_dim = input.ndim() > 0 ? input.size(input.ndim() - 1) : 1; - NVTE_CHECK(input_last_dim % 128 == 0, - "NVFP4 multi-quantize requires inner dim to be multiple of 128."); + if (quantizer.per_token_activation) { + NVTE_CHECK(input_last_dim % 16 == 0, + "Per-token NVFP4 split-quantize requires inner dim to be multiple of 16."); + } else { + NVTE_CHECK(input_last_dim % 128 == 0, + "NVFP4 multi-quantize requires inner dim to be multiple of 128."); + } // CUDA stream auto stream = at::cuda::getCurrentCUDAStream(); @@ -1433,12 +1485,25 @@ std::vector split_quantize(const at::Tensor &tensor, for (size_t i = 0; i < num_splits; i++) { quantizer_cpp_list.push_back(convert_quantizer(quantizer_list[i])); } + const bool all_nvfp4_quantizers = std::all_of(quantizer_list.begin(), quantizer_list.end(), + [](const py::handle &quantizer) -> bool { + return detail::IsNVFP4Quantizers(quantizer.ptr()); + }); + const bool all_nvfp4_per_token_activation = + all_nvfp4_quantizers && + std::all_of(quantizer_cpp_list.begin(), quantizer_cpp_list.end(), + [](const std::unique_ptr &quantizer) -> bool { + return static_cast(quantizer.get())->per_token_activation; + }); // Choose implementation for allocating and populating tensors enum class AllocationMethod { UNFUSED, BULK_FP8_BLOCKWISE, BULK_MXFP8, BULK_NVFP4 }; enum class QuantizationMethod { UNFUSED, FUSED_NVFP4 }; AllocationMethod allocation_method = AllocationMethod::UNFUSED; QuantizationMethod quantization_method = QuantizationMethod::UNFUSED; + if (all_nvfp4_per_token_activation) { + quantization_method = QuantizationMethod::FUSED_NVFP4; + } if (!disable_bulk_allocation) { if (std::all_of(quantizer_list.begin(), quantizer_list.end(), [](const py::handle &quantizer) -> bool { @@ -1450,10 +1515,7 @@ std::vector split_quantize(const at::Tensor &tensor, return detail::IsMXFP8Quantizers(quantizer.ptr()); })) { allocation_method = AllocationMethod::BULK_MXFP8; - } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), - [](const py::handle &quantizer) -> bool { - return detail::IsNVFP4Quantizers(quantizer.ptr()); - })) { + } else if (all_nvfp4_quantizers) { allocation_method = AllocationMethod::BULK_NVFP4; quantization_method = QuantizationMethod::FUSED_NVFP4; } @@ -1492,7 +1554,8 @@ std::vector split_quantize(const at::Tensor &tensor, bool contiguous_data_and_scale = false; std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) = bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); - if (!input_shape.empty() && input_shape.back() % 128 != 0) { + if (!all_nvfp4_per_token_activation && !input_shape.empty() && + input_shape.back() % 128 != 0) { static std::once_flag once_unfused_nvfp4_fallback_warning; std::call_once(once_unfused_nvfp4_fallback_warning, []() { NVTE_WARN( @@ -1502,7 +1565,7 @@ std::vector split_quantize(const at::Tensor &tensor, }); quantization_method = QuantizationMethod::UNFUSED; } - if (!contiguous_data_and_scale) { + if (!all_nvfp4_per_token_activation && !contiguous_data_and_scale) { // Avoid fused quantize kernel if data is not contiguous quantization_method = QuantizationMethod::UNFUSED; } @@ -1540,5 +1603,50 @@ std::vector split_quantize(const at::Tensor &tensor, return output_py_list; } +std::tuple quantize_nvfp4_per_token(at::Tensor input) { + init_extension(); + + NVTE_CHECK(input.dim() == 2, "Input must be 2D (num_rows, num_cols)"); + NVTE_CHECK(input.is_cuda(), "Input must be on CUDA device"); + + const int num_rows = input.size(0); + const int num_cols = input.size(1); + NVTE_CHECK(num_cols % 16 == 0, + "num_cols must be a multiple of 16 for per-token NVFP4 quantization"); + + if (num_rows == 0) { + auto options = input.options(); + return {at::empty({0, num_cols / 2}, options.dtype(at::kByte)), + at::empty({0, num_cols / 16}, options.dtype(at::kByte)), + at::empty({0}, options.dtype(at::kFloat))}; + } + + auto input_contig = input.contiguous(); + auto options = input_contig.options(); + + auto output_data = at::empty({num_rows, num_cols / 2}, options.dtype(at::kByte)); + auto output_scales = at::empty({num_rows, num_cols / 16}, options.dtype(at::kByte)); + auto output_per_token_amax = at::empty({num_rows}, options.dtype(at::kFloat)); + + auto te_input = makeTransformerEngineTensor(input_contig); + TensorWrapper te_output(NVTE_NVFP4_1D_SCALING); + te_output.set_rowwise_data( + output_data.data_ptr(), DType::kFloat4E2M1, + std::vector{static_cast(num_rows), static_cast(num_cols)}); + te_output.set_rowwise_scale_inv( + output_scales.data_ptr(), DType::kFloat8E4M3, + std::vector{static_cast(num_rows), static_cast(num_cols / 16)}); + te_output.set_amax(output_per_token_amax.data_ptr(), DType::kFloat32, + std::vector{static_cast(num_rows)}); + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_per_token_activation(true); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + NVTE_SCOPED_GIL_RELEASE( + { nvte_quantize_v2(te_input.data(), te_output.data(), quant_config, stream); }); + + return {output_data, output_scales, output_per_token_amax}; +} + } // namespace pytorch } // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index fb4c7aa1c9..3975c01fa5 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -120,8 +120,9 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->per_token_activation || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { // TE kernel supports amax output @@ -357,8 +358,9 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->per_token_activation || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { // TE kernel supports amax output diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index eb7576d905..b2d74205cc 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -145,6 +145,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Dequantize group tensor", py::arg("input"), py::arg("otype")); m.def("bgrad_group_quantize", transformer_engine::pytorch::bgrad_group_quantize, py::arg("tensor"), py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims")); + m.def("quantize_nvfp4_per_token", transformer_engine::pytorch::quantize_nvfp4_per_token, + "Per-token NVFP4 quantization", py::arg("input")); m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index da91e5c170..6e6e38a1dd 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1696,6 +1696,7 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast(); this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); + this->per_token_activation = quantizer.attr("per_token_activation").cast(); // Get amax reduction group if needed for NVFP4 AG const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); @@ -1760,9 +1761,10 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve rowwise_scale_inv_shape.end()); rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + const int64_t amax_rows = this->per_token_activation ? static_cast(flat_first_dim) : 1; // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_rowwise = at::empty({1}, bit32_tensor_opts); + amax_rowwise = at::empty({amax_rows}, bit32_tensor_opts); } if (columnwise_usage) { const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), @@ -1777,7 +1779,8 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_columnwise = at::empty({1}, bit32_tensor_opts); + const int64_t amax_rows = this->per_token_activation ? static_cast(flat_first_dim) : 1; + amax_columnwise = at::empty({amax_rows}, bit32_tensor_opts); } // Convert tensors to Python @@ -1850,7 +1853,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), DType::kFloat4E2M1, shape); out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, rowwise_scale_inv_shape); - out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, std::vector{1}); + out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); } if (columnwise_usage) { // enforce 2D shape to avoid [S, B, H] shape and B and be 1 @@ -1862,7 +1865,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, columnwise_scale_inv_shape); out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, - std::vector{1}); + getTensorShape(amax_columnwise)); } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); this->set_quantization_params(&out_cpp); @@ -1975,15 +1978,22 @@ std::pair NVFP4Quantizer::create_unquantized_tensor_w auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); // Register amax pointer from quantized tensor - void* amax_ptr = quantized_tensor.amax(); + auto rowwise_amax = quantized_tensor.get_amax(); + auto columnwise_amax = quantized_tensor.get_columnwise_amax(); + + void* amax_ptr = rowwise_amax.data_ptr; + std::vector amax_shape = convertShape(rowwise_amax.shape); if (amax_ptr == nullptr) { - amax_ptr = quantized_tensor.get_columnwise_amax().data_ptr; + amax_ptr = columnwise_amax.data_ptr; + amax_shape = convertShape(columnwise_amax.shape); } NVTE_CHECK(amax_ptr != nullptr, "Could not extract amax pointer from NVFP4 tensor."); - out_cpp.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + out_cpp.set_amax(amax_ptr, DType::kFloat32, amax_shape); // Zero out amax - NVTE_CHECK_CUDA(cudaMemsetAsync(amax_ptr, 0, sizeof(float), at::cuda::getCurrentCUDAStream())); + const size_t amax_numel = product(amax_shape); + NVTE_CHECK_CUDA( + cudaMemsetAsync(amax_ptr, 0, amax_numel * sizeof(float), at::cuda::getCurrentCUDAStream())); return {std::move(out_cpp), std::move(out_py)}; } @@ -2050,9 +2060,11 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } if (!amax_rowwise) { const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + const int64_t amax_rows = + this->per_token_activation ? static_cast(flat_first_dim) : 1; // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_rowwise = at::empty({1}, opts); + amax_rowwise = at::empty({amax_rows}, opts); tensor.attr("_amax_rowwise") = *amax_rowwise; } } else { // rowwise_usage == false @@ -2094,7 +2106,9 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_columnwise = at::empty({1}, opts); + const int64_t amax_rows = + this->per_token_activation ? static_cast(flat_first_dim) : 1; + amax_columnwise = at::empty({amax_rows}, opts); tensor.attr("_amax_columnwise") = *amax_columnwise; } } else { // columnwise_usage == false @@ -2118,7 +2132,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( out_cpp.set_rowwise_data(rowwise_data->data_ptr(), DType::kFloat4E2M1, shape); out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E4M3, getTensorShape(*rowwise_scale_inv)); - out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, std::vector{1}); + out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, getTensorShape(*amax_rowwise)); } if (columnwise_usage) { // enforce 2D shape to avoid [S, B, H] shape and B and be 1 @@ -2130,7 +2144,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E4M3, getTensorShape(*columnwise_scale_inv)); out_cpp.set_columnwise_amax(amax_columnwise->data_ptr(), DType::kFloat32, - std::vector{1}); + getTensorShape(*amax_columnwise)); } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); this->set_quantization_params(&out_cpp); @@ -2241,6 +2255,18 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } size_t cols = input.size(input.ndim() - 1); + if (this->per_token_activation) { + NVTE_CHECK(!this->with_rht, "Per-token NVFP4 activation does not support RHT."); + NVTE_CHECK(!this->with_2d_quantization, + "Per-token NVFP4 activation does not support 2D quantization."); + NVTE_CHECK(!this->stochastic_rounding, + "Per-token NVFP4 activation does not support stochastic rounding."); + NVTE_CHECK(!this->with_amax_reduction, + "Per-token NVFP4 activation does not support amax reduction."); + NVTE_CHECK(cols % 16 == 0, "Per-token NVFP4 activation requires last dim divisible by 16."); + quant_config.set_nvfp4_per_token_activation(true); + } + // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT bool eligible_for_rht_cast_fusion = input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; @@ -2307,7 +2333,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou "Use with_post_rht_amax=true instead."); } } else { // Without RHT - if (compute_amax) { + if (compute_amax && !this->per_token_activation) { // Amax pointers auto rowwise_amax_ptr = out.get_amax().data_ptr; auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; @@ -2408,6 +2434,8 @@ void NVFP4Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, } void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out) { + NVTE_CHECK(!this->per_token_activation, + "quantize_with_amax is not supported for per-token NVFP4 activation."); // Update output tensor amaxes with input tensor amax auto input_amax_ptr = input.amax(); auto output_rowwise_amax_ptr = out.get_amax().data_ptr; diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index dd01ae05d3..d57ea792dd 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -350,6 +350,7 @@ def __init__( pow_2_scales: bool = False, eps: float = 0.0, quant_tile_shape: Tuple[int, int] = (1, 16), + per_token_activation: bool = False, with_rht: bool = False, with_random_sign_mask: bool = True, ): @@ -360,6 +361,7 @@ def __init__( self.pow_2_scales = pow_2_scales self.eps = eps self.quant_tile_shape = quant_tile_shape + self.per_token_activation = per_token_activation self.with_rht = with_rht self.with_random_sign_mask = with_random_sign_mask @@ -447,9 +449,14 @@ def _quantize_blockwise_reference( tile_len_y: int, *, pow_2_scales: bool, + per_token_rowwise: bool = False, + per_token_columnwise: bool = False, eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: + assert not ( + per_token_rowwise and per_token_columnwise + ), "Per-token rowwise and columnwise reference modes are mutually exclusive." if x.ndim != 2: raise ValueError( f"_quantize_blockwise_reference expects a 2D tensor, got {x.ndim}D with shape" @@ -488,6 +495,11 @@ def _quantize_blockwise_reference( decode_scale.to(torch.float32), ) else: + if per_token_rowwise: + global_amax = global_amax.to(torch.float32).view(m, 1, 1) + if per_token_columnwise: + global_amax = global_amax.to(torch.float32).view(1, n // tile_len_x, tile_len_x) + global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) global_encode_scale = torch.min( global_encode_scale, @@ -497,14 +509,28 @@ def _quantize_blockwise_reference( dtype=torch.float32, ), ) - if global_encode_scale == torch.tensor(0.0, device=x.device, dtype=torch.float32): - global_encode_scale = torch.tensor(1.0, device=x.device, dtype=torch.float32) + if global_encode_scale.numel() == 1: + if global_encode_scale == torch.tensor(0.0, device=x.device, dtype=torch.float32): + global_encode_scale = torch.tensor(1.0, device=x.device, dtype=torch.float32) + else: + global_encode_scale = torch.where( + global_encode_scale == 0.0, + torch.ones_like(global_encode_scale), + global_encode_scale, + ) global_decode_scale = torch.div(1.0, global_encode_scale) global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) - # Match the kernel's default path: fold the FP4 reciprocal into the - # global scale multiplier, but keep the final reciprocal exact. - decode_scale = vec_max * global_encode_scale_multiplier + if per_token_columnwise: + decode_scale = torch.amax( + torch.abs(x.to(torch.float32)) * global_encode_scale_multiplier, + dim=-1, + keepdim=True, + ) + else: + # Match the kernel's default path: fold the FP4 reciprocal into the + # global scale multiplier, but keep the final reciprocal exact. + decode_scale = vec_max * global_encode_scale_multiplier decode_scale = torch.min( decode_scale, torch.tensor( @@ -609,6 +635,10 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ raise ValueError( f"MXFP4 only supports 1x32 tile shape, got {self.quant_tile_shape}" ) + if self.per_token_activation: + raise ValueError( + "Per-token activation is only supported for NVFP4 (non-pow2) mode." + ) # TODO(etsykunov): Fix bug where global_amax_row and # global_amax_col are not defined # global_amax = torch.empty(0, device=tensor.device, dtype=torch.float32) @@ -625,13 +655,22 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ if self.with_rht else tensor.t().contiguous() ) - # Compute amax for rowwise and columnwise paths separately - global_amax_row = torch.max(torch.abs(row_input)).to(torch.float32).view(1) - global_amax_col = ( - torch.max(torch.abs(col_input)).to(torch.float32).view(1) - if self.columnwise_usage - else global_amax_row - ) + if self.per_token_activation: + if self.quant_tile_shape != (1, 16): + raise ValueError( + "Per-token activation only supports NVFP4 1x16 tile shape, " + f"got {self.quant_tile_shape}" + ) + global_amax_row = torch.max(torch.abs(row_input), dim=1).values.to(torch.float32) + global_amax_col = global_amax_row + else: + # Compute amax for rowwise and columnwise paths separately + global_amax_row = torch.max(torch.abs(row_input)).to(torch.float32).view(1) + global_amax_col = ( + torch.max(torch.abs(col_input)).to(torch.float32).view(1) + if self.columnwise_usage + else global_amax_row + ) transpose_scales = False @@ -648,6 +687,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[1], self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, + per_token_rowwise=self.per_token_activation, eps=self.eps, ) if transpose_scales: @@ -671,6 +711,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[1], self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, + per_token_columnwise=self.per_token_activation, eps=self.eps, ) @@ -868,7 +909,11 @@ def qgemm( partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col else: partial_alpha = qresult_x.global_amax_row * qresult_w.global_amax_row - alpha = torch.div(partial_alpha, factor).squeeze(-1) + if partial_alpha.numel() > 1 and partial_alpha.numel() == high_precision_x.shape[0]: + partial_alpha = partial_alpha.view(-1, 1) + else: + partial_alpha = partial_alpha.squeeze(-1) + alpha = torch.div(partial_alpha, factor) M, K = high_precision_x.shape N, K_w = high_precision_w.shape diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 9956fb77ec..2cb6c21946 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1375,6 +1375,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=qparams.random_hadamard_transform, with_2d_quantization=qparams.fp4_2d_quantization, stochastic_rounding=qparams.stochastic_rounding, + per_token_activation=self.recipe.per_token_activation and idx % 3 != 1, ) return [_make_quantizer(idx) for idx in range(self.num_quantizers)] @@ -1389,6 +1390,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, + per_token_activation=False, ) for _ in range(self.num_quantizers) ] diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 65678aa347..53f77da9e4 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -128,6 +128,9 @@ class NVFP4Quantizer(Quantizer): """Stochastic rounding, only applicable for gradients.""" stochastic_rounding: bool + """Per-token activation quantization path (grouped split quantize).""" + per_token_activation: bool + """RHT matrix random sign mask""" rht_matrix_random_sign_mask_t: int rht_matrix: torch.Tensor @@ -143,6 +146,7 @@ def __init__( with_post_rht_amax: bool = False, with_2d_quantization: bool = False, stochastic_rounding: bool = False, + per_token_activation: bool = False, with_random_sign_mask: bool = True, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) @@ -153,6 +157,7 @@ def __init__( self.amax_reduction_group = amax_reduction_group self.with_2d_quantization = with_2d_quantization self.stochastic_rounding = stochastic_rounding + self.per_token_activation = per_token_activation self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) @@ -198,6 +203,7 @@ def copy(self) -> NVFP4Quantizer: with_post_rht_amax=self.with_post_rht_amax, with_2d_quantization=self.with_2d_quantization, stochastic_rounding=self.stochastic_rounding, + per_token_activation=self.per_token_activation, ) quantizer.internal = self.internal quantizer.optimize_for_gemm = self.optimize_for_gemm @@ -330,7 +336,10 @@ def make_empty( scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory ) # Allocate per tensor scale inverse. FP32 format. - amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device, pin_memory=pin_memory) + amax_rows = flat_first_dim if self.per_token_activation else 1 + amax_rowwise = torch.zeros( + amax_rows, dtype=torch.float32, device=device, pin_memory=pin_memory + ) # Allocate FP8 data transpose if needed columnwise_data = None @@ -353,8 +362,9 @@ def make_empty( device=device, pin_memory=pin_memory, ) + amax_rows = flat_first_dim if self.per_token_activation else 1 amax_columnwise = torch.zeros( - 1, dtype=torch.float32, device=device, pin_memory=pin_memory + amax_rows, dtype=torch.float32, device=device, pin_memory=pin_memory ) # Construct FP8 tensor diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 5f12c3ed8c..1732abf57c 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -662,6 +662,10 @@ def make_grouped_tensor( # Amax buffer for delayed scaling - one per tensor amax = torch.empty(num_tensors, dtype=torch.float32, device=device) elif quantizer._get_compatible_recipe().nvfp4(): + per_token_activation = getattr(quantizer, "per_token_activation", False) + total_amax_elements = ( + sum(math.prod(s[:-1]) for s in shape) if per_token_activation else num_tensors + ) if rowwise_usage: # Allocate rowwise data buffer (1D flattened, uint8, but FP4 packs 2 values per byte) @@ -675,8 +679,7 @@ def make_grouped_tensor( total_scale_elements += math.prod(scale_inv_shape) scale_inv_offsets.append(total_scale_elements) scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) - # Amax buffer - one per tensor - amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + amax = torch.empty(total_amax_elements, dtype=torch.float32, device=device) if columnwise_usage: # Allocate columnwise data buffer (1D flattened, uint8, FP4 packed) @@ -693,8 +696,9 @@ def make_grouped_tensor( columnwise_scale_inv = torch.empty( total_columnwise_scale_elements, dtype=torch.uint8, device=device ) - # Columnwise amax buffer - one per tensor - columnwise_amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + columnwise_amax = torch.empty( + total_amax_elements, dtype=torch.float32, device=device + ) elif quantizer._get_compatible_recipe().float8_block_scaling(): if rowwise_usage: # Allocate rowwise data buffer (1D flattened, uint8) @@ -891,6 +895,13 @@ def split_into_quantized_tensors( cum += math.prod(scale_shape) columnwise_scale_inv_offsets.append(cum) self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets + nvfp4_per_token_amax_offsets = None + if recipe.nvfp4() and getattr(self.quantizer, "per_token_activation", False): + cum = 0 + nvfp4_per_token_amax_offsets = [0] + for i in range(self.num_tensors): + cum += math.prod(self.tensor_shapes[i][:-1]) + nvfp4_per_token_amax_offsets.append(cum) for i in range(self.num_tensors): quantizer = self.quantizer @@ -1083,12 +1094,21 @@ def split_into_quantized_tensors( cscale_shape ) - # Extract amax - one per tensor if self.amax is not None: - amax_rowwise = self.amax[i : i + 1] + if nvfp4_per_token_amax_offsets is not None: + amax_start = nvfp4_per_token_amax_offsets[i] + amax_end = nvfp4_per_token_amax_offsets[i + 1] + amax_rowwise = self.amax[amax_start:amax_end] + else: + amax_rowwise = self.amax[i : i + 1] if self.columnwise_amax is not None: - amax_columnwise = self.columnwise_amax[i : i + 1] + if nvfp4_per_token_amax_offsets is not None: + amax_start = nvfp4_per_token_amax_offsets[i] + amax_end = nvfp4_per_token_amax_offsets[i + 1] + amax_columnwise = self.columnwise_amax[amax_start:amax_end] + else: + amax_columnwise = self.columnwise_amax[i : i + 1] if quantizer.internal: nvfp4_tensor_class = NVFP4TensorStorage