From 14f77bee4f7231a0d2b9b343c16582b214a9c498 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 00:37:14 -0700 Subject: [PATCH 01/21] Adapt initial implementation and make quantization bitwise exact Signed-off-by: Ziang Li Co-authored-by: Yigong Qin --- docs/envvars.rst | 6 + .../nvfp4/test_nvfp4_quantize_exact.py | 55 +++++ tests/pytorch/test_backward_override.py | 27 ++- tests/pytorch/utils.py | 10 +- .../common/cast/dispatch/quantize.cuh | 19 ++ .../common/cast/nvfp4/dequantize_nvfp4.cuh | 14 +- .../cast/nvfp4/quantize_pertoken_nvfp4.cuh | 220 ++++++++++++++++++ transformer_engine/common/common.h | 4 +- .../transformer_engine/transformer_engine.h | 9 + transformer_engine/common/recipe/__init__.py | 6 + .../common/transformer_engine.cpp | 6 + .../pytorch/cpp_extensions/gemm.py | 74 +++++- transformer_engine/pytorch/csrc/common.h | 1 + transformer_engine/pytorch/csrc/extensions.h | 2 + .../pytorch/csrc/extensions/cast.cpp | 131 ++++++++++- .../pytorch/csrc/extensions/pybind.cpp | 2 + transformer_engine/pytorch/csrc/quantizer.cpp | 49 +++- .../custom_recipes/quantization_nvfp4.py | 60 ++++- transformer_engine/pytorch/quantization.py | 2 + .../pytorch/tensor/nvfp4_tensor.py | 11 +- 20 files changed, 663 insertions(+), 45 deletions(-) create mode 100644 transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh diff --git a/docs/envvars.rst b/docs/envvars.rst index 1e040b4c3e..58988b5473 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -281,6 +281,12 @@ Kernel Configuration :Default: ``0`` :Description: Emit a warning when falling back from CUTLASS to cuBLAS for grouped GEMM operations. +.. envvar:: NVTE_NVFP4_PER_TOKEN_ACTIVATION + + :Type: ``int`` (0 or 1) + :Default: ``0`` + :Description: Enable per-token activation quantization for the ``NVFP4BlockScaling`` recipe in GroupedLinear split-quantize paths. When set to ``1`` (or when ``NVFP4BlockScaling(per_token_activation=True)`` is used), NVFP4 rowwise ``amax`` metadata stores one FP32 value per token (row) instead of a single scalar. + Torch Compilation and Fusion ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index bf3f545b8b..7e94911ddd 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -23,6 +23,20 @@ def unpack_fp4(x: torch.Tensor) -> torch.Tensor: return repeated +def maybe_skip_pertoken_nvfp4( + x_dtype: torch.dtype = torch.bfloat16, + *, + return_transpose: bool = False, + with_2d_quantization: bool = False, +) -> None: + if x_dtype == torch.float32: + pytest.skip("Per-token NVFP4 kernel supports BF16/FP16 inputs only") + if return_transpose: + pytest.skip("Per-token NVFP4 currently supports rowwise-only quantization") + if with_2d_quantization: + pytest.skip("Per-token NVFP4 does not support 2D quantization") + + def check_quantization_nvfp4_versus_reference( x_dtype: torch.dtype, M: int, @@ -31,6 +45,7 @@ def check_quantization_nvfp4_versus_reference( swizzled_scale: bool, use_cpp_allocator: bool, with_2d_quantization: bool, + per_token_activation: bool = False, ) -> None: te_dtype = tex.DType.kFloat4E2M1 @@ -52,6 +67,7 @@ def check_quantization_nvfp4_versus_reference( with_rht=False, with_post_rht_amax=False, with_2d_quantization=with_2d_quantization, + per_token_activation=per_token_activation, ) if use_cpp_allocator: x_nvfp4_sut = nvfp4_quantizer(x) @@ -83,6 +99,7 @@ def check_quantization_nvfp4_versus_reference( pow_2_scales=False, eps=0.0, quant_tile_shape=quant_tile_shape, + per_token_activation=per_token_activation, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -155,6 +172,9 @@ def check_quantization_nvfp4_versus_reference( @pytest.mark.parametrize( "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] ) +@pytest.mark.parametrize( + "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] +) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -163,7 +183,14 @@ def test_quantization_block_tiling_versus_reference( swizzled_scale: bool, use_cpp_allocator: bool, with_2d_quantization: bool, + per_token_activation: bool, ) -> None: + if per_token_activation: + maybe_skip_pertoken_nvfp4( + x_dtype=x_dtype, + return_transpose=return_transpose, + with_2d_quantization=with_2d_quantization, + ) check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, M=M, @@ -172,6 +199,7 @@ def test_quantization_block_tiling_versus_reference( swizzled_scale=swizzled_scale, use_cpp_allocator=use_cpp_allocator, with_2d_quantization=with_2d_quantization, + per_token_activation=per_token_activation, ) @@ -188,6 +216,9 @@ def test_quantization_block_tiling_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) +@pytest.mark.parametrize( + "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] +) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, M: int, @@ -195,6 +226,7 @@ def test_nvfp4_quantization_extrema_versus_reference( extrema_high: bool, return_transpose: bool, use_cpp_allocator: bool, + per_token_activation: bool, ): te_dtype = tex.DType.kFloat4E2M1 @@ -208,6 +240,9 @@ def test_nvfp4_quantization_extrema_versus_reference( else: x = torch.zeros((M, N), dtype=x_dtype, device=device) + if per_token_activation: + maybe_skip_pertoken_nvfp4(x_dtype=x_dtype, return_transpose=return_transpose) + nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, @@ -216,6 +251,7 @@ def test_nvfp4_quantization_extrema_versus_reference( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + per_token_activation=per_token_activation, ) if use_cpp_allocator: @@ -245,6 +281,7 @@ def test_nvfp4_quantization_extrema_versus_reference( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), + per_token_activation=per_token_activation, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -286,12 +323,16 @@ def test_nvfp4_quantization_extrema_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) +@pytest.mark.parametrize( + "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] +) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, M: int, N: int, return_transpose: bool, use_cpp_allocator: bool, + per_token_activation: bool, ): """ Stress rounding/threshold behavior by placing values just below/above @@ -319,6 +360,9 @@ def test_nvfp4_quantization_boundary_values( row[1::2] = upper x = row.unsqueeze(0).repeat(M, 1).to(dtype=x_dtype) + if per_token_activation: + maybe_skip_pertoken_nvfp4(x_dtype=x_dtype, return_transpose=return_transpose) + nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, @@ -327,6 +371,7 @@ def test_nvfp4_quantization_boundary_values( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + per_token_activation=per_token_activation, ) if use_cpp_allocator: @@ -356,6 +401,7 @@ def test_nvfp4_quantization_boundary_values( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), + per_token_activation=per_token_activation, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -397,12 +443,16 @@ def test_nvfp4_quantization_boundary_values( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) +@pytest.mark.parametrize( + "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] +) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, N: int, return_transpose: bool, use_cpp_allocator: bool, + per_token_activation: bool, ): te_dtype = tex.DType.kFloat4E2M1 @@ -416,6 +466,9 @@ def test_nvfp4_quantization_noncontiguous_inputs( x_nc = x_base.t() # shape (N, M), non-contiguous assert not x_nc.is_contiguous() + if per_token_activation: + maybe_skip_pertoken_nvfp4(x_dtype=x_dtype, return_transpose=return_transpose) + nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, @@ -424,6 +477,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + per_token_activation=per_token_activation, ) if use_cpp_allocator: @@ -453,6 +507,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), + per_token_activation=per_token_activation, ) x_nvfp4_ref = ref_quantizer.quantize(x_nc) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index ed4f73adbc..5da55b14b6 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -78,6 +78,11 @@ marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), id="NVFP4BlockScaling", ), + pytest.param( + "nvfp4_pertoken", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + id="NVFP4PerTokenBlockScaling", + ), ] @@ -165,7 +170,7 @@ def _maybe_skip_recipe_dtype( ) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) - if recipe_name == "nvfp4": + if recipe_name in ("nvfp4", "nvfp4_pertoken"): if module_type in ("linear", "layernorm_linear") and dtype not in ( torch.bfloat16, torch.float32, @@ -178,6 +183,16 @@ def _maybe_skip_recipe_dtype( def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") + if recipe_name == "nvfp4_pertoken" and module_type in ( + "linear", + "layernorm_linear", + "ops_linear", + "grouped_linear", + ): + pytest.skip( + "Per-token NVFP4 currently supports rowwise-only quantization paths " + "(columnwise usage is unsupported for these modules)." + ) def _maybe_skip_unsupported_recipe_shape( @@ -195,7 +210,9 @@ def _maybe_skip_unsupported_recipe_shape( " by 32." ) return - if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if recipe_name in ("nvfp4", "nvfp4_pertoken") and ( + flat_first_dim % 16 != 0 or last_dim % 16 != 0 + ): pytest.skip( "Linear/LayerNormLinear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible" " by 16." @@ -220,7 +237,9 @@ def _maybe_skip_unsupported_recipe_shape( pytest.skip( "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." ) - if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if recipe_name in ("nvfp4", "nvfp4_pertoken") and ( + flat_first_dim % 16 != 0 or last_dim % 16 != 0 + ): pytest.skip( "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." ) @@ -239,7 +258,7 @@ def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int] ) if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") - if recipe_name == "nvfp4" and any(m % 16 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_pertoken") and any(m % 16 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + NVFP4 requires each non-empty m_split divisible by 16.") if recipe_name == "nvfp4" and any(m % 64 != 0 for m in non_empty_splits): pytest.skip( diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 8f8852edc2..04ac2becbc 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -115,7 +115,7 @@ def quantization_tols(name: str) -> dict[str, float]: "mxfp8_block_scaling", ): return dtype_tols(tex.DType.kFloat8E4M3) - if name == "nvfp4": + if name in ("nvfp4", "nvfp4_pertoken"): return dtype_tols(tex.DType.kFloat4E2M1) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -149,6 +149,14 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: disable_2d_quantization=True, **recipe_kwargs, ) + if name == "nvfp4_pertoken": + return transformer_engine.common.recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + per_token_activation=True, + **recipe_kwargs, + ) raise ValueError(f"Unsupported quantization scheme ({name})") diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 5d0d3c28e8..0d86022cc1 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -21,6 +21,7 @@ #include "../mxfp8/group_quantize_mxfp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" +#include "../nvfp4/quantize_pertoken_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" namespace transformer_engine { @@ -100,6 +101,15 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, int32_t rows = input_tensor->flat_first_dim(); int32_t cols = input_tensor->flat_last_dim(); auto dtype = input_tensor->dtype(); + const bool per_token_activation = quant_config_cpp.nvfp4_per_token_activation; + if (per_token_activation) { + NVTE_CHECK(!output_tensor->has_columnwise_data(), + "Per-token NVFP4 quantization supports rowwise-only output."); + NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + "Per-token NVFP4 quantization does not support 2D quantization."); + nvfp4::quantize_pertoken(*input_tensor, noop_tensor, output_tensor, stream); + break; + } bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); @@ -239,6 +249,15 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens int32_t rows = grad_tensor->flat_first_dim(); int32_t cols = grad_tensor->flat_last_dim(); auto dtype = grad_tensor->dtype(); + const bool per_token_activation = quant_config_cpp.nvfp4_per_token_activation; + if (per_token_activation) { + NVTE_CHECK(!output_tensor->has_columnwise_data(), + "Per-token NVFP4 quantization supports rowwise-only output."); + NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + "Per-token NVFP4 quantization does not support 2D quantization."); + nvfp4::quantize_pertoken(*grad_tensor, noop_tensor, output_tensor, stream); + break; + } bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index 4143208153..9436b94939 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -34,7 +34,7 @@ namespace dequantize_kernel { template __global__ void __launch_bounds__(512) dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, - const float *const tensor_amax, const size_t N, const size_t M, + const float *const tensor_amax, const size_t amax_numel, const size_t N, const size_t M, const size_t scale_stride, const size_t num_scale_tiles_X) { const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t x = thread_idx % M; @@ -63,7 +63,7 @@ __global__ void __launch_bounds__(512) fp4vec value; value.vec = input_vectorized[my_index]; fp8e4m3 scale = scales[my_scale_index]; - float amax = *tensor_amax; + float amax = (amax_numel == 1) ? tensor_amax[0] : tensor_amax[y]; constexpr float factor_inv = 1.0 / (6.0 * 448.0); float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll @@ -110,11 +110,11 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, dequantize_fp4_kernel<<>>( - input.data.dptr, reinterpret_cast(output->data.dptr), - reinterpret_cast(input.scale_inv.dptr), - reinterpret_cast(input.amax.dptr), N, Mread, input.scale_inv.shape.back(), - num_scale_tiles_X);); // NOLINT(*) - ); // NOLINT(*) + input.data.dptr, reinterpret_cast(output->data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(input.amax.dptr), input.amax.numel(), N, Mread, input.scale_inv.shape.back(), + num_scale_tiles_X);); // NOLINT(*) +); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); #else NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); diff --git a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh new file mode 100644 index 0000000000..5e1e23f5d5 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh @@ -0,0 +1,220 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_pertoken_nvfp4.cuh + * \brief CUDA kernels to cast to NVFP4 with per-token (per-row) global scaling. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_PERTOKEN_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_PERTOKEN_NVFP4_CUH_ + +#include +#include + +#include +#include + +#include "../../common.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "core_nvfp4.cuh" + +#if FP4_TYPE_SUPPORTED +#include +#endif + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { +namespace quantize_pertoken_kernel { + +using namespace core; +using namespace ptx; + +constexpr int PERTOKEN_BLOCK_SIZE = 256; +constexpr int PERTOKEN_SF_VEC_SIZE = 16; + +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(BLOCK_SIZE) +#endif + quantize_pertoken_nvfp4_kernel( + const int num_rows, const int num_cols, const IType *__restrict__ input, + const int *__restrict__ row_offsets, uint8_t *__restrict__ output_data, + fp8e4m3 *__restrict__ output_scales, float *__restrict__ output_per_token_amax, + const int scale_stride, const float *__restrict__ noop) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using namespace detail; + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + using IType2 = typename ptx::FPx2; + + const int row_idx = blockIdx.x; + if (row_idx >= num_rows) return; + + const int actual_row = (row_offsets != nullptr) ? row_offsets[row_idx] : row_idx; + if (actual_row < 0) return; + + const int num_vec2 = num_cols / 2; + const IType2 *input_row = reinterpret_cast(input + actual_row * num_cols); + + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + for (int i = threadIdx.x; i < num_vec2; i += BLOCK_SIZE) { + const IType2 val = input_row[i]; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val); + } + const float thread_max = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + float row_amax = + BlockReduce(temp_storage).Reduce(thread_max, [](float a, float b) { return fmaxf(a, b); }); + + __shared__ float shared_s_enc; + if (threadIdx.x == 0) { + const float s_enc = compute_global_encode_scaling_factor_FP4(row_amax); + output_per_token_amax[row_idx] = row_amax; + shared_s_enc = s_enc; + } + __syncthreads(); + const float S_enc = shared_s_enc; + const float S_dec_rowwise = 1.0 / S_enc; + constexpr float fp4_max_inv = 1.0f / detail::TypeExtrema::max; + const float global_encode_scale_multiplier = S_enc * fp4_max_inv; + + const int num_sf_blocks = num_cols / PERTOKEN_SF_VEC_SIZE; + for (int sf_idx = threadIdx.x; sf_idx < num_sf_blocks; sf_idx += BLOCK_SIZE) { + const int col_start = sf_idx * PERTOKEN_SF_VEC_SIZE; + + IType2 block_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + alignas(8) IType2 vals[PERTOKEN_SF_VEC_SIZE / 2]; + const IType2 *input_block = + reinterpret_cast(input + actual_row * num_cols + col_start); + for (int j = 0; j < PERTOKEN_SF_VEC_SIZE / 2; ++j) { + vals[j] = input_block[j]; + ptx::abs_max_2x(block_amax_2x, block_amax_2x, vals[j]); + } + const float block_max = + static_cast(__hmax(__habs(block_amax_2x.x), __habs(block_amax_2x.y))); + + const float S_dec_b_f32 = + fminf(block_max * global_encode_scale_multiplier, detail::TypeExtrema::max); + const nvfp4_scale_t S_dec_b_fp8 = static_cast(S_dec_b_f32); + output_scales[row_idx * scale_stride + sf_idx] = S_dec_b_fp8; + + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = + fminf(1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + uint8_t *out_ptr = output_data + actual_row * (num_cols / 2) + col_start / 2; + if constexpr (std::is_same_v) { + auto *out_fp4_8x = reinterpret_cast(out_ptr); + for (int j = 0; j < PERTOKEN_SF_VEC_SIZE / 2; j += 4) { + const uint64_t elts03 = *reinterpret_cast(&vals[j]); + const uint64_t elts47 = *reinterpret_cast(&vals[j + 2]); + out_fp4_8x[j / 4] = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest( + elts03, elts47, block_scale_inverse); + } + } else { + auto *out_fp4 = reinterpret_cast(out_ptr); + for (int j = 0; j < PERTOKEN_SF_VEC_SIZE / 2; j += 2) { + const float2 in01 = + make_float2(static_cast(vals[j].x), static_cast(vals[j].y)); + const float2 in23 = + make_float2(static_cast(vals[j + 1].x), static_cast(vals[j + 1].y)); + out_fp4[j / 2] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, /*rbits=*/0u); + } + } + } +#endif +} + +template +void launch_quantize_pertoken_nvfp4(const int num_rows, const int num_cols, const IType *input, + const int *row_offsets, uint8_t *output_data, + fp8e4m3 *output_scales, float *output_per_token_amax, + const int scale_stride, cudaStream_t stream, + const float *noop = nullptr) { +#if FP4_TYPE_SUPPORTED + if (num_rows == 0 || num_cols == 0) return; + + NVTE_CHECK(num_cols % PERTOKEN_SF_VEC_SIZE == 0, "num_cols must be a multiple of ", + PERTOKEN_SF_VEC_SIZE, " for per-token NVFP4 quantization, got ", num_cols); + dim3 grid(num_rows); + dim3 block(PERTOKEN_BLOCK_SIZE); + + quantize_pertoken_nvfp4_kernel + <<>>(num_rows, num_cols, input, row_offsets, output_data, + output_scales, output_per_token_amax, scale_stride, noop); + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); +#endif +} + +} // namespace quantize_pertoken_kernel + +inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *output, + cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", false); + + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + NVTE_CHECK(output->amax.dptr != nullptr, "Per-token amax tensor must be allocated."); + NVTE_CHECK(!output->has_columnwise_data(), + "Per-token NVFP4 quantization supports rowwise-only output."); + NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format."); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + NVTE_CHECK(cols % quantize_pertoken_kernel::PERTOKEN_SF_VEC_SIZE == 0, + "Per-token NVFP4 quantization requires last dim divisible by ", + quantize_pertoken_kernel::PERTOKEN_SF_VEC_SIZE, "."); + + const auto *noop_ptr = reinterpret_cast(noop->data.dptr); + auto *data_ptr = reinterpret_cast(output->data.dptr); + auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); + auto *amax_ptr = reinterpret_cast(output->amax.dptr); + const int *row_offsets = nullptr; + const int scale_stride = static_cast(output->scale_inv.shape.back()); + + if (input.dtype() == DType::kBFloat16) { + quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4<__nv_bfloat16>( + static_cast(rows), static_cast(cols), + reinterpret_cast(input.data.dptr), row_offsets, data_ptr, scale_ptr, + amax_ptr, scale_stride, stream, noop_ptr); + } else if (input.dtype() == DType::kFloat16) { + quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4( + static_cast(rows), static_cast(cols), + reinterpret_cast(input.data.dptr), row_offsets, data_ptr, scale_ptr, amax_ptr, + scale_stride, stream, noop_ptr); + } else { + NVTE_ERROR( + "Unsupported input dtype for per-token NVFP4 quantization. " + "Expected BFloat16 or Float16."); + } +#else + NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); +#endif +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_PERTOKEN_NVFP4_CUH_ diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index c1b3f8f427..c5b4254e8b 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -470,6 +470,7 @@ struct QuantizationConfig { bool nvfp4_2d_quantization = false; bool stochastic_rounding = false; bool use_fast_math = false; + bool nvfp4_per_token_activation = false; static constexpr size_t attr_sizes[] = { sizeof(uint8_t), // force_pow_2_scales @@ -479,7 +480,8 @@ struct QuantizationConfig { sizeof(NVTETensor), // rng_seed and offset sizeof(uint8_t), // nvfp4_2d_quantization sizeof(uint8_t), // stochastic_rounding - sizeof(uint8_t) // use_fast_math + sizeof(uint8_t), // use_fast_math + sizeof(uint8_t) // nvfp4_per_token_activation }; }; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index b7461a85d1..0463d51d1c 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -370,6 +370,8 @@ enum NVTEQuantizationConfigAttribute { * inconsistently between kernels. */ kNVTEQuantizationConfigUseFastMath = 7, + /*! Whether to enable per-token (per-row) NVFP4 quantization */ + kNVTEQuantizationConfigNVFP4PerTokenActivation = 8, kNVTEQuantizationConfigNumAttributes }; @@ -1296,6 +1298,13 @@ class QuantizationConfigWrapper { sizeof(val)); } + /*! \brief Set whether to enable per-token NVFP4 quantization */ + void set_nvfp4_per_token_activation(bool nvfp4_per_token_activation) { + const auto val = static_cast(nvfp4_per_token_activation); + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP4PerTokenActivation, + &val, sizeof(val)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 67b6f87067..e59d01d82a 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -478,6 +478,10 @@ class NVFP4BlockScaling(Recipe): If set to `True`, stochastic rounding is disabled during quantization for all tensors. disable_2d_quantization : bool, default = False If set to `True`, 1D block scaling with block size 16 is used for all tensors. + per_token_activation : bool, default = False + If set to `True`, GroupedLinear activation split quantization uses per-token + (per-row) NVFP4 global amax values. In this mode, rowwise ``amax`` metadata + is stored as a vector with one FP32 value per token. backward_override : {None, 'high_precision', 'dequantized'}, default = None Backward precision mode. None does not modify backward behavior, `high_precision` keeps original high-precision operands for backward, @@ -491,6 +495,7 @@ class NVFP4BlockScaling(Recipe): os.getenv("NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING", "0") == "1" ) disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1" + per_token_activation: bool = os.getenv("NVTE_NVFP4_PER_TOKEN_ACTIVATION", "0") == "1" fp4_format: Format = Format.E2M1 fp8_format: Format = Format.E4M3 @@ -534,6 +539,7 @@ def __repr__(self) -> str: f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " f"backward_override={self.backward_override}, " + f"per_token_activation={self.per_token_activation}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 1261879a8b..a0a0ffa45f 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1043,6 +1043,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: bool_to_uint8(config_.use_fast_math, buf); break; + case kNVTEQuantizationConfigNVFP4PerTokenActivation: + bool_to_uint8(config_.nvfp4_per_token_activation, buf); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -1098,6 +1101,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: uint8_to_bool(buf, config_.use_fast_math); break; + case kNVTEQuantizationConfigNVFP4PerTokenActivation: + uint8_to_bool(buf, config_.nvfp4_per_token_activation); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 6f3553bf94..9cf58f9dce 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -15,6 +15,7 @@ from ..quantized_tensor import Quantizer from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage from ..tensor.utils import is_custom from ..custom_recipes.gemm import custom_gemm from ...debug.pytorch.debug_quantization import DebugQuantizer @@ -69,6 +70,50 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float: return 0.0 +def _maybe_apply_nvfp4_pertoken_output_rescale( + out: torch.Tensor, + B: torch.Tensor, + *, + layout: str, + bias: Optional[torch.Tensor], + grad: bool, + gelu: bool, + accumulate: bool, +) -> None: + """Apply per-token NVFP4 global-scale correction for TN forward GEMM outputs. + + Current NVFP4 GEMM alpha path consumes one scalar amax. Per-token NVFP4 stores + rowwise amax vector in B._amax_rowwise, so we correct by row using ratio + (amax[row] / amax[0]). If bias was fused in epilogue, remove/reapply it around + the row rescale to avoid bias distortion. + """ + + if grad or gelu or accumulate or layout != "TN": + return + if not isinstance(B, NVFP4TensorStorage): + return + if not isinstance(out, torch.Tensor) or is_custom(out): + return + if out.numel() == 0: + return + amax = B._amax_rowwise + if amax is None or amax.numel() <= 1: + return + + out_2d = out.reshape(-1, out.shape[-1]) + if amax.numel() != out_2d.shape[0]: + return + + ratios = (amax / amax[0]).to(dtype=out.dtype).view(-1, 1) + if bias is not None: + bias_cast = bias.to(dtype=out.dtype) + out_2d.sub_(bias_cast) + out_2d.mul_(ratios) + out_2d.add_(bias_cast) + else: + out_2d.mul_(ratios) + + def general_gemm( A: torch.Tensor, B: torch.Tensor, @@ -147,6 +192,22 @@ def general_gemm( # FP8 block-scaling requires split accumulator use_split_accumulator = True + requested_out_dtype = out_dtype + needs_fp32_rescale_path = ( + layout == "TN" + and not grad + and not gelu + and not accumulate + and isinstance(B, NVFP4TensorStorage) + and B._amax_rowwise is not None + and B._amax_rowwise.numel() > 1 + and quantization_params is None + and out is None + and requested_out_dtype is not None + and requested_out_dtype != torch.float32 + ) + effective_out_dtype = torch.float32 if needs_fp32_rescale_path else requested_out_dtype + args = ( A, transa, # transa @@ -154,7 +215,7 @@ def general_gemm( transb, # transb out, quantization_params, - TE_DType[out_dtype] if out_dtype is not None else None, + TE_DType[effective_out_dtype] if effective_out_dtype is not None else None, bias, bias_dtype, gelu, @@ -175,6 +236,17 @@ def general_gemm( } out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) + _maybe_apply_nvfp4_pertoken_output_rescale( + out, + B, + layout=layout, + bias=bias, + grad=grad, + gelu=gelu, + accumulate=accumulate, + ) + if needs_fp32_rescale_path: + out = out.to(dtype=requested_out_dtype) if debug_quantizer is not None: out = debug_quantizer.process_gemm_output(out) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 8e3bcdd5b3..b9f852c07d 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -320,6 +320,7 @@ class NVFP4Quantizer : public Quantizer { // 2D block scaling bool with_2d_quantization; bool stochastic_rounding; + bool per_token_activation; int rht_matrix_random_sign_mask_t; at::Tensor rht_matrix; diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 4a2ea7412b..06478b54e0 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -326,6 +326,8 @@ py::object group_dequantize(const py::handle &input, DType otype); py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, std::optional first_dims); +std::tuple quantize_nvfp4_pertoken(at::Tensor input); + std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 50fe4c109e..f1654f5525 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -801,6 +801,7 @@ std::tuple, std::vector, bool> bulk_alloc const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto fp4_dtype = quantizer_cpp_list[0]->dtype; + const bool per_token_activation = quantizer_cpp_list[0]->per_token_activation; const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) Enable based on optimize_for_gemm; constexpr size_t scale_elem_size = 1; @@ -828,6 +829,16 @@ std::tuple, std::vector, bool> bulk_alloc } return fp4_shape; }; + auto flat_first_dim = [](const std::vector &shape) -> size_t { + if (shape.empty()) { + return 1; + } + size_t rows = 1; + for (size_t i = 0; i + 1 < shape.size(); ++i) { + rows *= shape[i]; + } + return rows; + }; // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list, amax_rowwise_list; @@ -866,7 +877,9 @@ std::tuple, std::vector, bool> bulk_alloc // Note: Multi-quantize kernel does not require contiguous amaxes. const auto offset = roundup(buffer_size, 16); amax_offsets.push_back(offset); - buffer_size = offset + 4; + const size_t amax_size = + per_token_activation ? 4 * flat_first_dim(rowwise_data_shapes[i]) : 4; + buffer_size = offset + amax_size; } // Allocate full buffer @@ -879,8 +892,11 @@ std::tuple, std::vector, bool> bulk_alloc data_offsets[i], torch::kUInt8)); rowwise_scale_list.emplace_back( make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + const std::vector amax_shape = + per_token_activation ? std::vector{flat_first_dim(rowwise_data_shapes[i])} + : std::vector{1}; amax_rowwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); } } @@ -983,7 +999,7 @@ std::tuple, std::vector, bool> bulk_alloc // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, - std::vector{1}); + getTensorShape(amax_rowwise_list[i])); } if (columnwise_usage) { tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, @@ -1263,6 +1279,35 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, nvte_tensor_output_list.push_back(output_list[i].data()); } + if (quantizer.per_token_activation) { + NVTE_CHECK(!quantizer.with_rht, "Per-token NVFP4 split quantize does not support RHT."); + NVTE_CHECK(!quantizer.columnwise_usage, + "Per-token NVFP4 split quantize currently supports rowwise-only quantization."); + NVTE_CHECK(!quantizer.with_2d_quantization, + "Per-token NVFP4 split quantize does not support 2D quantization."); + NVTE_CHECK(!quantizer.stochastic_rounding, + "Per-token NVFP4 split quantize does not support stochastic rounding."); + + std::vector quant_config_list; + quant_config_list.reserve(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + quant_config_list.emplace_back(QuantizationConfigWrapper()); + quant_config_list.back().set_nvfp4_per_token_activation(true); + } + + for (size_t i = 0; i < num_tensors; i++) { + if (input_list[i].numel() == 0) { + continue; + } + const size_t input_ndim = input_list[i].ndim(); + const size_t cols = input_ndim > 0 ? input_list[i].size(input_ndim - 1) : 1; + NVTE_CHECK(cols % 16 == 0, + "Per-token NVFP4 split quantize requires split inner dim divisible by 16."); + nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config_list[i], stream); + } + return; + } + // In this case without RHT, the rowwise and colwise quantization are fused // we don't need separate rng states for rowwise and colwise bool need_separate_rng_states = false; @@ -1360,8 +1405,13 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, // Check input tensor shape const size_t input_last_dim = input.ndim() > 0 ? input.size(input.ndim() - 1) : 1; - NVTE_CHECK(input_last_dim % 128 == 0, - "NVFP4 multi-quantize requires inner dim to be multiple of 128."); + if (quantizer.per_token_activation) { + NVTE_CHECK(input_last_dim % 16 == 0, + "Per-token NVFP4 split-quantize requires inner dim to be multiple of 16."); + } else { + NVTE_CHECK(input_last_dim % 128 == 0, + "NVFP4 multi-quantize requires inner dim to be multiple of 128."); + } // CUDA stream auto stream = at::cuda::getCurrentCUDAStream(); @@ -1433,12 +1483,25 @@ std::vector split_quantize(const at::Tensor &tensor, for (size_t i = 0; i < num_splits; i++) { quantizer_cpp_list.push_back(convert_quantizer(quantizer_list[i])); } + const bool all_nvfp4_quantizers = std::all_of(quantizer_list.begin(), quantizer_list.end(), + [](const py::handle &quantizer) -> bool { + return detail::IsNVFP4Quantizers(quantizer.ptr()); + }); + const bool all_nvfp4_per_token_activation = + all_nvfp4_quantizers && + std::all_of(quantizer_cpp_list.begin(), quantizer_cpp_list.end(), + [](const std::unique_ptr &quantizer) -> bool { + return static_cast(quantizer.get())->per_token_activation; + }); // Choose implementation for allocating and populating tensors enum class AllocationMethod { UNFUSED, BULK_FP8_BLOCKWISE, BULK_MXFP8, BULK_NVFP4 }; enum class QuantizationMethod { UNFUSED, FUSED_NVFP4 }; AllocationMethod allocation_method = AllocationMethod::UNFUSED; QuantizationMethod quantization_method = QuantizationMethod::UNFUSED; + if (all_nvfp4_per_token_activation) { + quantization_method = QuantizationMethod::FUSED_NVFP4; + } if (!disable_bulk_allocation) { if (std::all_of(quantizer_list.begin(), quantizer_list.end(), [](const py::handle &quantizer) -> bool { @@ -1450,10 +1513,7 @@ std::vector split_quantize(const at::Tensor &tensor, return detail::IsMXFP8Quantizers(quantizer.ptr()); })) { allocation_method = AllocationMethod::BULK_MXFP8; - } else if (std::all_of(quantizer_list.begin(), quantizer_list.end(), - [](const py::handle &quantizer) -> bool { - return detail::IsNVFP4Quantizers(quantizer.ptr()); - })) { + } else if (all_nvfp4_quantizers) { allocation_method = AllocationMethod::BULK_NVFP4; quantization_method = QuantizationMethod::FUSED_NVFP4; } @@ -1492,7 +1552,8 @@ std::vector split_quantize(const at::Tensor &tensor, bool contiguous_data_and_scale = false; std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) = bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); - if (!input_shape.empty() && input_shape.back() % 128 != 0) { + if (!all_nvfp4_per_token_activation && !input_shape.empty() && + input_shape.back() % 128 != 0) { static std::once_flag once_unfused_nvfp4_fallback_warning; std::call_once(once_unfused_nvfp4_fallback_warning, []() { NVTE_WARN( @@ -1502,7 +1563,7 @@ std::vector split_quantize(const at::Tensor &tensor, }); quantization_method = QuantizationMethod::UNFUSED; } - if (!contiguous_data_and_scale) { + if (!all_nvfp4_per_token_activation && !contiguous_data_and_scale) { // Avoid fused quantize kernel if data is not contiguous quantization_method = QuantizationMethod::UNFUSED; } @@ -1540,5 +1601,53 @@ std::vector split_quantize(const at::Tensor &tensor, return output_py_list; } +std::tuple quantize_nvfp4_pertoken(at::Tensor input) { + init_extension(); + + NVTE_CHECK(input.dim() == 2, "Input must be 2D (num_rows, num_cols)"); + NVTE_CHECK(input.is_cuda(), "Input must be on CUDA device"); + NVTE_CHECK(input.scalar_type() == at::ScalarType::BFloat16 || + input.scalar_type() == at::ScalarType::Half, + "Input must be BFloat16 or Half"); + + const int num_rows = input.size(0); + const int num_cols = input.size(1); + NVTE_CHECK(num_cols % 16 == 0, + "num_cols must be a multiple of 16 for per-token NVFP4 quantization"); + + if (num_rows == 0) { + auto options = input.options(); + return {at::empty({0, num_cols / 2}, options.dtype(at::kByte)), + at::empty({0, num_cols / 16}, options.dtype(at::kByte)), + at::empty({0}, options.dtype(at::kFloat))}; + } + + auto input_contig = input.contiguous(); + auto options = input_contig.options(); + + auto output_data = at::empty({num_rows, num_cols / 2}, options.dtype(at::kByte)); + auto output_scales = at::empty({num_rows, num_cols / 16}, options.dtype(at::kByte)); + auto output_per_token_amax = at::empty({num_rows}, options.dtype(at::kFloat)); + + auto te_input = makeTransformerEngineTensor(input_contig); + TensorWrapper te_output(NVTE_NVFP4_1D_SCALING); + te_output.set_rowwise_data( + output_data.data_ptr(), DType::kFloat4E2M1, + std::vector{static_cast(num_rows), static_cast(num_cols)}); + te_output.set_rowwise_scale_inv( + output_scales.data_ptr(), DType::kFloat8E4M3, + std::vector{static_cast(num_rows), static_cast(num_cols / 16)}); + te_output.set_amax(output_per_token_amax.data_ptr(), DType::kFloat32, + std::vector{static_cast(num_rows)}); + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_per_token_activation(true); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + NVTE_SCOPED_GIL_RELEASE( + { nvte_quantize_v2(te_input.data(), te_output.data(), quant_config, stream); }); + + return {output_data, output_scales, output_per_token_amax}; +} + } // namespace pytorch } // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index eb7576d905..4021792f86 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -145,6 +145,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Dequantize group tensor", py::arg("input"), py::arg("otype")); m.def("bgrad_group_quantize", transformer_engine::pytorch::bgrad_group_quantize, py::arg("tensor"), py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims")); + m.def("quantize_nvfp4_pertoken", transformer_engine::pytorch::quantize_nvfp4_pertoken, + "Per-token NVFP4 quantization", py::arg("input")); m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index da91e5c170..d6fedc707b 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1696,6 +1696,7 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast(); this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); + this->per_token_activation = quantizer.attr("per_token_activation").cast(); // Get amax reduction group if needed for NVFP4 AG const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); @@ -1760,9 +1761,10 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve rowwise_scale_inv_shape.end()); rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + const int64_t amax_rows = this->per_token_activation ? static_cast(flat_first_dim) : 1; // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_rowwise = at::empty({1}, bit32_tensor_opts); + amax_rowwise = at::empty({amax_rows}, bit32_tensor_opts); } if (columnwise_usage) { const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), @@ -1850,7 +1852,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), DType::kFloat4E2M1, shape); out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, rowwise_scale_inv_shape); - out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, std::vector{1}); + out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); } if (columnwise_usage) { // enforce 2D shape to avoid [S, B, H] shape and B and be 1 @@ -1862,7 +1864,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, columnwise_scale_inv_shape); out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, - std::vector{1}); + getTensorShape(amax_columnwise)); } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); this->set_quantization_params(&out_cpp); @@ -1975,15 +1977,22 @@ std::pair NVFP4Quantizer::create_unquantized_tensor_w auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); // Register amax pointer from quantized tensor - void* amax_ptr = quantized_tensor.amax(); + auto rowwise_amax = quantized_tensor.get_amax(); + auto columnwise_amax = quantized_tensor.get_columnwise_amax(); + + void* amax_ptr = rowwise_amax.data_ptr; + std::vector amax_shape = convertShape(rowwise_amax.shape); if (amax_ptr == nullptr) { - amax_ptr = quantized_tensor.get_columnwise_amax().data_ptr; + amax_ptr = columnwise_amax.data_ptr; + amax_shape = convertShape(columnwise_amax.shape); } NVTE_CHECK(amax_ptr != nullptr, "Could not extract amax pointer from NVFP4 tensor."); - out_cpp.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + out_cpp.set_amax(amax_ptr, DType::kFloat32, amax_shape); // Zero out amax - NVTE_CHECK_CUDA(cudaMemsetAsync(amax_ptr, 0, sizeof(float), at::cuda::getCurrentCUDAStream())); + const size_t amax_numel = product(amax_shape); + NVTE_CHECK_CUDA( + cudaMemsetAsync(amax_ptr, 0, amax_numel * sizeof(float), at::cuda::getCurrentCUDAStream())); return {std::move(out_cpp), std::move(out_py)}; } @@ -2050,9 +2059,11 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } if (!amax_rowwise) { const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + const int64_t amax_rows = + this->per_token_activation ? static_cast(flat_first_dim) : 1; // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_rowwise = at::empty({1}, opts); + amax_rowwise = at::empty({amax_rows}, opts); tensor.attr("_amax_rowwise") = *amax_rowwise; } } else { // rowwise_usage == false @@ -2118,7 +2129,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( out_cpp.set_rowwise_data(rowwise_data->data_ptr(), DType::kFloat4E2M1, shape); out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E4M3, getTensorShape(*rowwise_scale_inv)); - out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, std::vector{1}); + out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, getTensorShape(*amax_rowwise)); } if (columnwise_usage) { // enforce 2D shape to avoid [S, B, H] shape and B and be 1 @@ -2241,6 +2252,22 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } size_t cols = input.size(input.ndim() - 1); + if (this->per_token_activation) { + NVTE_CHECK(!this->with_rht, "Per-token NVFP4 activation does not support RHT."); + NVTE_CHECK(!this->with_2d_quantization, + "Per-token NVFP4 activation does not support 2D quantization."); + NVTE_CHECK(!this->stochastic_rounding, + "Per-token NVFP4 activation does not support stochastic rounding."); + NVTE_CHECK(!this->columnwise_usage, + "Per-token NVFP4 activation currently supports rowwise-only quantization."); + NVTE_CHECK(!this->with_amax_reduction, + "Per-token NVFP4 activation does not support amax reduction."); + NVTE_CHECK(input.dtype() == DType::kBFloat16 || input.dtype() == DType::kFloat16, + "Per-token NVFP4 activation supports BF16/FP16 inputs only."); + NVTE_CHECK(cols % 16 == 0, "Per-token NVFP4 activation requires last dim divisible by 16."); + quant_config.set_nvfp4_per_token_activation(true); + } + // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT bool eligible_for_rht_cast_fusion = input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; @@ -2307,7 +2334,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou "Use with_post_rht_amax=true instead."); } } else { // Without RHT - if (compute_amax) { + if (compute_amax && !this->per_token_activation) { // Amax pointers auto rowwise_amax_ptr = out.get_amax().data_ptr; auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; @@ -2408,6 +2435,8 @@ void NVFP4Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, } void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out) { + NVTE_CHECK(!this->per_token_activation, + "quantize_with_amax is not supported for per-token NVFP4 activation."); // Update output tensor amaxes with input tensor amax auto input_amax_ptr = input.amax(); auto output_rowwise_amax_ptr = out.get_amax().data_ptr; diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index dd01ae05d3..6a5e400592 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -350,6 +350,7 @@ def __init__( pow_2_scales: bool = False, eps: float = 0.0, quant_tile_shape: Tuple[int, int] = (1, 16), + per_token_activation: bool = False, with_rht: bool = False, with_random_sign_mask: bool = True, ): @@ -360,6 +361,7 @@ def __init__( self.pow_2_scales = pow_2_scales self.eps = eps self.quant_tile_shape = quant_tile_shape + self.per_token_activation = per_token_activation self.with_rht = with_rht self.with_random_sign_mask = with_random_sign_mask @@ -447,6 +449,7 @@ def _quantize_blockwise_reference( tile_len_y: int, *, pow_2_scales: bool, + per_token_activation: bool, eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -488,6 +491,11 @@ def _quantize_blockwise_reference( decode_scale.to(torch.float32), ) else: + if per_token_activation: + global_amax = global_amax.to(torch.float32).view(m, 1, 1) + else: + global_amax = global_amax.to(torch.float32) + global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) global_encode_scale = torch.min( global_encode_scale, @@ -497,8 +505,15 @@ def _quantize_blockwise_reference( dtype=torch.float32, ), ) - if global_encode_scale == torch.tensor(0.0, device=x.device, dtype=torch.float32): - global_encode_scale = torch.tensor(1.0, device=x.device, dtype=torch.float32) + if global_encode_scale.numel() == 1: + if global_encode_scale == torch.tensor(0.0, device=x.device, dtype=torch.float32): + global_encode_scale = torch.tensor(1.0, device=x.device, dtype=torch.float32) + else: + global_encode_scale = torch.where( + global_encode_scale == 0.0, + torch.ones_like(global_encode_scale), + global_encode_scale, + ) global_decode_scale = torch.div(1.0, global_encode_scale) global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) @@ -609,6 +624,10 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ raise ValueError( f"MXFP4 only supports 1x32 tile shape, got {self.quant_tile_shape}" ) + if self.per_token_activation: + raise ValueError( + "Per-token activation is only supported for NVFP4 (non-pow2) mode." + ) # TODO(etsykunov): Fix bug where global_amax_row and # global_amax_col are not defined # global_amax = torch.empty(0, device=tensor.device, dtype=torch.float32) @@ -625,13 +644,24 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ if self.with_rht else tensor.t().contiguous() ) - # Compute amax for rowwise and columnwise paths separately - global_amax_row = torch.max(torch.abs(row_input)).to(torch.float32).view(1) - global_amax_col = ( - torch.max(torch.abs(col_input)).to(torch.float32).view(1) - if self.columnwise_usage - else global_amax_row - ) + if self.per_token_activation: + if self.quant_tile_shape != (1, 16): + raise ValueError( + "Per-token activation only supports NVFP4 1x16 tile shape, " + f"got {self.quant_tile_shape}" + ) + if self.columnwise_usage: + raise ValueError("Per-token activation reference supports rowwise-only usage.") + global_amax_row = torch.max(torch.abs(row_input), dim=1).values.to(torch.float32) + global_amax_col = global_amax_row + else: + # Compute amax for rowwise and columnwise paths separately + global_amax_row = torch.max(torch.abs(row_input)).to(torch.float32).view(1) + global_amax_col = ( + torch.max(torch.abs(col_input)).to(torch.float32).view(1) + if self.columnwise_usage + else global_amax_row + ) transpose_scales = False @@ -648,6 +678,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[1], self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, + per_token_activation=self.per_token_activation, eps=self.eps, ) if transpose_scales: @@ -671,6 +702,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[1], self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, + per_token_activation=False, eps=self.eps, ) @@ -863,6 +895,16 @@ def qgemm( sw = sw.to(torch.float32) factor = 6.0 * 6.0 * 448.0 * 448.0 + if ( + qresult_x.global_amax_row.numel() != 1 + or qresult_w.global_amax_row.numel() != 1 + or qresult_w.global_amax_col.numel() != 1 + or qresult_x.global_amax_col.numel() != 1 + ): + raise ValueError( + "NVFP4QuantizerRef.qgemm expects scalar global amax values; " + "per-token amax vectors are not supported in reference GEMM." + ) if gemm_type == quantization.GEMMType.WGRAD: partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 9956fb77ec..6ffca84a7d 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1375,6 +1375,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=qparams.random_hadamard_transform, with_2d_quantization=qparams.fp4_2d_quantization, stochastic_rounding=qparams.stochastic_rounding, + per_token_activation=self.recipe.per_token_activation, ) return [_make_quantizer(idx) for idx in range(self.num_quantizers)] @@ -1389,6 +1390,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, + per_token_activation=self.recipe.per_token_activation, ) for _ in range(self.num_quantizers) ] diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 65678aa347..cd63fb5221 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -128,6 +128,9 @@ class NVFP4Quantizer(Quantizer): """Stochastic rounding, only applicable for gradients.""" stochastic_rounding: bool + """Per-token activation quantization path (grouped split quantize).""" + per_token_activation: bool + """RHT matrix random sign mask""" rht_matrix_random_sign_mask_t: int rht_matrix: torch.Tensor @@ -143,6 +146,7 @@ def __init__( with_post_rht_amax: bool = False, with_2d_quantization: bool = False, stochastic_rounding: bool = False, + per_token_activation: bool = False, with_random_sign_mask: bool = True, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) @@ -153,6 +157,7 @@ def __init__( self.amax_reduction_group = amax_reduction_group self.with_2d_quantization = with_2d_quantization self.stochastic_rounding = stochastic_rounding + self.per_token_activation = per_token_activation self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) @@ -198,6 +203,7 @@ def copy(self) -> NVFP4Quantizer: with_post_rht_amax=self.with_post_rht_amax, with_2d_quantization=self.with_2d_quantization, stochastic_rounding=self.stochastic_rounding, + per_token_activation=self.per_token_activation, ) quantizer.internal = self.internal quantizer.optimize_for_gemm = self.optimize_for_gemm @@ -330,7 +336,10 @@ def make_empty( scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory ) # Allocate per tensor scale inverse. FP32 format. - amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device, pin_memory=pin_memory) + amax_rows = flat_first_dim if self.per_token_activation else 1 + amax_rowwise = torch.zeros( + amax_rows, dtype=torch.float32, device=device, pin_memory=pin_memory + ) # Allocate FP8 data transpose if needed columnwise_data = None From 700cbce0f382a585a09e122bb4ec8dda5913bf5a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 01:45:09 -0700 Subject: [PATCH 02/21] Add col Signed-off-by: Ziang Li --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 110 +++++--- .../nvfp4/test_nvfp4_quantize_exact.py | 4 +- tests/pytorch/test_backward_override.py | 13 +- .../common/cast/dispatch/quantize.cuh | 4 - .../cast/nvfp4/quantize_pertoken_nvfp4.cuh | 245 ++++++++++++++++-- .../pytorch/cpp_extensions/gemm.py | 18 +- .../pytorch/csrc/extensions/activation.cpp | 10 +- .../pytorch/csrc/extensions/bias.cpp | 5 +- .../pytorch/csrc/extensions/cast.cpp | 12 +- .../pytorch/csrc/extensions/normalization.cpp | 10 +- transformer_engine/pytorch/csrc/quantizer.cpp | 11 +- .../custom_recipes/quantization_nvfp4.py | 109 ++++++-- .../pytorch/tensor/nvfp4_tensor.py | 3 +- 13 files changed, 438 insertions(+), 116 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 911b7660dc..d22442cd64 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -8,6 +8,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils @@ -15,6 +16,20 @@ recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) +def maybe_skip_pertoken_nvfp4_gemm( + x_dtype: torch.dtype, + *, + accumulate: bool, + x_columnwise: bool, +) -> None: + if x_dtype == torch.float32: + pytest.skip("Per-token NVFP4 kernel supports BF16/FP16 inputs only") + if accumulate: + pytest.skip("Per-token NVFP4 GEMM output rescale does not support accumulation") + if x_columnwise: + pytest.skip("Per-token NVFP4 GEMM output rescale requires rowwise activation usage") + + def check_nvfp4_gemm_versus_reference( x_dtype: torch.dtype, w_dtype: torch.dtype, @@ -26,6 +41,7 @@ def check_nvfp4_gemm_versus_reference( *, x_columnwise: bool = False, w_columnwise: bool = False, + per_token_activation: bool = False, ): te_dtype = tex.DType.kFloat4E2M1 @@ -56,6 +72,7 @@ def check_nvfp4_gemm_versus_reference( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + per_token_activation=per_token_activation, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -112,7 +129,16 @@ def check_nvfp4_gemm_versus_reference( sw_trimmed = sw_trimmed.view(torch.float8_e4m3fn) # Create reference quantizer for reference GEMM - ref_quantizer = NVFP4QuantizerRef( + x_ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=not per_token_activation, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + per_token_activation=per_token_activation, + ) + w_ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, columnwise=True, @@ -124,16 +150,16 @@ def check_nvfp4_gemm_versus_reference( # Create reference quantized tensors needed by reference GEMM # Reference GEMM is only rowwise. if x_columnwise: - x_nvfp4_ref = ref_quantizer.quantize(x.t().contiguous()) + x_nvfp4_ref = x_ref_quantizer.quantize(x.t().contiguous()) else: - x_nvfp4_ref = ref_quantizer.quantize(x) + x_nvfp4_ref = x_ref_quantizer.quantize(x) if w_columnwise: - w_nvfp4_ref = ref_quantizer.quantize(w.t().contiguous()) + w_nvfp4_ref = w_ref_quantizer.quantize(w.t().contiguous()) else: - w_nvfp4_ref = ref_quantizer.quantize(w) + w_nvfp4_ref = w_ref_quantizer.quantize(w) # Reference GEMM using quantizer's qgemm method - y_ref = ref_quantizer.qgemm( + y_ref = x_ref_quantizer.qgemm( qx=qx_data, qw=qw_data, m_params=None, # MMParams not used in reference @@ -148,7 +174,7 @@ def check_nvfp4_gemm_versus_reference( qresult_w=w_nvfp4_ref, ) - # Native TE GEMM using tex.generic_gemm (cuBLAS GEMM) + # Native TE GEMM path # Allocate cuBLAS workspace workspace = torch.empty(4, dtype=torch.uint8, device=device) @@ -166,27 +192,38 @@ def check_nvfp4_gemm_versus_reference( x_nvfp4_native.update_usage(rowwise_usage=False) if w_columnwise: w_nvfp4_native.update_usage(rowwise_usage=False) - # Native cuBLAS GEMM - # return type is out, bias_grad, gelu_input, extra_output - # We are just capturing out. - y_native = tex.generic_gemm( - w_nvfp4_native, - transa, - x_nvfp4_native, - transb, - out.clone() if accumulate else None, - out_quantizer, - TE_DType[out_dtype], - bias, - bias_dtype, - use_gelu, - gelu_input, - use_grad, - workspace, - workspace.shape[0], - accumulate, - use_split_accumulator, - )[0] + if per_token_activation: + layout = ("T" if transa else "N") + ("T" if transb else "N") + y_native = general_gemm( + w_nvfp4_native, + x_nvfp4_native, + out_dtype=out_dtype, + accumulate=accumulate, + layout=layout, + out=out.clone() if accumulate else None, + )[0] + else: + # Native cuBLAS GEMM + # return type is out, bias_grad, gelu_input, extra_output + # We are just capturing out. + y_native = tex.generic_gemm( + w_nvfp4_native, + transa, + x_nvfp4_native, + transb, + out.clone() if accumulate else None, + out_quantizer, + TE_DType[out_dtype], + bias, + bias_dtype, + use_gelu, + gelu_input, + use_grad, + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + )[0] # just in case of accumulation, make sure y_ref and y_native are not the same tensor assert y_ref is not y_native, "y_ref and y_native should not be the same tensor" @@ -224,10 +261,14 @@ def check_nvfp4_gemm_versus_reference( "is_x_columnwise, is_w_columnwise", [ (False, False), # TN - (True, False), # NN + (False, True), # NN + (True, False), # TT (True, True), # NT ], - ids=["rowxrow", "colxrow", "colxcol"], + ids=["rowxrow", "rowxcol", "colxrow", "colxcol"], +) +@pytest.mark.parametrize( + "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] ) def test_nvfp4_gemm_versus_reference( M: int, @@ -239,7 +280,15 @@ def test_nvfp4_gemm_versus_reference( accumulate: bool, is_x_columnwise: bool, is_w_columnwise: bool, + per_token_activation: bool, ): + if per_token_activation: + maybe_skip_pertoken_nvfp4_gemm( + x_dtype=x_dtype, + accumulate=accumulate, + x_columnwise=is_x_columnwise, + ) + check_nvfp4_gemm_versus_reference( x_dtype=x_dtype, w_dtype=w_dtype, @@ -250,4 +299,5 @@ def test_nvfp4_gemm_versus_reference( accumulate=accumulate, x_columnwise=is_x_columnwise, w_columnwise=is_w_columnwise, + per_token_activation=per_token_activation, ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 7e94911ddd..7e2a587223 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -26,13 +26,11 @@ def unpack_fp4(x: torch.Tensor) -> torch.Tensor: def maybe_skip_pertoken_nvfp4( x_dtype: torch.dtype = torch.bfloat16, *, - return_transpose: bool = False, + return_transpose: bool = False, # pylint: disable=unused-argument with_2d_quantization: bool = False, ) -> None: if x_dtype == torch.float32: pytest.skip("Per-token NVFP4 kernel supports BF16/FP16 inputs only") - if return_transpose: - pytest.skip("Per-token NVFP4 currently supports rowwise-only quantization") if with_2d_quantization: pytest.skip("Per-token NVFP4 does not support 2D quantization") diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 5da55b14b6..06de1a06f7 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -170,6 +170,9 @@ def _maybe_skip_recipe_dtype( ) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) + if recipe_name == "nvfp4_pertoken" and module_type in ("linear", "layernorm_linear"): + if dtype != torch.bfloat16: + pytest.skip("Per-token NVFP4 activation supports BF16 inputs only in this test") if recipe_name in ("nvfp4", "nvfp4_pertoken"): if module_type in ("linear", "layernorm_linear") and dtype not in ( torch.bfloat16, @@ -183,16 +186,6 @@ def _maybe_skip_recipe_dtype( def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") - if recipe_name == "nvfp4_pertoken" and module_type in ( - "linear", - "layernorm_linear", - "ops_linear", - "grouped_linear", - ): - pytest.skip( - "Per-token NVFP4 currently supports rowwise-only quantization paths " - "(columnwise usage is unsupported for these modules)." - ) def _maybe_skip_unsupported_recipe_shape( diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 0d86022cc1..eab27a6e7e 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -103,8 +103,6 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, auto dtype = input_tensor->dtype(); const bool per_token_activation = quant_config_cpp.nvfp4_per_token_activation; if (per_token_activation) { - NVTE_CHECK(!output_tensor->has_columnwise_data(), - "Per-token NVFP4 quantization supports rowwise-only output."); NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "Per-token NVFP4 quantization does not support 2D quantization."); nvfp4::quantize_pertoken(*input_tensor, noop_tensor, output_tensor, stream); @@ -251,8 +249,6 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens auto dtype = grad_tensor->dtype(); const bool per_token_activation = quant_config_cpp.nvfp4_per_token_activation; if (per_token_activation) { - NVTE_CHECK(!output_tensor->has_columnwise_data(), - "Per-token NVFP4 quantization supports rowwise-only output."); NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "Per-token NVFP4 quantization does not support 2D quantization."); nvfp4::quantize_pertoken(*grad_tensor, noop_tensor, output_tensor, stream); diff --git a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh index 5e1e23f5d5..3f6f809d32 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh @@ -161,6 +161,157 @@ void launch_quantize_pertoken_nvfp4(const int num_rows, const int num_cols, cons #endif } +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(BLOCK_SIZE) +#endif + compute_pertoken_amax_kernel(const int num_rows, const int num_cols, + const IType *__restrict__ input, + float *__restrict__ output_per_token_amax, + const float *__restrict__ noop) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + using IType2 = typename ptx::FPx2; + + const int row_idx = blockIdx.x; + if (row_idx >= num_rows) return; + + const int num_vec2 = num_cols / 2; + const IType2 *input_row = reinterpret_cast(input + row_idx * num_cols); + + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + for (int i = threadIdx.x; i < num_vec2; i += BLOCK_SIZE) { + const IType2 val = input_row[i]; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val); + } + const float thread_max = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + float row_amax = + BlockReduce(temp_storage).Reduce(thread_max, [](float a, float b) { return fmaxf(a, b); }); + + if (threadIdx.x == 0) { + output_per_token_amax[row_idx] = row_amax; + } +#endif +} + +template +void launch_compute_pertoken_amax(const int num_rows, const int num_cols, const IType *input, + float *output_per_token_amax, cudaStream_t stream, + const float *noop = nullptr) { +#if FP4_TYPE_SUPPORTED + if (num_rows == 0 || num_cols == 0) return; + + NVTE_CHECK(num_cols % 2 == 0, "num_cols must be even for per-token amax computation, got ", + num_cols); + dim3 grid(num_rows); + dim3 block(PERTOKEN_BLOCK_SIZE); + + compute_pertoken_amax_kernel + <<>>(num_rows, num_cols, input, output_per_token_amax, noop); + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); +#endif +} + +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(BLOCK_SIZE) +#endif + quantize_pertoken_nvfp4_columnwise_kernel( + const int num_rows, const int num_cols, const IType *__restrict__ input, + uint8_t *__restrict__ output_data_t, fp8e4m3 *__restrict__ output_scales_t, + const float *__restrict__ per_token_amax, const int scale_stride_t, + const float *__restrict__ noop) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using namespace detail; + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + const int col_idx = blockIdx.x; + if (col_idx >= num_cols) return; + + constexpr float fp4_max_inv = 1.0f / TypeExtrema::max; + constexpr float float_max = TypeExtrema::max; + constexpr float one = 1.0f; + const float2 one_2x{one, one}; + const int num_row_blocks = num_rows / PERTOKEN_SF_VEC_SIZE; + + for (int row_block = threadIdx.x; row_block < num_row_blocks; row_block += BLOCK_SIZE) { + const int row_start = row_block * PERTOKEN_SF_VEC_SIZE; + + float vals[PERTOKEN_SF_VEC_SIZE]; + float s_enc[PERTOKEN_SF_VEC_SIZE]; + float scaled_block_amax = 0.0f; +#pragma unroll + for (int i = 0; i < PERTOKEN_SF_VEC_SIZE; ++i) { + const int row_idx = row_start + i; + const float val = static_cast(input[row_idx * num_cols + col_idx]); + const float S_enc = compute_global_encode_scaling_factor_FP4(per_token_amax[row_idx]); + vals[i] = val; + s_enc[i] = S_enc; + scaled_block_amax = fmaxf(scaled_block_amax, fabsf(val) * (S_enc * fp4_max_inv)); + } + + const float S_dec_b_f32 = fminf(scaled_block_amax, float_max); + const nvfp4_scale_t S_dec_b_fp8 = static_cast(S_dec_b_f32); + output_scales_t[col_idx * scale_stride_t + row_block] = S_dec_b_fp8; + + float scaled_vals[PERTOKEN_SF_VEC_SIZE]; +#pragma unroll + for (int i = 0; i < PERTOKEN_SF_VEC_SIZE; ++i) { + const float S_dec_rowwise = 1.0f / s_enc[i]; + const float block_scale_inverse = + fminf(1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); + scaled_vals[i] = vals[i] * block_scale_inverse; + } + + uint8_t *out_ptr = output_data_t + col_idx * (num_rows / 2) + row_start / 2; + auto *out_fp4 = reinterpret_cast(out_ptr); +#pragma unroll + for (int j = 0; j < PERTOKEN_SF_VEC_SIZE; j += 4) { + const float2 in01 = make_float2(scaled_vals[j], scaled_vals[j + 1]); + const float2 in23 = make_float2(scaled_vals[j + 2], scaled_vals[j + 3]); + out_fp4[j / 4] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, one_2x, /*rbits=*/0u); + } + } +#endif +} + +template +void launch_quantize_pertoken_nvfp4_columnwise( + const int num_rows, const int num_cols, const IType *input, uint8_t *output_data_t, + fp8e4m3 *output_scales_t, const float *per_token_amax, const int scale_stride_t, + cudaStream_t stream, const float *noop = nullptr) { +#if FP4_TYPE_SUPPORTED + if (num_rows == 0 || num_cols == 0) return; + + NVTE_CHECK(num_rows % PERTOKEN_SF_VEC_SIZE == 0, "num_rows must be a multiple of ", + PERTOKEN_SF_VEC_SIZE, " for per-token NVFP4 columnwise quantization, got ", + num_rows); + dim3 grid(num_cols); + dim3 block(PERTOKEN_BLOCK_SIZE); + + quantize_pertoken_nvfp4_columnwise_kernel + <<>>(num_rows, num_cols, input, output_data_t, output_scales_t, + per_token_amax, scale_stride_t, noop); + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); +#endif +} + } // namespace quantize_pertoken_kernel inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *output, @@ -172,12 +323,8 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o CheckOutputTensor(*output, "output", false); NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); - NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); - NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); - NVTE_CHECK(output->amax.dptr != nullptr, "Per-token amax tensor must be allocated."); - NVTE_CHECK(!output->has_columnwise_data(), - "Per-token NVFP4 quantization supports rowwise-only output."); + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "NVFP4 output tensor must be allocated."); NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format."); const size_t rows = input.flat_first_dim(); @@ -187,22 +334,86 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o quantize_pertoken_kernel::PERTOKEN_SF_VEC_SIZE, "."); const auto *noop_ptr = reinterpret_cast(noop->data.dptr); - auto *data_ptr = reinterpret_cast(output->data.dptr); - auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); auto *amax_ptr = reinterpret_cast(output->amax.dptr); + auto *amax_colwise_ptr = reinterpret_cast(output->columnwise_amax.dptr); + auto *per_token_amax_ptr = (amax_ptr != nullptr) ? amax_ptr : amax_colwise_ptr; + NVTE_CHECK(per_token_amax_ptr != nullptr, "Per-token amax tensor must be allocated."); + if (amax_ptr != nullptr) { + NVTE_CHECK(output->amax.numel() == rows, "Per-token rowwise amax must have ", rows, + " entries, got ", output->amax.shape, "."); + } + if (amax_colwise_ptr != nullptr) { + NVTE_CHECK(output->columnwise_amax.numel() == rows, "Per-token columnwise amax must have ", + rows, " entries, got ", output->columnwise_amax.shape, "."); + } const int *row_offsets = nullptr; - const int scale_stride = static_cast(output->scale_inv.shape.back()); if (input.dtype() == DType::kBFloat16) { - quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4<__nv_bfloat16>( - static_cast(rows), static_cast(cols), - reinterpret_cast(input.data.dptr), row_offsets, data_ptr, scale_ptr, - amax_ptr, scale_stride, stream, noop_ptr); + const auto *input_ptr = reinterpret_cast(input.data.dptr); + if (output->has_data()) { + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Rowwise output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Rowwise scaling tensor must be allocated."); + NVTE_CHECK(output->amax.dptr != nullptr, "Rowwise per-token amax tensor must be allocated."); + auto *data_ptr = reinterpret_cast(output->data.dptr); + auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); + const int scale_stride = static_cast(output->scale_inv.shape.back()); + quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4<__nv_bfloat16>( + static_cast(rows), static_cast(cols), input_ptr, row_offsets, data_ptr, + scale_ptr, amax_ptr, scale_stride, stream, noop_ptr); + } else { + quantize_pertoken_kernel::launch_compute_pertoken_amax<__nv_bfloat16>( + static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, + noop_ptr); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Columnwise output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated."); + if (amax_ptr != nullptr && amax_colwise_ptr != nullptr && amax_ptr != amax_colwise_ptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(amax_colwise_ptr, amax_ptr, rows * sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); + auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); + quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4_columnwise<__nv_bfloat16>( + static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, + per_token_amax_ptr, scale_stride_t, stream, noop_ptr); + } } else if (input.dtype() == DType::kFloat16) { - quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4( - static_cast(rows), static_cast(cols), - reinterpret_cast(input.data.dptr), row_offsets, data_ptr, scale_ptr, amax_ptr, - scale_stride, stream, noop_ptr); + const auto *input_ptr = reinterpret_cast(input.data.dptr); + if (output->has_data()) { + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Rowwise output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Rowwise scaling tensor must be allocated."); + NVTE_CHECK(output->amax.dptr != nullptr, "Rowwise per-token amax tensor must be allocated."); + auto *data_ptr = reinterpret_cast(output->data.dptr); + auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); + const int scale_stride = static_cast(output->scale_inv.shape.back()); + quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4( + static_cast(rows), static_cast(cols), input_ptr, row_offsets, data_ptr, + scale_ptr, amax_ptr, scale_stride, stream, noop_ptr); + } else { + quantize_pertoken_kernel::launch_compute_pertoken_amax( + static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, + noop_ptr); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Columnwise output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated."); + if (amax_ptr != nullptr && amax_colwise_ptr != nullptr && amax_ptr != amax_colwise_ptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(amax_colwise_ptr, amax_ptr, rows * sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); + auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); + quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4_columnwise( + static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, + per_token_amax_ptr, scale_stride_t, stream, noop_ptr); + } } else { NVTE_ERROR( "Unsupported input dtype for per-token NVFP4 quantization. " diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 9cf58f9dce..4895054758 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -80,15 +80,15 @@ def _maybe_apply_nvfp4_pertoken_output_rescale( gelu: bool, accumulate: bool, ) -> None: - """Apply per-token NVFP4 global-scale correction for TN forward GEMM outputs. + """Apply per-token NVFP4 global-scale correction for forward GEMM outputs. Current NVFP4 GEMM alpha path consumes one scalar amax. Per-token NVFP4 stores - rowwise amax vector in B._amax_rowwise, so we correct by row using ratio - (amax[row] / amax[0]). If bias was fused in epilogue, remove/reapply it around + rowwise amax vector in B, so we correct by row using ratio (amax[row] / amax[0]) + when B is not transposed. If bias was fused in epilogue, remove/reapply it around the row rescale to avoid bias distortion. """ - if grad or gelu or accumulate or layout != "TN": + if grad or gelu or accumulate or layout[1] != "N": return if not isinstance(B, NVFP4TensorStorage): return @@ -96,7 +96,7 @@ def _maybe_apply_nvfp4_pertoken_output_rescale( return if out.numel() == 0: return - amax = B._amax_rowwise + amax = B._amax_rowwise if B._amax_rowwise is not None else B._amax_columnwise if amax is None or amax.numel() <= 1: return @@ -194,13 +194,15 @@ def general_gemm( requested_out_dtype = out_dtype needs_fp32_rescale_path = ( - layout == "TN" + layout[1] == "N" and not grad and not gelu and not accumulate and isinstance(B, NVFP4TensorStorage) - and B._amax_rowwise is not None - and B._amax_rowwise.numel() > 1 + and ( + (B._amax_rowwise is not None and B._amax_rowwise.numel() > 1) + or (B._amax_columnwise is not None and B._amax_columnwise.numel() > 1) + ) and quantization_params is None and out is None and requested_out_dtype is not None diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 2df3b66553..17f86d63d6 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -42,8 +42,9 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->per_token_activation || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else { impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; @@ -154,8 +155,9 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->per_token_activation || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else { impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index 0cf2025f1b..e2dba46370 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -152,8 +152,9 @@ std::vector dact_dbias( } else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->per_token_activation || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else { impl = Impl::FUSED_DACT_AMAX_NVFP4; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index f1654f5525..b05d399414 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -944,7 +944,8 @@ std::tuple, std::vector, bool> bulk_alloc // Note: Multi-quantize kernel does not require contiguous amaxes. const auto offset = roundup(buffer_size, 16); amax_offsets.push_back(offset); - buffer_size = offset + 4; + const size_t amax_size = per_token_activation ? 4 * flat_first_dim(shape_list[i]) : 4; + buffer_size = offset + amax_size; } // Allocate full buffer @@ -957,8 +958,11 @@ std::tuple, std::vector, bool> bulk_alloc buffer, to_fp4_shape(columnwise_data_shapes[i]), data_offsets[i], torch::kUInt8)); columnwise_scale_list.emplace_back( make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + const std::vector amax_shape = + per_token_activation ? std::vector{flat_first_dim(shape_list[i])} + : std::vector{1}; amax_columnwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); } } @@ -1003,7 +1007,7 @@ std::tuple, std::vector, bool> bulk_alloc } if (columnwise_usage) { tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, - std::vector{1}); + getTensorShape(amax_columnwise_list[i])); } tensor_cpp_list.emplace_back(std::move(tensor_wrapper)); @@ -1281,8 +1285,6 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, if (quantizer.per_token_activation) { NVTE_CHECK(!quantizer.with_rht, "Per-token NVFP4 split quantize does not support RHT."); - NVTE_CHECK(!quantizer.columnwise_usage, - "Per-token NVFP4 split quantize currently supports rowwise-only quantization."); NVTE_CHECK(!quantizer.with_2d_quantization, "Per-token NVFP4 split quantize does not support 2D quantization."); NVTE_CHECK(!quantizer.stochastic_rounding, diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index fb4c7aa1c9..3975c01fa5 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -120,8 +120,9 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->per_token_activation || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { // TE kernel supports amax output @@ -357,8 +358,9 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->per_token_activation || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { // TE kernel supports amax output diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index d6fedc707b..6cc6560d8b 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1779,7 +1779,8 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_columnwise = at::empty({1}, bit32_tensor_opts); + const int64_t amax_rows = this->per_token_activation ? static_cast(flat_first_dim) : 1; + amax_columnwise = at::empty({amax_rows}, bit32_tensor_opts); } // Convert tensors to Python @@ -2105,7 +2106,9 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_columnwise = at::empty({1}, opts); + const int64_t amax_rows = + this->per_token_activation ? static_cast(flat_first_dim) : 1; + amax_columnwise = at::empty({amax_rows}, opts); tensor.attr("_amax_columnwise") = *amax_columnwise; } } else { // columnwise_usage == false @@ -2141,7 +2144,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E4M3, getTensorShape(*columnwise_scale_inv)); out_cpp.set_columnwise_amax(amax_columnwise->data_ptr(), DType::kFloat32, - std::vector{1}); + getTensorShape(*amax_columnwise)); } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); this->set_quantization_params(&out_cpp); @@ -2258,8 +2261,6 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou "Per-token NVFP4 activation does not support 2D quantization."); NVTE_CHECK(!this->stochastic_rounding, "Per-token NVFP4 activation does not support stochastic rounding."); - NVTE_CHECK(!this->columnwise_usage, - "Per-token NVFP4 activation currently supports rowwise-only quantization."); NVTE_CHECK(!this->with_amax_reduction, "Per-token NVFP4 activation does not support amax reduction."); NVTE_CHECK(input.dtype() == DType::kBFloat16 || input.dtype() == DType::kFloat16, diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index 6a5e400592..430af6c581 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -546,6 +546,71 @@ def _quantize_blockwise_reference( return cast_to_fp4x2(clipped_x), decode_scale.squeeze(-1) + @classmethod + def _quantize_blockwise_pertoken_columnwise_reference( + cls, + x: torch.Tensor, + global_amax: torch.Tensor, + tile_len_x: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if x.ndim != 2: + raise ValueError( + "_quantize_blockwise_pertoken_columnwise_reference expects a 2D tensor, got" + f" {x.ndim}D with shape {x.shape}" + ) + + m, n = x.shape + x = x.view(m, n // tile_len_x, tile_len_x) + FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) + FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) + + global_amax = global_amax.to(torch.float32).view(1, n // tile_len_x, tile_len_x) + global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) + global_encode_scale = torch.min( + global_encode_scale, + torch.tensor( + torch.finfo(torch.float32).max, + device=global_encode_scale.device, + dtype=torch.float32, + ), + ) + global_encode_scale = torch.where( + global_encode_scale == 0.0, + torch.ones_like(global_encode_scale), + global_encode_scale, + ) + global_decode_scale = torch.div(1.0, global_encode_scale) + global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) + + decode_scale = torch.amax( + torch.abs(x.to(torch.float32)) * global_encode_scale_multiplier, + dim=-1, + keepdim=True, + ) + decode_scale = torch.min( + decode_scale, + torch.tensor( + torch.finfo(torch.float32).max, + device=decode_scale.device, + dtype=torch.float32, + ), + ) + decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) + decode_scale = decode_scale.to(torch.float8_e4m3fn) + + encode_scale = torch.min( + torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale), + torch.tensor( + torch.finfo(torch.float32).max, + device=decode_scale.device, + dtype=torch.float32, + ), + ) + scaled_x = x.to(torch.float32) * encode_scale + clipped_x = torch.clamp(scaled_x, -FLOAT4_E2M1_MAX, FLOAT4_E2M1_MAX).reshape(m, n) + + return cast_to_fp4x2(clipped_x), decode_scale.squeeze(-1) + @staticmethod def _pad_tensor( tensor: torch.Tensor, row_divisor: Optional[int], col_divisor: Optional[int] @@ -650,8 +715,6 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ "Per-token activation only supports NVFP4 1x16 tile shape, " f"got {self.quant_tile_shape}" ) - if self.columnwise_usage: - raise ValueError("Per-token activation reference supports rowwise-only usage.") global_amax_row = torch.max(torch.abs(row_input), dim=1).values.to(torch.float32) global_amax_col = global_amax_row else: @@ -696,15 +759,22 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ x_t, row_divisor=self.quant_tile_shape[0], col_divisor=self.quant_tile_shape[1] ) - qx_t, sx_t = self._quantize_blockwise_reference( - x_t_padded, - global_amax_col, - self.quant_tile_shape[1], - self.quant_tile_shape[0], - pow_2_scales=self.pow_2_scales, - per_token_activation=False, - eps=self.eps, - ) + if self.per_token_activation: + qx_t, sx_t = self._quantize_blockwise_pertoken_columnwise_reference( + x_t_padded, + global_amax_col, + self.quant_tile_shape[1], + ) + else: + qx_t, sx_t = self._quantize_blockwise_reference( + x_t_padded, + global_amax_col, + self.quant_tile_shape[1], + self.quant_tile_shape[0], + pow_2_scales=self.pow_2_scales, + per_token_activation=False, + eps=self.eps, + ) qx_t = self._rm_pad_tensor(qx_t, (N, M // 2)) @@ -895,22 +965,15 @@ def qgemm( sw = sw.to(torch.float32) factor = 6.0 * 6.0 * 448.0 * 448.0 - if ( - qresult_x.global_amax_row.numel() != 1 - or qresult_w.global_amax_row.numel() != 1 - or qresult_w.global_amax_col.numel() != 1 - or qresult_x.global_amax_col.numel() != 1 - ): - raise ValueError( - "NVFP4QuantizerRef.qgemm expects scalar global amax values; " - "per-token amax vectors are not supported in reference GEMM." - ) - if gemm_type == quantization.GEMMType.WGRAD: partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col else: partial_alpha = qresult_x.global_amax_row * qresult_w.global_amax_row - alpha = torch.div(partial_alpha, factor).squeeze(-1) + if partial_alpha.numel() > 1 and partial_alpha.numel() == high_precision_x.shape[0]: + partial_alpha = partial_alpha.view(-1, 1) + else: + partial_alpha = partial_alpha.squeeze(-1) + alpha = torch.div(partial_alpha, factor) M, K = high_precision_x.shape N, K_w = high_precision_w.shape diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index cd63fb5221..53f77da9e4 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -362,8 +362,9 @@ def make_empty( device=device, pin_memory=pin_memory, ) + amax_rows = flat_first_dim if self.per_token_activation else 1 amax_columnwise = torch.zeros( - 1, dtype=torch.float32, device=device, pin_memory=pin_memory + amax_rows, dtype=torch.float32, device=device, pin_memory=pin_memory ) # Construct FP8 tensor From cfd13bb97392fa971a0d1b7adbb3904514cdbfec Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 02:03:16 -0700 Subject: [PATCH 03/21] Add fp32 Signed-off-by: Ziang Li --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 4 -- .../nvfp4/test_nvfp4_quantize_exact.py | 10 +-- tests/pytorch/test_backward_override.py | 3 - .../cast/nvfp4/quantize_pertoken_nvfp4.cuh | 70 ++++++++++++++++--- .../pytorch/csrc/extensions/cast.cpp | 3 - transformer_engine/pytorch/csrc/quantizer.cpp | 2 - 6 files changed, 63 insertions(+), 29 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index d22442cd64..231fb62468 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -17,13 +17,10 @@ def maybe_skip_pertoken_nvfp4_gemm( - x_dtype: torch.dtype, *, accumulate: bool, x_columnwise: bool, ) -> None: - if x_dtype == torch.float32: - pytest.skip("Per-token NVFP4 kernel supports BF16/FP16 inputs only") if accumulate: pytest.skip("Per-token NVFP4 GEMM output rescale does not support accumulation") if x_columnwise: @@ -284,7 +281,6 @@ def test_nvfp4_gemm_versus_reference( ): if per_token_activation: maybe_skip_pertoken_nvfp4_gemm( - x_dtype=x_dtype, accumulate=accumulate, x_columnwise=is_x_columnwise, ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 7e2a587223..93359b6179 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -24,13 +24,10 @@ def unpack_fp4(x: torch.Tensor) -> torch.Tensor: def maybe_skip_pertoken_nvfp4( - x_dtype: torch.dtype = torch.bfloat16, *, return_transpose: bool = False, # pylint: disable=unused-argument with_2d_quantization: bool = False, ) -> None: - if x_dtype == torch.float32: - pytest.skip("Per-token NVFP4 kernel supports BF16/FP16 inputs only") if with_2d_quantization: pytest.skip("Per-token NVFP4 does not support 2D quantization") @@ -185,7 +182,6 @@ def test_quantization_block_tiling_versus_reference( ) -> None: if per_token_activation: maybe_skip_pertoken_nvfp4( - x_dtype=x_dtype, return_transpose=return_transpose, with_2d_quantization=with_2d_quantization, ) @@ -239,7 +235,7 @@ def test_nvfp4_quantization_extrema_versus_reference( x = torch.zeros((M, N), dtype=x_dtype, device=device) if per_token_activation: - maybe_skip_pertoken_nvfp4(x_dtype=x_dtype, return_transpose=return_transpose) + maybe_skip_pertoken_nvfp4(return_transpose=return_transpose) nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -359,7 +355,7 @@ def test_nvfp4_quantization_boundary_values( x = row.unsqueeze(0).repeat(M, 1).to(dtype=x_dtype) if per_token_activation: - maybe_skip_pertoken_nvfp4(x_dtype=x_dtype, return_transpose=return_transpose) + maybe_skip_pertoken_nvfp4(return_transpose=return_transpose) nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -465,7 +461,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( assert not x_nc.is_contiguous() if per_token_activation: - maybe_skip_pertoken_nvfp4(x_dtype=x_dtype, return_transpose=return_transpose) + maybe_skip_pertoken_nvfp4(return_transpose=return_transpose) nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 06de1a06f7..6f11069e28 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -170,9 +170,6 @@ def _maybe_skip_recipe_dtype( ) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) - if recipe_name == "nvfp4_pertoken" and module_type in ("linear", "layernorm_linear"): - if dtype != torch.bfloat16: - pytest.skip("Per-token NVFP4 activation supports BF16 inputs only in this test") if recipe_name in ("nvfp4", "nvfp4_pertoken"): if module_type in ("linear", "layernorm_linear") and dtype not in ( torch.bfloat16, diff --git a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh index 3f6f809d32..feacc2ff6a 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh @@ -37,6 +37,26 @@ using namespace ptx; constexpr int PERTOKEN_BLOCK_SIZE = 256; constexpr int PERTOKEN_SF_VEC_SIZE = 16; +template +__device__ __forceinline__ void abs_max_2x_update(ptx::FPx2 &dst, + const ptx::FPx2 &val) { + if constexpr (std::is_same_v) { + dst.x = fmaxf(fabsf(dst.x), fabsf(val.x)); + dst.y = fmaxf(fabsf(dst.y), fabsf(val.y)); + } else { + ptx::abs_max_2x(dst, dst, val); + } +} + +template +__device__ __forceinline__ float abs_max_2x_to_float(const ptx::FPx2 &val) { + if constexpr (std::is_same_v) { + return fmaxf(fabsf(val.x), fabsf(val.y)); + } else { + return static_cast(__hmax(__habs(val.x), __habs(val.y))); + } +} + template __global__ void #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) @@ -67,10 +87,9 @@ __launch_bounds__(BLOCK_SIZE) IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; for (int i = threadIdx.x; i < num_vec2; i += BLOCK_SIZE) { const IType2 val = input_row[i]; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val); + abs_max_2x_update(thread_amax_2x, val); } - const float thread_max = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + const float thread_max = abs_max_2x_to_float(thread_amax_2x); using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -99,10 +118,9 @@ __launch_bounds__(BLOCK_SIZE) reinterpret_cast(input + actual_row * num_cols + col_start); for (int j = 0; j < PERTOKEN_SF_VEC_SIZE / 2; ++j) { vals[j] = input_block[j]; - ptx::abs_max_2x(block_amax_2x, block_amax_2x, vals[j]); + abs_max_2x_update(block_amax_2x, vals[j]); } - const float block_max = - static_cast(__hmax(__habs(block_amax_2x.x), __habs(block_amax_2x.y))); + const float block_max = abs_max_2x_to_float(block_amax_2x); const float S_dec_b_f32 = fminf(block_max * global_encode_scale_multiplier, detail::TypeExtrema::max); @@ -186,10 +204,9 @@ __launch_bounds__(BLOCK_SIZE) IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; for (int i = threadIdx.x; i < num_vec2; i += BLOCK_SIZE) { const IType2 val = input_row[i]; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val); + abs_max_2x_update(thread_amax_2x, val); } - const float thread_max = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + const float thread_max = abs_max_2x_to_float(thread_amax_2x); using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -414,10 +431,43 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, per_token_amax_ptr, scale_stride_t, stream, noop_ptr); } + } else if (input.dtype() == DType::kFloat32) { + const auto *input_ptr = reinterpret_cast(input.data.dptr); + if (output->has_data()) { + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Rowwise output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Rowwise scaling tensor must be allocated."); + NVTE_CHECK(output->amax.dptr != nullptr, "Rowwise per-token amax tensor must be allocated."); + auto *data_ptr = reinterpret_cast(output->data.dptr); + auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); + const int scale_stride = static_cast(output->scale_inv.shape.back()); + quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4( + static_cast(rows), static_cast(cols), input_ptr, row_offsets, data_ptr, + scale_ptr, amax_ptr, scale_stride, stream, noop_ptr); + } else { + quantize_pertoken_kernel::launch_compute_pertoken_amax( + static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, + noop_ptr); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Columnwise output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated."); + if (amax_ptr != nullptr && amax_colwise_ptr != nullptr && amax_ptr != amax_colwise_ptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(amax_colwise_ptr, amax_ptr, rows * sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); + auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); + quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4_columnwise( + static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, + per_token_amax_ptr, scale_stride_t, stream, noop_ptr); + } } else { NVTE_ERROR( "Unsupported input dtype for per-token NVFP4 quantization. " - "Expected BFloat16 or Float16."); + "Expected BFloat16, Float16, or Float32."); } #else NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index b05d399414..9423aa7296 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1608,9 +1608,6 @@ std::tuple quantize_nvfp4_pertoken(at::Tenso NVTE_CHECK(input.dim() == 2, "Input must be 2D (num_rows, num_cols)"); NVTE_CHECK(input.is_cuda(), "Input must be on CUDA device"); - NVTE_CHECK(input.scalar_type() == at::ScalarType::BFloat16 || - input.scalar_type() == at::ScalarType::Half, - "Input must be BFloat16 or Half"); const int num_rows = input.size(0); const int num_cols = input.size(1); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 6cc6560d8b..6e6e38a1dd 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -2263,8 +2263,6 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou "Per-token NVFP4 activation does not support stochastic rounding."); NVTE_CHECK(!this->with_amax_reduction, "Per-token NVFP4 activation does not support amax reduction."); - NVTE_CHECK(input.dtype() == DType::kBFloat16 || input.dtype() == DType::kFloat16, - "Per-token NVFP4 activation supports BF16/FP16 inputs only."); NVTE_CHECK(cols % 16 == 0, "Per-token NVFP4 activation requires last dim divisible by 16."); quant_config.set_nvfp4_per_token_activation(true); } From 866d337ff0754682b9b674bbfae3fddc4500c71f Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 15:32:57 -0700 Subject: [PATCH 04/21] Clean up tests Signed-off-by: Ziang Li --- tests/pytorch/test_backward_override.py | 8 +++- tests/pytorch/test_cuda_graphs.py | 20 +++++++-- tests/pytorch/test_sanity.py | 31 ++++++++++--- tests/pytorch/utils.py | 15 +++++++ transformer_engine/common/recipe/__init__.py | 1 - .../pytorch/cpp_extensions/gemm.py | 45 +++++++++++++++++++ 6 files changed, 109 insertions(+), 11 deletions(-) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 6f11069e28..c91442562f 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -1042,6 +1042,7 @@ def test_grouped_linear_backward_override_matches_reference( quantized_ref_recipe = make_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) + skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) module_quantized_ref = te.GroupedLinear( num_gemms, @@ -1280,6 +1281,7 @@ def test_grouped_linear_runtime_backward_override_switch_updates_ctx( default_recipe = make_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) + skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) *_, default_ctx = _run_grouped_linear_single_step_with_ctx_state( module, @@ -1724,7 +1726,11 @@ def test_backward_override_memory_peak_report( x = torch.randn(*input_shape, dtype=dtype, device="cuda") dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") - modes = (None, "high_precision", "dequantized") + modes = ( + ("high_precision", "dequantized") + if recipe_name == "nvfp4_pertoken" + else (None, "high_precision", "dequantized") + ) mode_results: dict[str, dict[str, float] | str] = {} for mode in modes: diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index a782dadc60..8a01acf0eb 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -20,17 +20,19 @@ is_fp8_available, is_fp8_block_scaling_available, is_mxfp8_available, + is_nvfp4_available, is_bf16_available, ) from transformer_engine.pytorch.quantization import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.common import recipe -from utils import ModelConfig, reset_rng_states, skip_unsupported_backward_override +from utils import ModelConfig, recipe_id, reset_rng_states, skip_unsupported_backward_override # Check if FP8 is supported. fp8_available = is_fp8_available() fp8_block_scaling_available = is_fp8_block_scaling_available() mxfp8_available = is_mxfp8_available() +nvfp4_available = is_nvfp4_available() # Reset RNG states. reset_rng_states() @@ -62,6 +64,14 @@ def nvfp4_rht_and_2d_quantization(): return nvfp4_recipe +def nvfp4_per_token(): + nvfp4_recipe = recipe.NVFP4BlockScaling(per_token_activation=True) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + def check_rht_usage(recipe: recipe.Recipe) -> bool: # if using RHT, we can only support bf16 # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad @@ -88,7 +98,9 @@ def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) +if nvfp4_available: fp8_recipes.append(nvfp4_rht_and_2d_quantization()) + fp8_recipes.append(nvfp4_per_token()) if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: @@ -360,7 +372,7 @@ def _test_cuda_graphs( @pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=recipe_id) @pytest.mark.parametrize("backward_override", (None, "high_precision", "dequantized")) def test_make_graphed_callables( *, @@ -390,6 +402,8 @@ def test_make_graphed_callables( f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs" ) if fp8 and fp8_recipe.nvfp4(): + if getattr(fp8_recipe, "per_token_activation", False) and module == "mha": + pytest.skip("Per-token NVFP4 CUDA graph coverage applies to GEMM modules.") if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype): pytest.skip( f"Input dtype {dtype} not supported for NVFP4 Recipe" @@ -448,7 +462,7 @@ def test_make_graphed_callables( ) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("backward_override", (None, "high_precision", "dequantized")) def test_make_graphed_callables_with_fp8_weight_caching( *, diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 7f2f24fd69..c3951c28ed 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -33,17 +33,19 @@ checkpoint, QuantizedTensor, is_bf16_available, + is_nvfp4_available, ) from transformer_engine.common import recipe import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.tensor.utils import replace_raw_data -from utils import ModelConfig, skip_unsupported_backward_override +from utils import ModelConfig, recipe_id, skip_unsupported_backward_override # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +nvfp4_available, _ = is_nvfp4_available(return_reason=True) # Record initial RNG state from script run. seed = 1234 @@ -93,9 +95,18 @@ def nvfp4_vanilla(): return nvfp4_recipe +def nvfp4_per_token(): + nvfp4_recipe = recipe.NVFP4BlockScaling(per_token_activation=True) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) +if nvfp4_available: fp8_recipes.append(nvfp4_vanilla()) # TODO: fix check for this if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) @@ -103,6 +114,9 @@ def nvfp4_vanilla(): fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(recipe.DelayedScaling()) fp8_recipes.append(None) +fp8_recipes_with_per_token = fp8_recipes.copy() +if nvfp4_available: + fp8_recipes_with_per_token.insert(-1, nvfp4_per_token()) param_types = [torch.float32, torch.float16] if is_bf16_available(): # bf16 requires sm_80 or higher @@ -402,7 +416,7 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @@ -450,7 +464,7 @@ def test_sanity_layernorm_linear( @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @@ -488,7 +502,7 @@ def test_sanity_linear( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -529,7 +543,7 @@ def test_sanity_linear_with_zero_tokens( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -563,7 +577,12 @@ def test_sanity_grouped_linear( if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") if fp8_recipe.nvfp4(): - pytest.skip("NVFP4 not supported for grouped linear") + if dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") + if not getattr(fp8_recipe, "per_token_activation", False): + pytest.skip("Only per-token NVFP4 is supported for grouped linear") + if fp8_model_params: + pytest.skip("Per-token NVFP4 grouped linear does not support FP8 model params") use_fp8 = fp8_recipe is not None with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 04ac2becbc..6e58538a4a 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -160,12 +160,27 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: raise ValueError(f"Unsupported quantization scheme ({name})") +def recipe_id(fp8_recipe: Optional[Recipe]) -> str: + """Readable pytest id for FP8/FP4 recipes.""" + if fp8_recipe is None: + return "None" + if fp8_recipe.nvfp4() and getattr(fp8_recipe, "per_token_activation", False): + return "NVFP4PerTokenBlockScaling" + return type(fp8_recipe).__name__ + + def skip_unsupported_backward_override( layer_type: str, quant_recipe: Optional[Recipe], backward_override: Optional[str], ) -> None: """Skip known unsupported layer/recipe/backward-override combinations used in tests.""" + if ( + quant_recipe is not None + and getattr(quant_recipe, "per_token_activation", False) + and backward_override is None + ): + pytest.skip("Per-token NVFP4 requires an explicit backward override.") if backward_override is None: return if quant_recipe is None and backward_override is not None: diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index e59d01d82a..8f549c8979 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -511,7 +511,6 @@ def __post_init__(self) -> None: assert ( self.backward_override in _BACKWARD_OVERRIDES ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." - # Quantization params # Note: RHT is currently only applied to column-wise usage so that # it can be used for wgrad GEMM. diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 4895054758..de693d823a 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -114,6 +114,14 @@ def _maybe_apply_nvfp4_pertoken_output_rescale( out_2d.mul_(ratios) +def _is_nvfp4_pertoken_tensor(tensor: torch.Tensor) -> bool: + """Whether tensor carries per-token NVFP4 global amax metadata.""" + if not isinstance(tensor, NVFP4TensorStorage): + return False + amax = tensor._amax_rowwise if tensor._amax_rowwise is not None else tensor._amax_columnwise + return amax is not None and amax.numel() > 1 + + def general_gemm( A: torch.Tensor, B: torch.Tensor, @@ -303,6 +311,43 @@ def general_grouped_gemm( else: bias_dtype = TE_DType[torch.bfloat16] + use_pertoken_unfused_fprop = ( + not grad + and not gelu + and not accumulate + and layout[1] == "N" + and D_dtype is None + and all(q is None for q in quantization_params) + and any(_is_nvfp4_pertoken_tensor(tensor) for tensor in B) + ) + if use_pertoken_unfused_fprop: + out_init = out[0] if single_output else None + if single_output: + start_idx = 0 + out_views = [] + for i in range(num_gemms): + size = m_splits[i] + out_views.append(out_init[start_idx : start_idx + size]) + start_idx += size + else: + out_views = out + for i in range(num_gemms): + if out_views[i].numel() == 0: + continue + gemm_out, _, _, _ = general_gemm( + A[i], + B[i], + quantization_params=None, + out_dtype=out_views[i].dtype, + layout=layout, + bias=bias[i] if use_bias else None, + use_split_accumulator=use_split_accumulator, + ) + out_views[i].copy_(gemm_out) + if single_output: + out = out_init + return out, bias, gelu_input + if isinstance(quantization_params[0], DebugQuantizer): assert not gelu, "GELU not supported in debug mode" if single_output: From 5a6ea130cc57a7ed0eabcb21ea9213e818dc093d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 15:43:04 -0700 Subject: [PATCH 05/21] Clean up ref Signed-off-by: Ziang Li --- .../custom_recipes/quantization_nvfp4.py | 118 +++++------------- 1 file changed, 29 insertions(+), 89 deletions(-) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index 430af6c581..d57ea792dd 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -449,10 +449,14 @@ def _quantize_blockwise_reference( tile_len_y: int, *, pow_2_scales: bool, - per_token_activation: bool, + per_token_rowwise: bool = False, + per_token_columnwise: bool = False, eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: + assert not ( + per_token_rowwise and per_token_columnwise + ), "Per-token rowwise and columnwise reference modes are mutually exclusive." if x.ndim != 2: raise ValueError( f"_quantize_blockwise_reference expects a 2D tensor, got {x.ndim}D with shape" @@ -491,10 +495,10 @@ def _quantize_blockwise_reference( decode_scale.to(torch.float32), ) else: - if per_token_activation: + if per_token_rowwise: global_amax = global_amax.to(torch.float32).view(m, 1, 1) - else: - global_amax = global_amax.to(torch.float32) + if per_token_columnwise: + global_amax = global_amax.to(torch.float32).view(1, n // tile_len_x, tile_len_x) global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) global_encode_scale = torch.min( @@ -517,9 +521,16 @@ def _quantize_blockwise_reference( global_decode_scale = torch.div(1.0, global_encode_scale) global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) - # Match the kernel's default path: fold the FP4 reciprocal into the - # global scale multiplier, but keep the final reciprocal exact. - decode_scale = vec_max * global_encode_scale_multiplier + if per_token_columnwise: + decode_scale = torch.amax( + torch.abs(x.to(torch.float32)) * global_encode_scale_multiplier, + dim=-1, + keepdim=True, + ) + else: + # Match the kernel's default path: fold the FP4 reciprocal into the + # global scale multiplier, but keep the final reciprocal exact. + decode_scale = vec_max * global_encode_scale_multiplier decode_scale = torch.min( decode_scale, torch.tensor( @@ -546,71 +557,6 @@ def _quantize_blockwise_reference( return cast_to_fp4x2(clipped_x), decode_scale.squeeze(-1) - @classmethod - def _quantize_blockwise_pertoken_columnwise_reference( - cls, - x: torch.Tensor, - global_amax: torch.Tensor, - tile_len_x: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: - if x.ndim != 2: - raise ValueError( - "_quantize_blockwise_pertoken_columnwise_reference expects a 2D tensor, got" - f" {x.ndim}D with shape {x.shape}" - ) - - m, n = x.shape - x = x.view(m, n // tile_len_x, tile_len_x) - FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) - FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) - - global_amax = global_amax.to(torch.float32).view(1, n // tile_len_x, tile_len_x) - global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) - global_encode_scale = torch.min( - global_encode_scale, - torch.tensor( - torch.finfo(torch.float32).max, - device=global_encode_scale.device, - dtype=torch.float32, - ), - ) - global_encode_scale = torch.where( - global_encode_scale == 0.0, - torch.ones_like(global_encode_scale), - global_encode_scale, - ) - global_decode_scale = torch.div(1.0, global_encode_scale) - global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) - - decode_scale = torch.amax( - torch.abs(x.to(torch.float32)) * global_encode_scale_multiplier, - dim=-1, - keepdim=True, - ) - decode_scale = torch.min( - decode_scale, - torch.tensor( - torch.finfo(torch.float32).max, - device=decode_scale.device, - dtype=torch.float32, - ), - ) - decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) - decode_scale = decode_scale.to(torch.float8_e4m3fn) - - encode_scale = torch.min( - torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale), - torch.tensor( - torch.finfo(torch.float32).max, - device=decode_scale.device, - dtype=torch.float32, - ), - ) - scaled_x = x.to(torch.float32) * encode_scale - clipped_x = torch.clamp(scaled_x, -FLOAT4_E2M1_MAX, FLOAT4_E2M1_MAX).reshape(m, n) - - return cast_to_fp4x2(clipped_x), decode_scale.squeeze(-1) - @staticmethod def _pad_tensor( tensor: torch.Tensor, row_divisor: Optional[int], col_divisor: Optional[int] @@ -741,7 +687,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[1], self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, - per_token_activation=self.per_token_activation, + per_token_rowwise=self.per_token_activation, eps=self.eps, ) if transpose_scales: @@ -759,22 +705,15 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ x_t, row_divisor=self.quant_tile_shape[0], col_divisor=self.quant_tile_shape[1] ) - if self.per_token_activation: - qx_t, sx_t = self._quantize_blockwise_pertoken_columnwise_reference( - x_t_padded, - global_amax_col, - self.quant_tile_shape[1], - ) - else: - qx_t, sx_t = self._quantize_blockwise_reference( - x_t_padded, - global_amax_col, - self.quant_tile_shape[1], - self.quant_tile_shape[0], - pow_2_scales=self.pow_2_scales, - per_token_activation=False, - eps=self.eps, - ) + qx_t, sx_t = self._quantize_blockwise_reference( + x_t_padded, + global_amax_col, + self.quant_tile_shape[1], + self.quant_tile_shape[0], + pow_2_scales=self.pow_2_scales, + per_token_columnwise=self.per_token_activation, + eps=self.eps, + ) qx_t = self._rm_pad_tensor(qx_t, (N, M // 2)) @@ -965,6 +904,7 @@ def qgemm( sw = sw.to(torch.float32) factor = 6.0 * 6.0 * 448.0 * 448.0 + if gemm_type == quantization.GEMMType.WGRAD: partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col else: From ee0aafb3ee360b51d8b151177d93f3b712909b00 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 17:58:25 -0700 Subject: [PATCH 06/21] Clean up gemm wrapper Signed-off-by: Ziang Li --- tests/pytorch/test_backward_override.py | 57 ++++--- transformer_engine/common/recipe/__init__.py | 1 + .../pytorch/cpp_extensions/gemm.py | 152 ++++++++---------- 3 files changed, 105 insertions(+), 105 deletions(-) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index c91442562f..0921035d1e 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -99,6 +99,12 @@ def backward_override(request: pytest.FixtureRequest) -> str: return request.param +def _make_backward_test_recipe(recipe_name: str, **recipe_kwargs) -> Optional[recipe.Recipe]: + if recipe_name == "nvfp4_pertoken" and "backward_override" not in recipe_kwargs: + recipe_kwargs["backward_override"] = "dequantized" + return make_recipe(recipe_name, **recipe_kwargs) + + # -------------------------- # Test cases # -------------------------- @@ -185,6 +191,11 @@ def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: s pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") +def _maybe_skip_unsupported_fused_ops(recipe_name: str) -> None: + if recipe_name == "nvfp4_pertoken": + pytest.skip("Per-token NVFP4 currently does not support fused te_ops paths.") + + def _maybe_skip_unsupported_recipe_shape( recipe_name: str, input_shape: tuple[int, ...], @@ -856,7 +867,7 @@ def test_linear_like_backward_override_matches_reference( _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) in_features = input_shape[-1] - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_backward_test_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override(module_type, mode_recipe, backward_override) @@ -1040,7 +1051,7 @@ def test_grouped_linear_backward_override_matches_reference( num_gemms = len(m_splits) num_tokens = sum(m_splits) - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_backward_test_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) @@ -1209,9 +1220,11 @@ def test_linear_like_runtime_backward_override_switch_updates_ctx( x = torch.randn(*input_shape, dtype=dtype, device="cuda") dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") - default_recipe = make_recipe(recipe_name) + default_recipe = _make_backward_test_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override(module_type, mode_recipe, backward_override) + expected_default_mode = default_recipe.backward_override + expected_default_fp8 = expected_default_mode is None *_, default_ctx = _run_single_step_with_ctx_state(module, x, dy, default_recipe) ( @@ -1220,10 +1233,10 @@ def test_linear_like_runtime_backward_override_switch_updates_ctx( default_grad_output_quantizer, default_reduce_and_update, ) = default_ctx - assert default_mode is None - assert default_fp8 - assert default_grad_output_quantizer is not None - assert default_reduce_and_update + assert default_mode == expected_default_mode + assert default_fp8 == expected_default_fp8 + assert (default_grad_output_quantizer is not None) == expected_default_fp8 + assert default_reduce_and_update == expected_default_fp8 *_, switched_ctx = _run_single_step_with_ctx_state(module, x, dy, mode_recipe) switched_mode, switched_fp8, switched_grad_output_quantizer, switched_reduce_and_update = ( @@ -1241,10 +1254,10 @@ def test_linear_like_runtime_backward_override_switch_updates_ctx( default_grad_output_quantizer_after, default_reduce_and_update_after, ) = default_ctx_after - assert default_mode_after is None - assert default_fp8_after - assert default_grad_output_quantizer_after is not None - assert default_reduce_and_update_after + assert default_mode_after == expected_default_mode + assert default_fp8_after == expected_default_fp8 + assert (default_grad_output_quantizer_after is not None) == expected_default_fp8 + assert default_reduce_and_update_after == expected_default_fp8 @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @@ -1279,9 +1292,11 @@ def test_grouped_linear_runtime_backward_override_switch_updates_ctx( x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") - default_recipe = make_recipe(recipe_name) + default_recipe = _make_backward_test_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) + expected_default_mode = default_recipe.backward_override + expected_default_fp8 = expected_default_mode is None *_, default_ctx = _run_grouped_linear_single_step_with_ctx_state( module, @@ -1291,9 +1306,9 @@ def test_grouped_linear_runtime_backward_override_switch_updates_ctx( default_recipe, ) default_mode, default_fp8, default_reduce_and_update = default_ctx - assert default_mode is None - assert default_fp8 - assert default_reduce_and_update + assert default_mode == expected_default_mode + assert default_fp8 == expected_default_fp8 + assert default_reduce_and_update == expected_default_fp8 *_, switched_ctx = _run_grouped_linear_single_step_with_ctx_state( module, @@ -1315,9 +1330,9 @@ def test_grouped_linear_runtime_backward_override_switch_updates_ctx( default_recipe, ) default_mode_after, default_fp8_after, default_reduce_and_update_after = default_ctx_after - assert default_mode_after is None - assert default_fp8_after - assert default_reduce_and_update_after + assert default_mode_after == expected_default_mode + assert default_fp8_after == expected_default_fp8 + assert default_reduce_and_update_after == expected_default_fp8 @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @@ -1344,10 +1359,11 @@ def test_fused_linear_paths_match_backward_override_reference( _maybe_skip_recipe_dtype(recipe_name, dtype, "ops_linear") _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") _maybe_skip_unsupported_recipe_shape(recipe_name, (m, in_features), "ops_linear") + _maybe_skip_unsupported_fused_ops(recipe_name) reset_rng_states() - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_backward_test_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("ops_linear", mode_recipe, backward_override) @@ -1483,11 +1499,12 @@ def test_fused_bias_activation_matches_masked_linear_backward( _maybe_skip_recipe_dtype(recipe_name, dtype, "ops_linear") _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, "ops_linear") + _maybe_skip_unsupported_fused_ops(recipe_name) reset_rng_states() in_features = input_shape[-1] - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_backward_test_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("ops_linear", mode_recipe, backward_override) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 8f549c8979..e59d01d82a 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -511,6 +511,7 @@ def __post_init__(self) -> None: assert ( self.backward_override in _BACKWARD_OVERRIDES ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." + # Quantization params # Note: RHT is currently only applied to column-wise usage so that # it can be used for wgrad GEMM. diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index de693d823a..f19b175969 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -70,50 +70,6 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float: return 0.0 -def _maybe_apply_nvfp4_pertoken_output_rescale( - out: torch.Tensor, - B: torch.Tensor, - *, - layout: str, - bias: Optional[torch.Tensor], - grad: bool, - gelu: bool, - accumulate: bool, -) -> None: - """Apply per-token NVFP4 global-scale correction for forward GEMM outputs. - - Current NVFP4 GEMM alpha path consumes one scalar amax. Per-token NVFP4 stores - rowwise amax vector in B, so we correct by row using ratio (amax[row] / amax[0]) - when B is not transposed. If bias was fused in epilogue, remove/reapply it around - the row rescale to avoid bias distortion. - """ - - if grad or gelu or accumulate or layout[1] != "N": - return - if not isinstance(B, NVFP4TensorStorage): - return - if not isinstance(out, torch.Tensor) or is_custom(out): - return - if out.numel() == 0: - return - amax = B._amax_rowwise if B._amax_rowwise is not None else B._amax_columnwise - if amax is None or amax.numel() <= 1: - return - - out_2d = out.reshape(-1, out.shape[-1]) - if amax.numel() != out_2d.shape[0]: - return - - ratios = (amax / amax[0]).to(dtype=out.dtype).view(-1, 1) - if bias is not None: - bias_cast = bias.to(dtype=out.dtype) - out_2d.sub_(bias_cast) - out_2d.mul_(ratios) - out_2d.add_(bias_cast) - else: - out_2d.mul_(ratios) - - def _is_nvfp4_pertoken_tensor(tensor: torch.Tensor) -> bool: """Whether tensor carries per-token NVFP4 global amax metadata.""" if not isinstance(tensor, NVFP4TensorStorage): @@ -200,24 +156,6 @@ def general_gemm( # FP8 block-scaling requires split accumulator use_split_accumulator = True - requested_out_dtype = out_dtype - needs_fp32_rescale_path = ( - layout[1] == "N" - and not grad - and not gelu - and not accumulate - and isinstance(B, NVFP4TensorStorage) - and ( - (B._amax_rowwise is not None and B._amax_rowwise.numel() > 1) - or (B._amax_columnwise is not None and B._amax_columnwise.numel() > 1) - ) - and quantization_params is None - and out is None - and requested_out_dtype is not None - and requested_out_dtype != torch.float32 - ) - effective_out_dtype = torch.float32 if needs_fp32_rescale_path else requested_out_dtype - args = ( A, transa, # transa @@ -225,7 +163,7 @@ def general_gemm( transb, # transb out, quantization_params, - TE_DType[effective_out_dtype] if effective_out_dtype is not None else None, + TE_DType[out_dtype] if out_dtype is not None else None, bias, bias_dtype, gelu, @@ -245,18 +183,57 @@ def general_gemm( "beta": beta, } - out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) - _maybe_apply_nvfp4_pertoken_output_rescale( - out, - B, - layout=layout, - bias=bias, - grad=grad, - gelu=gelu, - accumulate=accumulate, - ) - if needs_fp32_rescale_path: - out = out.to(dtype=requested_out_dtype) + if not _is_nvfp4_pertoken_tensor(B): + out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) + else: + assert layout[1] == "N", "Per-token NVFP4 GEMM currently supports N-layout B only." + assert not grad, "Per-token NVFP4 GEMM currently supports fprop only." + assert not gelu, "Per-token NVFP4 GEMM currently does not support fused GELU." + assert not accumulate, "Per-token NVFP4 GEMM currently does not support accumulation." + assert ( + quantization_params is None + ), "Per-token NVFP4 GEMM currently does not support output quantization." + assert out is None or ( + isinstance(out, torch.Tensor) and not is_custom(out) + ), "Per-token NVFP4 GEMM currently supports only plain torch.Tensor outputs." + requested_out = out + requested_out_dtype = out_dtype + fp32_out = ( + torch.empty_like(requested_out, dtype=torch.float32) + if requested_out is not None + else None + ) + # Override only output, output quantizer, and output dtype for the FP32 correction path. + args = ( + *args[:4], + fp32_out, + None, + TE_DType[torch.float32], + *args[7:], + ) + out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) + + assert isinstance(out, torch.Tensor) and not is_custom(out) + assert out.numel() > 0 + amax = B._amax_rowwise if B._amax_rowwise is not None else B._amax_columnwise + assert amax is not None and amax.numel() > 1 + + out_2d = out.reshape(-1, out.shape[-1]) + assert amax.numel() == out_2d.shape[0] + ratios = (amax / amax[0]).to(dtype=out.dtype).view(-1, 1) + if bias is not None: + bias_cast = bias.to(dtype=out.dtype) + out_2d.sub_(bias_cast) + out_2d.mul_(ratios) + out_2d.add_(bias_cast) + else: + out_2d.mul_(ratios) + + if requested_out is not None: + requested_out.copy_(out.to(dtype=requested_out.dtype)) + out = requested_out + elif requested_out_dtype is not None and requested_out_dtype != torch.float32: + out = out.to(dtype=requested_out_dtype) if debug_quantizer is not None: out = debug_quantizer.process_gemm_output(out) @@ -311,16 +288,21 @@ def general_grouped_gemm( else: bias_dtype = TE_DType[torch.bfloat16] - use_pertoken_unfused_fprop = ( - not grad - and not gelu - and not accumulate - and layout[1] == "N" - and D_dtype is None - and all(q is None for q in quantization_params) - and any(_is_nvfp4_pertoken_tensor(tensor) for tensor in B) - ) - if use_pertoken_unfused_fprop: + if any(_is_nvfp4_pertoken_tensor(tensor) for tensor in B): + assert layout[1] == "N", "Per-token NVFP4 grouped GEMM currently supports N-layout B only." + assert not grad, "Per-token NVFP4 grouped GEMM currently supports fprop only." + assert not gelu, "Per-token NVFP4 grouped GEMM currently does not support fused GELU." + assert ( + not accumulate + ), "Per-token NVFP4 grouped GEMM currently does not support accumulation." + assert D_dtype is None, "Per-token NVFP4 grouped GEMM currently does not support D_dtype." + assert all( + q is None for q in quantization_params + ), "Per-token NVFP4 grouped GEMM currently does not support output quantization." + if single_output: + assert ( + m_splits is not None + ), "Per-token NVFP4 grouped GEMM requires m_splits with single output." out_init = out[0] if single_output else None if single_output: start_idx = 0 From e852804fddbc17c3b610a6dad7fc3d4533e0802a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 18:18:30 -0700 Subject: [PATCH 07/21] Clean up test Signed-off-by: Ziang Li --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 28 +++++---------- .../nvfp4/test_nvfp4_quantize_exact.py | 34 ++++--------------- 2 files changed, 15 insertions(+), 47 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 231fb62468..27b5d0626f 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -16,17 +16,6 @@ recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) -def maybe_skip_pertoken_nvfp4_gemm( - *, - accumulate: bool, - x_columnwise: bool, -) -> None: - if accumulate: - pytest.skip("Per-token NVFP4 GEMM output rescale does not support accumulation") - if x_columnwise: - pytest.skip("Per-token NVFP4 GEMM output rescale requires rowwise activation usage") - - def check_nvfp4_gemm_versus_reference( x_dtype: torch.dtype, w_dtype: torch.dtype, @@ -171,7 +160,7 @@ def check_nvfp4_gemm_versus_reference( qresult_w=w_nvfp4_ref, ) - # Native TE GEMM path + # Native TE GEMM using tex.generic_gemm (cuBLAS GEMM) # Allocate cuBLAS workspace workspace = torch.empty(4, dtype=torch.uint8, device=device) @@ -258,14 +247,13 @@ def check_nvfp4_gemm_versus_reference( "is_x_columnwise, is_w_columnwise", [ (False, False), # TN - (False, True), # NN - (True, False), # TT + (True, False), # NN (True, True), # NT ], - ids=["rowxrow", "rowxcol", "colxrow", "colxcol"], + ids=["rowxrow", "colxrow", "colxcol"], ) @pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] + "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] ) def test_nvfp4_gemm_versus_reference( M: int, @@ -280,10 +268,10 @@ def test_nvfp4_gemm_versus_reference( per_token_activation: bool, ): if per_token_activation: - maybe_skip_pertoken_nvfp4_gemm( - accumulate=accumulate, - x_columnwise=is_x_columnwise, - ) + if accumulate: + pytest.skip("Per-token NVFP4 GEMM output rescale does not support accumulation") + if is_x_columnwise: + pytest.skip("Per-token NVFP4 GEMM output rescale requires rowwise activation usage") check_nvfp4_gemm_versus_reference( x_dtype=x_dtype, diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 93359b6179..a804e8f1ba 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -23,15 +23,6 @@ def unpack_fp4(x: torch.Tensor) -> torch.Tensor: return repeated -def maybe_skip_pertoken_nvfp4( - *, - return_transpose: bool = False, # pylint: disable=unused-argument - with_2d_quantization: bool = False, -) -> None: - if with_2d_quantization: - pytest.skip("Per-token NVFP4 does not support 2D quantization") - - def check_quantization_nvfp4_versus_reference( x_dtype: torch.dtype, M: int, @@ -168,7 +159,7 @@ def check_quantization_nvfp4_versus_reference( "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] ) @pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] + "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] ) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, @@ -180,11 +171,9 @@ def test_quantization_block_tiling_versus_reference( with_2d_quantization: bool, per_token_activation: bool, ) -> None: - if per_token_activation: - maybe_skip_pertoken_nvfp4( - return_transpose=return_transpose, - with_2d_quantization=with_2d_quantization, - ) + if per_token_activation and with_2d_quantization: + pytest.skip("Per-token NVFP4 does not support 2D quantization") + check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, M=M, @@ -211,7 +200,7 @@ def test_quantization_block_tiling_versus_reference( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] + "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] ) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, @@ -234,9 +223,6 @@ def test_nvfp4_quantization_extrema_versus_reference( else: x = torch.zeros((M, N), dtype=x_dtype, device=device) - if per_token_activation: - maybe_skip_pertoken_nvfp4(return_transpose=return_transpose) - nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, @@ -318,7 +304,7 @@ def test_nvfp4_quantization_extrema_versus_reference( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] + "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] ) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, @@ -354,9 +340,6 @@ def test_nvfp4_quantization_boundary_values( row[1::2] = upper x = row.unsqueeze(0).repeat(M, 1).to(dtype=x_dtype) - if per_token_activation: - maybe_skip_pertoken_nvfp4(return_transpose=return_transpose) - nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, @@ -438,7 +421,7 @@ def test_nvfp4_quantization_boundary_values( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4_per_tensor", "nvfp4_pertoken"] + "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] ) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, @@ -460,9 +443,6 @@ def test_nvfp4_quantization_noncontiguous_inputs( x_nc = x_base.t() # shape (N, M), non-contiguous assert not x_nc.is_contiguous() - if per_token_activation: - maybe_skip_pertoken_nvfp4(return_transpose=return_transpose) - nvfp4_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, From 9dbb3ad02f1550764805363f2cf2e8c7ff8084b4 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 18:45:02 -0700 Subject: [PATCH 08/21] Clean up Signed-off-by: Ziang Li --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 4 +--- tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py | 16 ++++------------ tests/pytorch/test_sanity.py | 12 +++--------- 3 files changed, 8 insertions(+), 24 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 27b5d0626f..b2862cc63d 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -252,9 +252,7 @@ def check_nvfp4_gemm_versus_reference( ], ids=["rowxrow", "colxrow", "colxcol"], ) -@pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] -) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) def test_nvfp4_gemm_versus_reference( M: int, K: int, diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index a804e8f1ba..d21e6a6e37 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -158,9 +158,7 @@ def check_quantization_nvfp4_versus_reference( @pytest.mark.parametrize( "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] ) -@pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] -) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -199,9 +197,7 @@ def test_quantization_block_tiling_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] -) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, M: int, @@ -303,9 +299,7 @@ def test_nvfp4_quantization_extrema_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] -) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, M: int, @@ -420,9 +414,7 @@ def test_nvfp4_quantization_boundary_values( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize( - "per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"] -) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index c3951c28ed..73c291cc15 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -33,7 +33,6 @@ checkpoint, QuantizedTensor, is_bf16_available, - is_nvfp4_available, ) from transformer_engine.common import recipe import transformer_engine_torch as tex @@ -45,7 +44,7 @@ fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) -nvfp4_available, _ = is_nvfp4_available(return_reason=True) +nvfp4_available, _ = te.is_nvfp4_available(return_reason=True) # Record initial RNG state from script run. seed = 1234 @@ -543,7 +542,7 @@ def test_sanity_linear_with_zero_tokens( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -577,12 +576,7 @@ def test_sanity_grouped_linear( if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") if fp8_recipe.nvfp4(): - if dtype == torch.float16: - pytest.skip("FP16 output for NVFP4 not supported") - if not getattr(fp8_recipe, "per_token_activation", False): - pytest.skip("Only per-token NVFP4 is supported for grouped linear") - if fp8_model_params: - pytest.skip("Per-token NVFP4 grouped linear does not support FP8 model params") + pytest.skip("NVFP4 not supported for grouped linear") use_fp8 = fp8_recipe is not None with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): From 475de8a604c43775129db2993709bf052fd07ac7 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 18:48:50 -0700 Subject: [PATCH 09/21] Rename and reformat Signed-off-by: Ziang Li --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 2 +- .../nvfp4/test_nvfp4_quantize_exact.py | 8 ++-- tests/pytorch/test_backward_override.py | 16 ++++---- tests/pytorch/utils.py | 4 +- .../cast/nvfp4/quantize_pertoken_nvfp4.cuh | 37 +++++++++++-------- 5 files changed, 36 insertions(+), 31 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index b2862cc63d..3708205aef 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -252,7 +252,7 @@ def check_nvfp4_gemm_versus_reference( ], ids=["rowxrow", "colxrow", "colxcol"], ) -@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) def test_nvfp4_gemm_versus_reference( M: int, K: int, diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index d21e6a6e37..cf801639b7 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -158,7 +158,7 @@ def check_quantization_nvfp4_versus_reference( @pytest.mark.parametrize( "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] ) -@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -197,7 +197,7 @@ def test_quantization_block_tiling_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, M: int, @@ -299,7 +299,7 @@ def test_nvfp4_quantization_extrema_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, M: int, @@ -414,7 +414,7 @@ def test_nvfp4_quantization_boundary_values( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) -@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_pertoken"]) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 0921035d1e..2156d6cef0 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -79,7 +79,7 @@ id="NVFP4BlockScaling", ), pytest.param( - "nvfp4_pertoken", + "nvfp4_per_token", marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), id="NVFP4PerTokenBlockScaling", ), @@ -100,7 +100,7 @@ def backward_override(request: pytest.FixtureRequest) -> str: def _make_backward_test_recipe(recipe_name: str, **recipe_kwargs) -> Optional[recipe.Recipe]: - if recipe_name == "nvfp4_pertoken" and "backward_override" not in recipe_kwargs: + if recipe_name == "nvfp4_per_token" and "backward_override" not in recipe_kwargs: recipe_kwargs["backward_override"] = "dequantized" return make_recipe(recipe_name, **recipe_kwargs) @@ -176,7 +176,7 @@ def _maybe_skip_recipe_dtype( ) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) - if recipe_name in ("nvfp4", "nvfp4_pertoken"): + if recipe_name in ("nvfp4", "nvfp4_per_token"): if module_type in ("linear", "layernorm_linear") and dtype not in ( torch.bfloat16, torch.float32, @@ -192,7 +192,7 @@ def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: s def _maybe_skip_unsupported_fused_ops(recipe_name: str) -> None: - if recipe_name == "nvfp4_pertoken": + if recipe_name == "nvfp4_per_token": pytest.skip("Per-token NVFP4 currently does not support fused te_ops paths.") @@ -211,7 +211,7 @@ def _maybe_skip_unsupported_recipe_shape( " by 32." ) return - if recipe_name in ("nvfp4", "nvfp4_pertoken") and ( + if recipe_name in ("nvfp4", "nvfp4_per_token") and ( flat_first_dim % 16 != 0 or last_dim % 16 != 0 ): pytest.skip( @@ -238,7 +238,7 @@ def _maybe_skip_unsupported_recipe_shape( pytest.skip( "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." ) - if recipe_name in ("nvfp4", "nvfp4_pertoken") and ( + if recipe_name in ("nvfp4", "nvfp4_per_token") and ( flat_first_dim % 16 != 0 or last_dim % 16 != 0 ): pytest.skip( @@ -259,7 +259,7 @@ def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int] ) if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") - if recipe_name in ("nvfp4", "nvfp4_pertoken") and any(m % 16 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_per_token") and any(m % 16 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + NVFP4 requires each non-empty m_split divisible by 16.") if recipe_name == "nvfp4" and any(m % 64 != 0 for m in non_empty_splits): pytest.skip( @@ -1745,7 +1745,7 @@ def test_backward_override_memory_peak_report( modes = ( ("high_precision", "dequantized") - if recipe_name == "nvfp4_pertoken" + if recipe_name == "nvfp4_per_token" else (None, "high_precision", "dequantized") ) mode_results: dict[str, dict[str, float] | str] = {} diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 6e58538a4a..b88bcd31b5 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -115,7 +115,7 @@ def quantization_tols(name: str) -> dict[str, float]: "mxfp8_block_scaling", ): return dtype_tols(tex.DType.kFloat8E4M3) - if name in ("nvfp4", "nvfp4_pertoken"): + if name in ("nvfp4", "nvfp4_per_token"): return dtype_tols(tex.DType.kFloat4E2M1) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -149,7 +149,7 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: disable_2d_quantization=True, **recipe_kwargs, ) - if name == "nvfp4_pertoken": + if name == "nvfp4_per_token": return transformer_engine.common.recipe.NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, diff --git a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh index feacc2ff6a..36eb05115d 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh @@ -62,11 +62,13 @@ __global__ void #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) __launch_bounds__(BLOCK_SIZE) #endif - quantize_pertoken_nvfp4_kernel( - const int num_rows, const int num_cols, const IType *__restrict__ input, - const int *__restrict__ row_offsets, uint8_t *__restrict__ output_data, - fp8e4m3 *__restrict__ output_scales, float *__restrict__ output_per_token_amax, - const int scale_stride, const float *__restrict__ noop) { + quantize_pertoken_nvfp4_kernel(const int num_rows, const int num_cols, + const IType *__restrict__ input, + const int *__restrict__ row_offsets, + uint8_t *__restrict__ output_data, + fp8e4m3 *__restrict__ output_scales, + float *__restrict__ output_per_token_amax, + const int scale_stride, const float *__restrict__ noop) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) using namespace detail; if (noop != nullptr && noop[0] == 1.0f) { @@ -244,11 +246,13 @@ __global__ void #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) __launch_bounds__(BLOCK_SIZE) #endif - quantize_pertoken_nvfp4_columnwise_kernel( - const int num_rows, const int num_cols, const IType *__restrict__ input, - uint8_t *__restrict__ output_data_t, fp8e4m3 *__restrict__ output_scales_t, - const float *__restrict__ per_token_amax, const int scale_stride_t, - const float *__restrict__ noop) { + quantize_pertoken_nvfp4_columnwise_kernel(const int num_rows, const int num_cols, + const IType *__restrict__ input, + uint8_t *__restrict__ output_data_t, + fp8e4m3 *__restrict__ output_scales_t, + const float *__restrict__ per_token_amax, + const int scale_stride_t, + const float *__restrict__ noop) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) using namespace detail; if (noop != nullptr && noop[0] == 1.0f) { @@ -307,16 +311,17 @@ __launch_bounds__(BLOCK_SIZE) } template -void launch_quantize_pertoken_nvfp4_columnwise( - const int num_rows, const int num_cols, const IType *input, uint8_t *output_data_t, - fp8e4m3 *output_scales_t, const float *per_token_amax, const int scale_stride_t, - cudaStream_t stream, const float *noop = nullptr) { +void launch_quantize_pertoken_nvfp4_columnwise(const int num_rows, const int num_cols, + const IType *input, uint8_t *output_data_t, + fp8e4m3 *output_scales_t, + const float *per_token_amax, + const int scale_stride_t, cudaStream_t stream, + const float *noop = nullptr) { #if FP4_TYPE_SUPPORTED if (num_rows == 0 || num_cols == 0) return; NVTE_CHECK(num_rows % PERTOKEN_SF_VEC_SIZE == 0, "num_rows must be a multiple of ", - PERTOKEN_SF_VEC_SIZE, " for per-token NVFP4 columnwise quantization, got ", - num_rows); + PERTOKEN_SF_VEC_SIZE, " for per-token NVFP4 columnwise quantization, got ", num_rows); dim3 grid(num_cols); dim3 block(PERTOKEN_BLOCK_SIZE); From 62a1c1ed95c51423f198020c40d9404688788d01 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 22:27:25 -0700 Subject: [PATCH 10/21] Avoid partial amax folding in gemm Signed-off-by: Ziang Li --- .../pytorch/cpp_extensions/gemm.py | 55 ++++++++++++------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index f19b175969..fec82b8a02 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -78,6 +78,22 @@ def _is_nvfp4_pertoken_tensor(tensor: torch.Tensor) -> bool: return amax is not None and amax.numel() > 1 +def _nvfp4_pertoken_gemm_input( + tensor: NVFP4TensorStorage, +) -> Tuple[NVFP4TensorStorage, torch.Tensor]: + """Return a GEMM alias with identity activation amax and the original per-token amax.""" + metadata = tensor.get_metadata() + if tensor._amax_rowwise is not None: + amax = tensor._amax_rowwise + assert amax is not None and amax.numel() > 1 + metadata["amax_rowwise"] = amax.new_ones(1) + else: + amax = tensor._amax_columnwise + assert amax is not None and amax.numel() > 1 + metadata["amax_columnwise"] = amax.new_ones(1) + return NVFP4TensorStorage(**metadata), amax + + def general_gemm( A: torch.Tensor, B: torch.Tensor, @@ -196,38 +212,35 @@ def general_gemm( assert out is None or ( isinstance(out, torch.Tensor) and not is_custom(out) ), "Per-token NVFP4 GEMM currently supports only plain torch.Tensor outputs." - requested_out = out - requested_out_dtype = out_dtype + # cuBLAS folds the first activation amax into GEMM alpha. Keep per-token amax out of + # alpha by using identity here, then apply the true per-token scale in FP32 below. + gemm_B, amax = _nvfp4_pertoken_gemm_input(B) + per_token_scales = amax.view(-1, 1) + + requested_out, requested_out_dtype = out, out_dtype fp32_out = ( torch.empty_like(requested_out, dtype=torch.float32) if requested_out is not None else None ) - # Override only output, output quantizer, and output dtype for the FP32 correction path. - args = ( - *args[:4], - fp32_out, - None, - TE_DType[torch.float32], - *args[7:], - ) - out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) - - assert isinstance(out, torch.Tensor) and not is_custom(out) - assert out.numel() > 0 - amax = B._amax_rowwise if B._amax_rowwise is not None else B._amax_columnwise - assert amax is not None and amax.numel() > 1 - + gemm_args = list(args) + gemm_args[2] = gemm_B # B + gemm_args[4] = fp32_out # out + gemm_args[5] = None # quantization_params + gemm_args[6] = TE_DType[torch.float32] # out_dtype + out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*gemm_args, **kwargs) out_2d = out.reshape(-1, out.shape[-1]) + + assert amax.dtype == torch.float32 and out.dtype == torch.float32 assert amax.numel() == out_2d.shape[0] - ratios = (amax / amax[0]).to(dtype=out.dtype).view(-1, 1) + if bias is not None: - bias_cast = bias.to(dtype=out.dtype) + bias_cast = bias.to(dtype=torch.float32) out_2d.sub_(bias_cast) - out_2d.mul_(ratios) + out_2d.mul_(per_token_scales) out_2d.add_(bias_cast) else: - out_2d.mul_(ratios) + out_2d.mul_(per_token_scales) if requested_out is not None: requested_out.copy_(out.to(dtype=requested_out.dtype)) From 44e4e0fd5c28b2c1c7899be2e8b44f8ffc404cdb Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 22:27:41 -0700 Subject: [PATCH 11/21] Expand test coverage Signed-off-by: Ziang Li --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 138 +++++++++++++++++- .../nvfp4/test_nvfp4_quantize_exact.py | 12 ++ 2 files changed, 149 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 3708205aef..1a6784ed24 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -8,7 +8,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.cpp_extensions import general_gemm +from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils @@ -222,6 +222,98 @@ def check_nvfp4_gemm_versus_reference( torch.testing.assert_close(y_native, y_ref, atol=8e-3, rtol=8e-3) +def check_nvfp4_pertoken_grouped_gemm_matches_per_gemm( + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + m_splits: list[int], + k: int, + n: int, + *, + use_bias: bool, + single_output: bool, +): + te_dtype = tex.DType.kFloat4E2M1 + device = "cuda" + torch.manual_seed(23) + torch.cuda.manual_seed(23) + + num_gemms = len(m_splits) + + x_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + per_token_activation=True, + ) + w_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + x_nvfp4 = [] + w_nvfp4 = [] + bias = [] + expected = [] + for m in m_splits: + x = torch.randn((m, k), dtype=x_dtype, device=device) + w = torch.randn((n, k), dtype=w_dtype, device=device) + x_nvfp4.append( + x_quantizer.update_quantized( + x, x_quantizer.make_empty(x.shape, dtype=x_dtype, device=device) + ) + ) + w_nvfp4.append( + w_quantizer.update_quantized( + w, w_quantizer.make_empty(w.shape, dtype=w_dtype, device=device) + ) + ) + bias.append(torch.randn(n, dtype=torch.bfloat16, device=device) if use_bias else None) + expected.append( + general_gemm( + w_nvfp4[-1], + x_nvfp4[-1], + out_dtype=out_dtype, + layout="TN", + bias=bias[-1], + )[0] + ) + + if single_output: + out = [torch.empty((sum(m_splits), n), dtype=out_dtype, device=device)] + else: + out = [torch.empty((m, n), dtype=out_dtype, device=device) for m in m_splits] + + grouped_out, _, _ = general_grouped_gemm( + w_nvfp4, + x_nvfp4, + out, + quantization_params=[None] * num_gemms, + out_dtype=out_dtype, + layout="TN", + m_splits=m_splits, + bias=bias, + use_bias=use_bias, + single_output=single_output, + ) + + if single_output: + grouped_slices = torch.split(grouped_out, m_splits, dim=0) + else: + grouped_slices = grouped_out + for grouped, ref in zip(grouped_slices, expected): + torch.testing.assert_close(grouped, ref, atol=0.0, rtol=0.0) + + @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.parametrize( "M, K, N", @@ -283,3 +375,47 @@ def test_nvfp4_gemm_versus_reference( w_columnwise=is_w_columnwise, per_token_activation=per_token_activation, ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "m_splits, k, n", + [ + ([32, 48, 48], 128, 128), + ([64, 80, 112], 128, 256), + ([64, 80, 112], 256, 256), + ([64, 80, 112], 1024, 256), + ([256, 256, 512], 1024, 1024), + ([1024, 1536, 1536], 512, 3072), + ([16, 32, 64], 128, 96), + ([80, 96, 128], 640, 304), + ([320, 336, 352], 3072, 992), + ([64, 80, 112], 64, 256), + ([32, 48, 48], 128, 112), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("use_bias", [False, True], ids=["no_bias", "bias"]) +@pytest.mark.parametrize("single_output", [False, True], ids=["list_output", "single_output"]) +def test_nvfp4_pertoken_grouped_gemm_matches_per_gemm( + m_splits: list[int], + k: int, + n: int, + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + use_bias: bool, + single_output: bool, +): + check_nvfp4_pertoken_grouped_gemm_matches_per_gemm( + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, + m_splits=m_splits, + k=k, + n=n, + use_bias=use_bias, + single_output=single_output, + ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index cf801639b7..098807b685 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -75,6 +75,7 @@ def check_quantization_nvfp4_versus_reference( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise # Reference quantization quant_tile_shape = (1, 16) if not with_2d_quantization else (16, 16) @@ -105,6 +106,7 @@ def check_quantization_nvfp4_versus_reference( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col qx = unpack_fp4(qx) qx_t = unpack_fp4(qx_t) if qx_t is not None else None @@ -124,6 +126,7 @@ def check_quantization_nvfp4_versus_reference( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -249,6 +252,7 @@ def test_nvfp4_quantization_extrema_versus_reference( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -270,6 +274,7 @@ def test_nvfp4_quantization_extrema_versus_reference( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) @@ -282,6 +287,7 @@ def test_nvfp4_quantization_extrema_versus_reference( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -364,6 +370,7 @@ def test_nvfp4_quantization_boundary_values( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -385,6 +392,7 @@ def test_nvfp4_quantization_boundary_values( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) @@ -398,6 +406,7 @@ def test_nvfp4_quantization_boundary_values( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -465,6 +474,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -486,6 +496,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col # Quantized must match torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) @@ -500,5 +511,6 @@ def test_nvfp4_quantization_noncontiguous_inputs( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) From 4755f09941fee8cf2981fa3715277cf93b9e1b5f Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 22:39:35 -0700 Subject: [PATCH 12/21] Expand more tests Signed-off-by: Ziang Li --- tests/pytorch/test_recipe.py | 5 +++-- tests/pytorch/test_torch_compile.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 91d4b89013..b44f27765a 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -509,6 +509,7 @@ def test_quantizer_update(self, module_class): @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) @pytest.mark.parametrize( "M, N", [ @@ -524,8 +525,8 @@ def test_quantizer_update(self, module_class): (8192, 8192), ], ) -def test_fp4_dequantize(dtype, M, N): - q = NVFP4Quantizer() +def test_fp4_dequantize(dtype, per_token_activation, M, N): + q = NVFP4Quantizer(per_token_activation=per_token_activation) a = torch.rand((M, N)).cuda().to(dtype=dtype) starting_tensor = q(a) dequantized_tensor = starting_tensor.dequantize() diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 9d0ed79888..d67c5e77b7 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -32,6 +32,7 @@ is_fp8_block_scaling_available, is_nvfp4_available, ) +from utils import recipe_id fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) @@ -47,6 +48,7 @@ _all_recipes.append(recipe.MXFP8BlockScaling()) if nvfp4_available: _all_recipes.append(recipe.NVFP4BlockScaling()) + _all_recipes.append(recipe.NVFP4BlockScaling(per_token_activation=True)) # --------------------------------------------------------------------------- @@ -303,7 +305,7 @@ def fn(inp): @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.parametrize("fp8_recipe", _all_recipes, ids=lambda r: type(r).__name__) +@pytest.mark.parametrize("fp8_recipe", _all_recipes, ids=recipe_id) def test_autocast_sanity(fp8_recipe): """Smoke test: torch.nn.Linear inside a single te.autocast with each built-in recipe. Forward + backward under torch.compile(fullgraph=True).""" From 55286ed5eb75e0147f361d09f1be11e95c71219a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 23:06:17 -0700 Subject: [PATCH 13/21] Turn on test for grouped linear sanity Signed-off-by: Ziang Li --- tests/pytorch/test_sanity.py | 7 ++-- .../tensor/storage/grouped_tensor_storage.py | 34 +++++++++++++++---- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 73c291cc15..c7527ecfe4 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -542,7 +542,7 @@ def test_sanity_linear_with_zero_tokens( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_per_token, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -576,7 +576,10 @@ def test_sanity_grouped_linear( if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") if fp8_recipe.nvfp4(): - pytest.skip("NVFP4 not supported for grouped linear") + if not getattr(fp8_recipe, "per_token_activation", False): + pytest.skip("NVFP4 not supported for grouped linear") + if dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") use_fp8 = fp8_recipe is not None with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 5f12c3ed8c..1732abf57c 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -662,6 +662,10 @@ def make_grouped_tensor( # Amax buffer for delayed scaling - one per tensor amax = torch.empty(num_tensors, dtype=torch.float32, device=device) elif quantizer._get_compatible_recipe().nvfp4(): + per_token_activation = getattr(quantizer, "per_token_activation", False) + total_amax_elements = ( + sum(math.prod(s[:-1]) for s in shape) if per_token_activation else num_tensors + ) if rowwise_usage: # Allocate rowwise data buffer (1D flattened, uint8, but FP4 packs 2 values per byte) @@ -675,8 +679,7 @@ def make_grouped_tensor( total_scale_elements += math.prod(scale_inv_shape) scale_inv_offsets.append(total_scale_elements) scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) - # Amax buffer - one per tensor - amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + amax = torch.empty(total_amax_elements, dtype=torch.float32, device=device) if columnwise_usage: # Allocate columnwise data buffer (1D flattened, uint8, FP4 packed) @@ -693,8 +696,9 @@ def make_grouped_tensor( columnwise_scale_inv = torch.empty( total_columnwise_scale_elements, dtype=torch.uint8, device=device ) - # Columnwise amax buffer - one per tensor - columnwise_amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + columnwise_amax = torch.empty( + total_amax_elements, dtype=torch.float32, device=device + ) elif quantizer._get_compatible_recipe().float8_block_scaling(): if rowwise_usage: # Allocate rowwise data buffer (1D flattened, uint8) @@ -891,6 +895,13 @@ def split_into_quantized_tensors( cum += math.prod(scale_shape) columnwise_scale_inv_offsets.append(cum) self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets + nvfp4_per_token_amax_offsets = None + if recipe.nvfp4() and getattr(self.quantizer, "per_token_activation", False): + cum = 0 + nvfp4_per_token_amax_offsets = [0] + for i in range(self.num_tensors): + cum += math.prod(self.tensor_shapes[i][:-1]) + nvfp4_per_token_amax_offsets.append(cum) for i in range(self.num_tensors): quantizer = self.quantizer @@ -1083,12 +1094,21 @@ def split_into_quantized_tensors( cscale_shape ) - # Extract amax - one per tensor if self.amax is not None: - amax_rowwise = self.amax[i : i + 1] + if nvfp4_per_token_amax_offsets is not None: + amax_start = nvfp4_per_token_amax_offsets[i] + amax_end = nvfp4_per_token_amax_offsets[i + 1] + amax_rowwise = self.amax[amax_start:amax_end] + else: + amax_rowwise = self.amax[i : i + 1] if self.columnwise_amax is not None: - amax_columnwise = self.columnwise_amax[i : i + 1] + if nvfp4_per_token_amax_offsets is not None: + amax_start = nvfp4_per_token_amax_offsets[i] + amax_end = nvfp4_per_token_amax_offsets[i + 1] + amax_columnwise = self.columnwise_amax[amax_start:amax_end] + else: + amax_columnwise = self.columnwise_amax[i : i + 1] if quantizer.internal: nvfp4_tensor_class = NVFP4TensorStorage From e4829b8038bce6ede33c22eea395d63defd4283a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 26 Apr 2026 23:39:05 -0700 Subject: [PATCH 14/21] Rename pertoken to per_token Signed-off-by: Ziang Li --- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 6 +- .../common/cast/dispatch/quantize.cuh | 6 +- ...nvfp4.cuh => quantize_per_token_nvfp4.cuh} | 102 +++++++++--------- .../pytorch/cpp_extensions/gemm.py | 10 +- transformer_engine/pytorch/csrc/extensions.h | 2 +- .../pytorch/csrc/extensions/cast.cpp | 2 +- .../pytorch/csrc/extensions/pybind.cpp | 2 +- 7 files changed, 65 insertions(+), 65 deletions(-) rename transformer_engine/common/cast/nvfp4/{quantize_pertoken_nvfp4.cuh => quantize_per_token_nvfp4.cuh} (83%) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 1a6784ed24..5fdb0c7d26 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -222,7 +222,7 @@ def check_nvfp4_gemm_versus_reference( torch.testing.assert_close(y_native, y_ref, atol=8e-3, rtol=8e-3) -def check_nvfp4_pertoken_grouped_gemm_matches_per_gemm( +def check_nvfp4_per_token_grouped_gemm_matches_per_gemm( x_dtype: torch.dtype, w_dtype: torch.dtype, out_dtype: torch.dtype, @@ -399,7 +399,7 @@ def test_nvfp4_gemm_versus_reference( @pytest.mark.parametrize("out_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("use_bias", [False, True], ids=["no_bias", "bias"]) @pytest.mark.parametrize("single_output", [False, True], ids=["list_output", "single_output"]) -def test_nvfp4_pertoken_grouped_gemm_matches_per_gemm( +def test_nvfp4_per_token_grouped_gemm_matches_per_gemm( m_splits: list[int], k: int, n: int, @@ -409,7 +409,7 @@ def test_nvfp4_pertoken_grouped_gemm_matches_per_gemm( use_bias: bool, single_output: bool, ): - check_nvfp4_pertoken_grouped_gemm_matches_per_gemm( + check_nvfp4_per_token_grouped_gemm_matches_per_gemm( x_dtype=x_dtype, w_dtype=w_dtype, out_dtype=out_dtype, diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index eab27a6e7e..1200979f6b 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -21,7 +21,7 @@ #include "../mxfp8/group_quantize_mxfp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" -#include "../nvfp4/quantize_pertoken_nvfp4.cuh" +#include "../nvfp4/quantize_per_token_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" namespace transformer_engine { @@ -105,7 +105,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, if (per_token_activation) { NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "Per-token NVFP4 quantization does not support 2D quantization."); - nvfp4::quantize_pertoken(*input_tensor, noop_tensor, output_tensor, stream); + nvfp4::quantize_per_token(*input_tensor, noop_tensor, output_tensor, stream); break; } bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && @@ -251,7 +251,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens if (per_token_activation) { NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "Per-token NVFP4 quantization does not support 2D quantization."); - nvfp4::quantize_pertoken(*grad_tensor, noop_tensor, output_tensor, stream); + nvfp4::quantize_per_token(*grad_tensor, noop_tensor, output_tensor, stream); break; } bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && diff --git a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh similarity index 83% rename from transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh rename to transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh index 36eb05115d..c4b16c557e 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh @@ -4,7 +4,7 @@ * See LICENSE for license information. ************************************************************************/ -/*! \file quantize_pertoken_nvfp4.cuh +/*! \file quantize_per_token_nvfp4.cuh * \brief CUDA kernels to cast to NVFP4 with per-token (per-row) global scaling. */ @@ -29,7 +29,7 @@ namespace transformer_engine { namespace dispatch { namespace nvfp4 { -namespace quantize_pertoken_kernel { +namespace quantize_per_token_kernel { using namespace core; using namespace ptx; @@ -62,13 +62,13 @@ __global__ void #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) __launch_bounds__(BLOCK_SIZE) #endif - quantize_pertoken_nvfp4_kernel(const int num_rows, const int num_cols, - const IType *__restrict__ input, - const int *__restrict__ row_offsets, - uint8_t *__restrict__ output_data, - fp8e4m3 *__restrict__ output_scales, - float *__restrict__ output_per_token_amax, - const int scale_stride, const float *__restrict__ noop) { + quantize_per_token_nvfp4_kernel(const int num_rows, const int num_cols, + const IType *__restrict__ input, + const int *__restrict__ row_offsets, + uint8_t *__restrict__ output_data, + fp8e4m3 *__restrict__ output_scales, + float *__restrict__ output_per_token_amax, + const int scale_stride, const float *__restrict__ noop) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) using namespace detail; if (noop != nullptr && noop[0] == 1.0f) { @@ -159,11 +159,11 @@ __launch_bounds__(BLOCK_SIZE) } template -void launch_quantize_pertoken_nvfp4(const int num_rows, const int num_cols, const IType *input, - const int *row_offsets, uint8_t *output_data, - fp8e4m3 *output_scales, float *output_per_token_amax, - const int scale_stride, cudaStream_t stream, - const float *noop = nullptr) { +void launch_quantize_per_token_nvfp4(const int num_rows, const int num_cols, const IType *input, + const int *row_offsets, uint8_t *output_data, + fp8e4m3 *output_scales, float *output_per_token_amax, + const int scale_stride, cudaStream_t stream, + const float *noop = nullptr) { #if FP4_TYPE_SUPPORTED if (num_rows == 0 || num_cols == 0) return; @@ -172,7 +172,7 @@ void launch_quantize_pertoken_nvfp4(const int num_rows, const int num_cols, cons dim3 grid(num_rows); dim3 block(PERTOKEN_BLOCK_SIZE); - quantize_pertoken_nvfp4_kernel + quantize_per_token_nvfp4_kernel <<>>(num_rows, num_cols, input, row_offsets, output_data, output_scales, output_per_token_amax, scale_stride, noop); NVTE_CHECK_CUDA(cudaGetLastError()); @@ -186,10 +186,10 @@ __global__ void #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) __launch_bounds__(BLOCK_SIZE) #endif - compute_pertoken_amax_kernel(const int num_rows, const int num_cols, - const IType *__restrict__ input, - float *__restrict__ output_per_token_amax, - const float *__restrict__ noop) { + compute_per_token_amax_kernel(const int num_rows, const int num_cols, + const IType *__restrict__ input, + float *__restrict__ output_per_token_amax, + const float *__restrict__ noop) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (noop != nullptr && noop[0] == 1.0f) { return; @@ -222,9 +222,9 @@ __launch_bounds__(BLOCK_SIZE) } template -void launch_compute_pertoken_amax(const int num_rows, const int num_cols, const IType *input, - float *output_per_token_amax, cudaStream_t stream, - const float *noop = nullptr) { +void launch_compute_per_token_amax(const int num_rows, const int num_cols, const IType *input, + float *output_per_token_amax, cudaStream_t stream, + const float *noop = nullptr) { #if FP4_TYPE_SUPPORTED if (num_rows == 0 || num_cols == 0) return; @@ -233,7 +233,7 @@ void launch_compute_pertoken_amax(const int num_rows, const int num_cols, const dim3 grid(num_rows); dim3 block(PERTOKEN_BLOCK_SIZE); - compute_pertoken_amax_kernel + compute_per_token_amax_kernel <<>>(num_rows, num_cols, input, output_per_token_amax, noop); NVTE_CHECK_CUDA(cudaGetLastError()); #else @@ -246,13 +246,13 @@ __global__ void #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) __launch_bounds__(BLOCK_SIZE) #endif - quantize_pertoken_nvfp4_columnwise_kernel(const int num_rows, const int num_cols, - const IType *__restrict__ input, - uint8_t *__restrict__ output_data_t, - fp8e4m3 *__restrict__ output_scales_t, - const float *__restrict__ per_token_amax, - const int scale_stride_t, - const float *__restrict__ noop) { + quantize_per_token_nvfp4_columnwise_kernel(const int num_rows, const int num_cols, + const IType *__restrict__ input, + uint8_t *__restrict__ output_data_t, + fp8e4m3 *__restrict__ output_scales_t, + const float *__restrict__ per_token_amax, + const int scale_stride_t, + const float *__restrict__ noop) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) using namespace detail; if (noop != nullptr && noop[0] == 1.0f) { @@ -311,12 +311,12 @@ __launch_bounds__(BLOCK_SIZE) } template -void launch_quantize_pertoken_nvfp4_columnwise(const int num_rows, const int num_cols, - const IType *input, uint8_t *output_data_t, - fp8e4m3 *output_scales_t, - const float *per_token_amax, - const int scale_stride_t, cudaStream_t stream, - const float *noop = nullptr) { +void launch_quantize_per_token_nvfp4_columnwise(const int num_rows, const int num_cols, + const IType *input, uint8_t *output_data_t, + fp8e4m3 *output_scales_t, + const float *per_token_amax, + const int scale_stride_t, cudaStream_t stream, + const float *noop = nullptr) { #if FP4_TYPE_SUPPORTED if (num_rows == 0 || num_cols == 0) return; @@ -325,7 +325,7 @@ void launch_quantize_pertoken_nvfp4_columnwise(const int num_rows, const int num dim3 grid(num_cols); dim3 block(PERTOKEN_BLOCK_SIZE); - quantize_pertoken_nvfp4_columnwise_kernel + quantize_per_token_nvfp4_columnwise_kernel <<>>(num_rows, num_cols, input, output_data_t, output_scales_t, per_token_amax, scale_stride_t, noop); NVTE_CHECK_CUDA(cudaGetLastError()); @@ -334,10 +334,10 @@ void launch_quantize_pertoken_nvfp4_columnwise(const int num_rows, const int num #endif } -} // namespace quantize_pertoken_kernel +} // namespace quantize_per_token_kernel -inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *output, - cudaStream_t stream) { +inline void quantize_per_token(const Tensor &input, const Tensor *noop, Tensor *output, + cudaStream_t stream) { #if FP4_TYPE_SUPPORTED checkCuDriverContext(stream); CheckNoopTensor(*noop, "cast_noop"); @@ -351,9 +351,9 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - NVTE_CHECK(cols % quantize_pertoken_kernel::PERTOKEN_SF_VEC_SIZE == 0, + NVTE_CHECK(cols % quantize_per_token_kernel::PERTOKEN_SF_VEC_SIZE == 0, "Per-token NVFP4 quantization requires last dim divisible by ", - quantize_pertoken_kernel::PERTOKEN_SF_VEC_SIZE, "."); + quantize_per_token_kernel::PERTOKEN_SF_VEC_SIZE, "."); const auto *noop_ptr = reinterpret_cast(noop->data.dptr); auto *amax_ptr = reinterpret_cast(output->amax.dptr); @@ -379,11 +379,11 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o auto *data_ptr = reinterpret_cast(output->data.dptr); auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); const int scale_stride = static_cast(output->scale_inv.shape.back()); - quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4<__nv_bfloat16>( + quantize_per_token_kernel::launch_quantize_per_token_nvfp4<__nv_bfloat16>( static_cast(rows), static_cast(cols), input_ptr, row_offsets, data_ptr, scale_ptr, amax_ptr, scale_stride, stream, noop_ptr); } else { - quantize_pertoken_kernel::launch_compute_pertoken_amax<__nv_bfloat16>( + quantize_per_token_kernel::launch_compute_per_token_amax<__nv_bfloat16>( static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, noop_ptr); } @@ -399,7 +399,7 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); - quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4_columnwise<__nv_bfloat16>( + quantize_per_token_kernel::launch_quantize_per_token_nvfp4_columnwise<__nv_bfloat16>( static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, per_token_amax_ptr, scale_stride_t, stream, noop_ptr); } @@ -412,11 +412,11 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o auto *data_ptr = reinterpret_cast(output->data.dptr); auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); const int scale_stride = static_cast(output->scale_inv.shape.back()); - quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4( + quantize_per_token_kernel::launch_quantize_per_token_nvfp4( static_cast(rows), static_cast(cols), input_ptr, row_offsets, data_ptr, scale_ptr, amax_ptr, scale_stride, stream, noop_ptr); } else { - quantize_pertoken_kernel::launch_compute_pertoken_amax( + quantize_per_token_kernel::launch_compute_per_token_amax( static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, noop_ptr); } @@ -432,7 +432,7 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); - quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4_columnwise( + quantize_per_token_kernel::launch_quantize_per_token_nvfp4_columnwise( static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, per_token_amax_ptr, scale_stride_t, stream, noop_ptr); } @@ -445,11 +445,11 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o auto *data_ptr = reinterpret_cast(output->data.dptr); auto *scale_ptr = reinterpret_cast(output->scale_inv.dptr); const int scale_stride = static_cast(output->scale_inv.shape.back()); - quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4( + quantize_per_token_kernel::launch_quantize_per_token_nvfp4( static_cast(rows), static_cast(cols), input_ptr, row_offsets, data_ptr, scale_ptr, amax_ptr, scale_stride, stream, noop_ptr); } else { - quantize_pertoken_kernel::launch_compute_pertoken_amax( + quantize_per_token_kernel::launch_compute_per_token_amax( static_cast(rows), static_cast(cols), input_ptr, per_token_amax_ptr, stream, noop_ptr); } @@ -465,7 +465,7 @@ inline void quantize_pertoken(const Tensor &input, const Tensor *noop, Tensor *o auto *data_t_ptr = reinterpret_cast(output->columnwise_data.dptr); auto *scale_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); const int scale_stride_t = static_cast(output->columnwise_scale_inv.shape.back()); - quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4_columnwise( + quantize_per_token_kernel::launch_quantize_per_token_nvfp4_columnwise( static_cast(rows), static_cast(cols), input_ptr, data_t_ptr, scale_t_ptr, per_token_amax_ptr, scale_stride_t, stream, noop_ptr); } diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index fec82b8a02..d23fdf1b59 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -70,7 +70,7 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float: return 0.0 -def _is_nvfp4_pertoken_tensor(tensor: torch.Tensor) -> bool: +def _is_nvfp4_per_token_tensor(tensor: torch.Tensor) -> bool: """Whether tensor carries per-token NVFP4 global amax metadata.""" if not isinstance(tensor, NVFP4TensorStorage): return False @@ -78,7 +78,7 @@ def _is_nvfp4_pertoken_tensor(tensor: torch.Tensor) -> bool: return amax is not None and amax.numel() > 1 -def _nvfp4_pertoken_gemm_input( +def _nvfp4_per_token_gemm_input( tensor: NVFP4TensorStorage, ) -> Tuple[NVFP4TensorStorage, torch.Tensor]: """Return a GEMM alias with identity activation amax and the original per-token amax.""" @@ -199,7 +199,7 @@ def general_gemm( "beta": beta, } - if not _is_nvfp4_pertoken_tensor(B): + if not _is_nvfp4_per_token_tensor(B): out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) else: assert layout[1] == "N", "Per-token NVFP4 GEMM currently supports N-layout B only." @@ -214,7 +214,7 @@ def general_gemm( ), "Per-token NVFP4 GEMM currently supports only plain torch.Tensor outputs." # cuBLAS folds the first activation amax into GEMM alpha. Keep per-token amax out of # alpha by using identity here, then apply the true per-token scale in FP32 below. - gemm_B, amax = _nvfp4_pertoken_gemm_input(B) + gemm_B, amax = _nvfp4_per_token_gemm_input(B) per_token_scales = amax.view(-1, 1) requested_out, requested_out_dtype = out, out_dtype @@ -301,7 +301,7 @@ def general_grouped_gemm( else: bias_dtype = TE_DType[torch.bfloat16] - if any(_is_nvfp4_pertoken_tensor(tensor) for tensor in B): + if any(_is_nvfp4_per_token_tensor(tensor) for tensor in B): assert layout[1] == "N", "Per-token NVFP4 grouped GEMM currently supports N-layout B only." assert not grad, "Per-token NVFP4 grouped GEMM currently supports fprop only." assert not gelu, "Per-token NVFP4 grouped GEMM currently does not support fused GELU." diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 06478b54e0..f62853bb2b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -326,7 +326,7 @@ py::object group_dequantize(const py::handle &input, DType otype); py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, std::optional first_dims); -std::tuple quantize_nvfp4_pertoken(at::Tensor input); +std::tuple quantize_nvfp4_per_token(at::Tensor input); std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 9423aa7296..ba75867a15 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1603,7 +1603,7 @@ std::vector split_quantize(const at::Tensor &tensor, return output_py_list; } -std::tuple quantize_nvfp4_pertoken(at::Tensor input) { +std::tuple quantize_nvfp4_per_token(at::Tensor input) { init_extension(); NVTE_CHECK(input.dim() == 2, "Input must be 2D (num_rows, num_cols)"); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 4021792f86..b2d74205cc 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -145,7 +145,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Dequantize group tensor", py::arg("input"), py::arg("otype")); m.def("bgrad_group_quantize", transformer_engine::pytorch::bgrad_group_quantize, py::arg("tensor"), py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims")); - m.def("quantize_nvfp4_pertoken", transformer_engine::pytorch::quantize_nvfp4_pertoken, + m.def("quantize_nvfp4_per_token", transformer_engine::pytorch::quantize_nvfp4_per_token, "Per-token NVFP4 quantization", py::arg("input")); m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); From dbbdecbf195a1aa68cce37987d4e2f489865a938 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 27 Apr 2026 02:09:11 -0700 Subject: [PATCH 15/21] Expand .cu test Signed-off-by: Ziang Li --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 189 +++++++++++++++--- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 2 +- 2 files changed, 158 insertions(+), 33 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 15d7c695c9..c59c895965 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -114,16 +114,14 @@ void quantize_nvfp4_1d(float (*OP)(const float), block_amax = std::max(block_amax, std::abs(elt)); } - // 2. Compute E4M3 scaling factor - // Compute per-block encoding/decoding scaling factor - const float S_dec_b = block_amax / 6.0f; - - // Scale & Store per-block decoding scaling factor - const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // Compute and store the per-block FP8 decode scale + const float S_dec_b = block_amax * (S_enc * (1.0f / 6.0f)); + const fp8e4m3 S_dec_b_fp8 = static_cast(fminf(S_dec_b, Numeric_Traits::maxNorm)); const float S_dec_b_fp32 = static_cast(S_dec_b_fp8); // Compute "correct" per-block encoding scaling factor - const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32; + const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : + fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits::maxNorm); const size_t scale_idx = i * scales_stride + block_X; scales[scale_idx] = S_dec_b_fp8; @@ -317,11 +315,69 @@ void compute_ref(float (*OP)(const float), const size_t scales_stride, const size_t scales_stride_t, const bool use_fast_math, - const bool use_2d_quantization = false) + const bool use_2d_quantization = false, + std::vector *per_token_amax = nullptr) { std::vector input_t = create_transpose(input, rows, cols); - if (use_2d_quantization) { + if (per_token_amax != nullptr) { + constexpr size_t kBlockSize = 16; + constexpr float fp4_max_inv = 1.0f / 6.0f; + constexpr float float_max = Numeric_Traits::maxNorm; + + per_token_amax->resize(rows, 0.0f); + for (size_t row = 0; row < rows; ++row) { + float row_amax = 0.0f; + for (size_t col = 0; col < cols; ++col) { + row_amax = fmaxf(row_amax, fabsf(static_cast(input[row * cols + col]))); + } + (*per_token_amax)[row] = row_amax; + quantize_nvfp4(OP, + input + row * cols, + output + row * (cols / 2), + scales + row * scales_stride, + 1, + cols, + scales_stride, + row_amax, + use_fast_math, + use_2d_quantization); + } + + for (size_t col = 0; col < cols; ++col) { + for (size_t row_start = 0; row_start < rows; row_start += kBlockSize) { + float vals[kBlockSize]; + float s_enc[kBlockSize]; + float scaled_block_amax = 0.0f; + for (size_t i = 0; i < kBlockSize; ++i) { + const size_t row = row_start + i; + const float val = static_cast(input[row * cols + col]); + const float S_enc = + compute_global_encode_scaling_factor_FP4((*per_token_amax)[row], false); + vals[i] = val; + s_enc[i] = S_enc; + scaled_block_amax = fmaxf(scaled_block_amax, fabsf(val) * (S_enc * fp4_max_inv)); + } + + const float S_dec_b_f32 = fminf(scaled_block_amax, float_max); + const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b_f32); + scales_t[col * scales_stride_t + row_start / kBlockSize] = S_dec_b_fp8; + + for (size_t i = 0; i < kBlockSize; i += 2) { + const float S_dec_rowwise_x = 1.0f / s_enc[i]; + const float S_dec_rowwise_y = 1.0f / s_enc[i + 1]; + const float S_dec_b_fp32 = static_cast(S_dec_b_fp8); + const float S_enc_b_fp8_x = + fminf(1.0f / (S_dec_b_fp32 * S_dec_rowwise_x), float_max); + const float S_enc_b_fp8_y = + fminf(1.0f / (S_dec_b_fp32 * S_dec_rowwise_y), float_max); + const float2 scaled_elt_pair = {vals[i] * S_enc_b_fp8_x, + vals[i + 1] * S_enc_b_fp8_y}; + output_t[(col * rows + row_start + i) / 2] = fp4e2m1x2(scaled_elt_pair); + } + } + } + } else if (use_2d_quantization) { // Step 1: Compute mathematical 8×8 scaling factors std::vector> math_scales; compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math); @@ -526,10 +582,20 @@ void compareResults_nvfp4(const Tensor &test, compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol); } +void compare_per_token_amax(const Tensor &test_amax, const std::vector &ref_amax) { + test_amax.to_cpu(); + const float *test_amax_data = test_amax.rowwise_cpu_dptr(); + for (size_t row = 0; row < ref_amax.size(); ++row) { + ASSERT_EQ(test_amax_data[row], ref_amax[row]) + << "Per-token amax mismatch at row " << row; + } +} + template void performTest(float (*OP)(const float), const std::vector& shape, - const bool use_fast_math) { + const bool use_fast_math, + const bool per_token_activation = false) { using namespace test; DType itype = TypeInfo::dtype; @@ -557,6 +623,7 @@ void performTest(float (*OP)(const float), Tensor input("input", shape, itype); Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING); + Tensor per_token_amax; std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); @@ -565,28 +632,53 @@ void performTest(float (*OP)(const float), fillCase(&input, InputsFillCase::uniform); - // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues - const float amax = 448.0f * 6.0f * 8.0f; - - // Set 2nd stage NVFP4 scaling factor - output.set_tensor_amax(amax); - output.set_tensor_amax_columnwise(amax); - bool use_2d_quantization = false; - - compute_ref(OP, - input.rowwise_cpu_dptr(), - ref_output.get(), - ref_output_t.get(), - ref_scales.get(), - ref_scales_t.get(), - amax, - rows, - cols, - scales_stride, - scales_stride_t, - use_fast_math, - use_2d_quantization); + std::vector ref_per_token_amax; + if (per_token_activation) { + per_token_amax = Tensor("per_token_amax", std::vector{rows}, DType::kFloat32); + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + 0.0f, + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math, + use_2d_quantization, + &ref_per_token_amax); + + std::vector per_token_amax_shape = {rows}; + NVTEBasicTensor amax_tensor = {per_token_amax.rowwise_dptr(), + static_cast(DType::kFloat32), + nvte_make_shape(per_token_amax_shape.data(), + per_token_amax_shape.size())}; + NVTETensor output_tensor = output.data(); + nvte_set_tensor_param_v2(output_tensor, kNVTEAmax, &amax_tensor, sizeof(amax_tensor)); + } else { + // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues + const float amax = 448.0f * 6.0f * 8.0f; + // Set 2nd stage NVFP4 scaling factor + output.set_tensor_amax(amax); + output.set_tensor_amax_columnwise(amax); + + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + amax, + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math, + use_2d_quantization); + } // Initialize stochastic rounding Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); rng_state.rowwise_cpu_dptr()[0] = 123; // rng_seed @@ -600,6 +692,7 @@ void performTest(float (*OP)(const float), // Set 2D quantization based on compile-time flag quant_config.set_nvfp4_2d_quantization(use_2d_quantization); + quant_config.set_nvfp4_per_token_activation(per_token_activation); // Call appropriate function based on operation type // Activation functions take 3 parameters (input, output, stream) @@ -646,6 +739,10 @@ void performTest(float (*OP)(const float), ref_scales_t.get(), unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, scale_mismatches_num); + + if (per_token_activation) { + compare_per_token_amax(per_token_amax, ref_per_token_amax); + } } std::vector> tensor_dims = { @@ -678,6 +775,7 @@ class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam , transformer_engine::DType, + bool, bool>> {}; TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { @@ -693,6 +791,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { const auto tensor_dims = std::get<1>(GetParam()); const DType input_type = std::get<2>(GetParam()); const bool use_fast_math = std::get<3>(GetParam()); + const bool per_token_activation = std::get<4>(GetParam()); // Skip tests if the input tensor is 1D if (tensor_dims.size() < 2) { @@ -710,7 +809,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { } TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - performTest(OP, tensor_dims, use_fast_math); + performTest(OP, tensor_dims, use_fast_math, per_token_activation); ); } @@ -733,6 +832,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(Activation_types), ::testing::ValuesIn(tensor_dims), ::testing::Values(DType::kBFloat16), + ::testing::Values(false), ::testing::Values(false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)); @@ -746,3 +846,28 @@ INSTANTIATE_TEST_SUITE_P( } return name; }); + +INSTANTIATE_TEST_SUITE_P( + OperatorTestPerToken, + FusedCastTransposeNVFP4TestSuite, + ::testing::Combine( + ::testing::Values(ActivationType::Identity), + ::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]), + ::testing::Values(DType::kBFloat16, DType::kFloat32), + ::testing::Values(false), + ::testing::Values(true)), + [](const testing::TestParamInfo& info) { + std::string name = to_string(std::get<0>(info.param)); + const auto& shape = std::get<1>(info.param); + for (const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<2>(info.param)); + if (std::get<3>(info.param)) { + name += "X_FAST_SCALING"; + } + if (std::get<4>(info.param)) { + name += "XPER_TOKEN"; + } + return name; + }); diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 5fdb0c7d26..ef6eda8dcd 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -118,7 +118,7 @@ def check_nvfp4_gemm_versus_reference( x_ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, - columnwise=not per_token_activation, + columnwise=True, pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), From 2374a6e0757dff7e09bfee3f2ef2a07b026aa20b Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 2 May 2026 12:08:19 -0700 Subject: [PATCH 16/21] Format after rebase Signed-off-by: Ziang Li --- .../common/cast/nvfp4/dequantize_nvfp4.cuh | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index 9436b94939..85e858e146 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -34,8 +34,9 @@ namespace dequantize_kernel { template __global__ void __launch_bounds__(512) dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, - const float *const tensor_amax, const size_t amax_numel, const size_t N, const size_t M, - const size_t scale_stride, const size_t num_scale_tiles_X) { + const float *const tensor_amax, const size_t amax_numel, const size_t N, + const size_t M, const size_t scale_stride, + const size_t num_scale_tiles_X) { const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t x = thread_idx % M; const size_t y = thread_idx / M; @@ -110,11 +111,12 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, dequantize_fp4_kernel<<>>( - input.data.dptr, reinterpret_cast(output->data.dptr), - reinterpret_cast(input.scale_inv.dptr), - reinterpret_cast(input.amax.dptr), input.amax.numel(), N, Mread, input.scale_inv.shape.back(), - num_scale_tiles_X);); // NOLINT(*) -); // NOLINT(*) + input.data.dptr, reinterpret_cast(output->data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(input.amax.dptr), input.amax.numel(), N, Mread, + input.scale_inv.shape.back(), + num_scale_tiles_X);); // NOLINT(*) + ); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); #else NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); From 57982850fb4189d4fafbac39e1afe0afa4d5ba22 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 2 May 2026 12:09:53 -0700 Subject: [PATCH 17/21] Fix test after rebase Signed-off-by: Ziang Li --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 42 +++++++++++++++---- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index c59c895965..cc45b2fce5 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -582,9 +582,13 @@ void compareResults_nvfp4(const Tensor &test, compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol); } -void compare_per_token_amax(const Tensor &test_amax, const std::vector &ref_amax) { - test_amax.to_cpu(); - const float *test_amax_data = test_amax.rowwise_cpu_dptr(); +void compare_per_token_amax(const float *test_amax, const std::vector &ref_amax) { + std::vector test_amax_data(ref_amax.size()); + ASSERT_EQ(cudaMemcpy(test_amax_data.data(), + test_amax, + ref_amax.size() * sizeof(float), + cudaMemcpyDeviceToHost), + cudaSuccess); for (size_t row = 0; row < ref_amax.size(); ++row) { ASSERT_EQ(test_amax_data[row], ref_amax[row]) << "Per-token amax mismatch at row " << row; @@ -623,7 +627,8 @@ void performTest(float (*OP)(const float), Tensor input("input", shape, itype); Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING); - Tensor per_token_amax; + float *per_token_amax = nullptr; + float *per_token_columnwise_amax = nullptr; std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); @@ -635,7 +640,6 @@ void performTest(float (*OP)(const float), bool use_2d_quantization = false; std::vector ref_per_token_amax; if (per_token_activation) { - per_token_amax = Tensor("per_token_amax", std::vector{rows}, DType::kFloat32); compute_ref(OP, input.rowwise_cpu_dptr(), ref_output.get(), @@ -651,20 +655,44 @@ void performTest(float (*OP)(const float), use_2d_quantization, &ref_per_token_amax); + NVTETensor output_tensor = output.data(); + NVTEBasicTensor old_amax; + NVTEBasicTensor old_columnwise_amax; + nvte_get_tensor_param_v2(output_tensor, kNVTEAmax, &old_amax, sizeof(old_amax), nullptr); + nvte_get_tensor_param_v2(output_tensor, kNVTEColumnwiseAmax, &old_columnwise_amax, + sizeof(old_columnwise_amax), nullptr); + if (old_amax.data_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaFree(old_amax.data_ptr)); + } + if (old_columnwise_amax.data_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaFree(old_columnwise_amax.data_ptr)); + } + NVTE_CHECK_CUDA(cudaMalloc(&per_token_amax, rows * sizeof(float))); + NVTE_CHECK_CUDA(cudaMalloc(&per_token_columnwise_amax, rows * sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(per_token_amax, 0, rows * sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(per_token_columnwise_amax, 0, rows * sizeof(float))); std::vector per_token_amax_shape = {rows}; - NVTEBasicTensor amax_tensor = {per_token_amax.rowwise_dptr(), + NVTEBasicTensor amax_tensor = {per_token_amax, static_cast(DType::kFloat32), nvte_make_shape(per_token_amax_shape.data(), per_token_amax_shape.size())}; - NVTETensor output_tensor = output.data(); + NVTEBasicTensor columnwise_amax_tensor = {per_token_columnwise_amax, + static_cast(DType::kFloat32), + nvte_make_shape(per_token_amax_shape.data(), + per_token_amax_shape.size())}; nvte_set_tensor_param_v2(output_tensor, kNVTEAmax, &amax_tensor, sizeof(amax_tensor)); + nvte_set_tensor_param_v2(output_tensor, kNVTEColumnwiseAmax, &columnwise_amax_tensor, + sizeof(columnwise_amax_tensor)); } else { // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues const float amax = 448.0f * 6.0f * 8.0f; + // Set 2nd stage NVFP4 scaling factor output.set_tensor_amax(amax); output.set_tensor_amax_columnwise(amax); + bool use_2d_quantization = false; + compute_ref(OP, input.rowwise_cpu_dptr(), ref_output.get(), From 233bb4456cee29c767d6a632dd00cc53479a8558 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 2 May 2026 16:24:16 -0700 Subject: [PATCH 18/21] Clean up cpp test Signed-off-by: Ziang Li --- .../cpp/operator/test_cast_nvfp4_transpose.cu | 81 +++++++++---------- 1 file changed, 40 insertions(+), 41 deletions(-) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index cc45b2fce5..f7a16539cc 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -582,10 +582,16 @@ void compareResults_nvfp4(const Tensor &test, compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol); } -void compare_per_token_amax(const float *test_amax, const std::vector &ref_amax) { +void compare_per_token_amax(const Tensor &output, const std::vector &ref_amax) { + NVTEBasicTensor amax; + nvte_get_tensor_param_v2(output.data(), kNVTEAmax, &amax, sizeof(amax), nullptr); + ASSERT_NE(amax.data_ptr, nullptr); + ASSERT_EQ(amax.shape.ndim, 1); + ASSERT_EQ(amax.shape.data[0], ref_amax.size()); + std::vector test_amax_data(ref_amax.size()); ASSERT_EQ(cudaMemcpy(test_amax_data.data(), - test_amax, + amax.data_ptr, ref_amax.size() * sizeof(float), cudaMemcpyDeviceToHost), cudaSuccess); @@ -595,6 +601,32 @@ void compare_per_token_amax(const float *test_amax, const std::vector &re } } +void set_per_token_amax_metadata(Tensor &output, const size_t rows) { + const std::vector shape = {rows}; + NVTETensor output_tensor = output.data(); + + auto replace_amax = [&](const NVTETensorParam param) { + NVTEBasicTensor old_amax; + nvte_get_tensor_param_v2(output_tensor, param, &old_amax, sizeof(old_amax), nullptr); + if (old_amax.data_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaFree(old_amax.data_ptr)); + } + + float *amax = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&amax, rows * sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(amax, 0, rows * sizeof(float))); + + NVTEBasicTensor amax_tensor = {amax, + static_cast(DType::kFloat32), + nvte_make_shape(shape.data(), shape.size())}; + nvte_set_tensor_param_v2(output_tensor, param, &amax_tensor, sizeof(amax_tensor)); + return amax; + }; + + replace_amax(kNVTEAmax); + replace_amax(kNVTEColumnwiseAmax); +} + template void performTest(float (*OP)(const float), const std::vector& shape, @@ -627,8 +659,6 @@ void performTest(float (*OP)(const float), Tensor input("input", shape, itype); Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING); - float *per_token_amax = nullptr; - float *per_token_columnwise_amax = nullptr; std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); @@ -637,9 +667,12 @@ void performTest(float (*OP)(const float), fillCase(&input, InputsFillCase::uniform); - bool use_2d_quantization = false; + // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues + const float amax = 448.0f * 6.0f * 8.0f; std::vector ref_per_token_amax; + bool use_2d_quantization = false; if (per_token_activation) { + set_per_token_amax_metadata(output, rows); compute_ref(OP, input.rowwise_cpu_dptr(), ref_output.get(), @@ -654,45 +687,10 @@ void performTest(float (*OP)(const float), use_fast_math, use_2d_quantization, &ref_per_token_amax); - - NVTETensor output_tensor = output.data(); - NVTEBasicTensor old_amax; - NVTEBasicTensor old_columnwise_amax; - nvte_get_tensor_param_v2(output_tensor, kNVTEAmax, &old_amax, sizeof(old_amax), nullptr); - nvte_get_tensor_param_v2(output_tensor, kNVTEColumnwiseAmax, &old_columnwise_amax, - sizeof(old_columnwise_amax), nullptr); - if (old_amax.data_ptr != nullptr) { - NVTE_CHECK_CUDA(cudaFree(old_amax.data_ptr)); - } - if (old_columnwise_amax.data_ptr != nullptr) { - NVTE_CHECK_CUDA(cudaFree(old_columnwise_amax.data_ptr)); - } - NVTE_CHECK_CUDA(cudaMalloc(&per_token_amax, rows * sizeof(float))); - NVTE_CHECK_CUDA(cudaMalloc(&per_token_columnwise_amax, rows * sizeof(float))); - NVTE_CHECK_CUDA(cudaMemset(per_token_amax, 0, rows * sizeof(float))); - NVTE_CHECK_CUDA(cudaMemset(per_token_columnwise_amax, 0, rows * sizeof(float))); - std::vector per_token_amax_shape = {rows}; - NVTEBasicTensor amax_tensor = {per_token_amax, - static_cast(DType::kFloat32), - nvte_make_shape(per_token_amax_shape.data(), - per_token_amax_shape.size())}; - NVTEBasicTensor columnwise_amax_tensor = {per_token_columnwise_amax, - static_cast(DType::kFloat32), - nvte_make_shape(per_token_amax_shape.data(), - per_token_amax_shape.size())}; - nvte_set_tensor_param_v2(output_tensor, kNVTEAmax, &amax_tensor, sizeof(amax_tensor)); - nvte_set_tensor_param_v2(output_tensor, kNVTEColumnwiseAmax, &columnwise_amax_tensor, - sizeof(columnwise_amax_tensor)); } else { - // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues - const float amax = 448.0f * 6.0f * 8.0f; - // Set 2nd stage NVFP4 scaling factor output.set_tensor_amax(amax); output.set_tensor_amax_columnwise(amax); - - bool use_2d_quantization = false; - compute_ref(OP, input.rowwise_cpu_dptr(), ref_output.get(), @@ -707,6 +705,7 @@ void performTest(float (*OP)(const float), use_fast_math, use_2d_quantization); } + // Initialize stochastic rounding Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); rng_state.rowwise_cpu_dptr()[0] = 123; // rng_seed @@ -769,7 +768,7 @@ void performTest(float (*OP)(const float), scale_mismatches_num); if (per_token_activation) { - compare_per_token_amax(per_token_amax, ref_per_token_amax); + compare_per_token_amax(output, ref_per_token_amax); } } From 47c9cde6f07ad46ca902c24e1337172fe65d786d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 2 May 2026 16:31:27 -0700 Subject: [PATCH 19/21] Extend cpp dequantize test Signed-off-by: Ziang Li --- tests/cpp/operator/test_dequantize_nvfp4.cu | 117 +++++++++++++++++--- 1 file changed, 99 insertions(+), 18 deletions(-) diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index 96e85cb5ed..f932c7dd7a 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -42,7 +42,7 @@ float2 cvt_fp4x2_to_float2(fp4e2m1x2 fp4_pair) { template void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, const fp8e4m3 *scales, - float amax, + const std::vector &amax, OType *output, size_t rows, size_t cols, @@ -55,7 +55,8 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, for (size_t row = 0; row < rows; ++row) { for (size_t block = 0; block < Mread; ++block) { const fp8e4m3 scale = scales[row * scale_stride + block]; - const float final_scale = static_cast(scale) * amax * factor_inv; + const float final_scale = + static_cast(scale) * (amax.size() == 1 ? amax[0] : amax[row]) * factor_inv; for (size_t pair_idx = 0; pair_idx < bytes_per_block; ++pair_idx) { const size_t byte_idx = @@ -74,6 +75,43 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, } } +void set_per_token_amax_metadata(Tensor &output, const size_t rows) { + const std::vector shape = {rows}; + NVTETensor output_tensor = output.data(); + + auto replace_amax = [&](const NVTETensorParam param) { + NVTEBasicTensor old_amax; + nvte_get_tensor_param_v2(output_tensor, param, &old_amax, sizeof(old_amax), nullptr); + if (old_amax.data_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaFree(old_amax.data_ptr)); + } + + float *amax = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&amax, rows * sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(amax, 0, rows * sizeof(float))); + + NVTEBasicTensor amax_tensor = {amax, + static_cast(DType::kFloat32), + nvte_make_shape(shape.data(), shape.size())}; + nvte_set_tensor_param_v2(output_tensor, param, &amax_tensor, sizeof(amax_tensor)); + }; + + replace_amax(kNVTEAmax); + replace_amax(kNVTEColumnwiseAmax); +} + +std::vector get_amax_values(const Tensor &tensor) { + NVTEBasicTensor amax; + nvte_get_tensor_param_v2(tensor.data(), kNVTEAmax, &amax, sizeof(amax), nullptr); + const size_t numel = amax.shape.ndim == 0 ? 1 : amax.shape.data[0]; + std::vector amax_values(numel); + if (numel > 0) { + NVTE_CHECK_CUDA(cudaMemcpy(amax_values.data(), amax.data_ptr, numel * sizeof(float), + cudaMemcpyDeviceToHost)); + } + return amax_values; +} + template float compute_amax(const test::Tensor &t, size_t rows, size_t cols) { t.to_cpu(); @@ -88,7 +126,8 @@ float compute_amax(const test::Tensor &t, size_t rows, size_t cols) { // Quantize a high-precision input to NVFP4, then dequantize and compare // against a CPU reference computed from the quantized data. template -void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { +void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, + const bool per_token_activation) { using namespace test; DType otype = TypeInfo::dtype; @@ -97,14 +136,22 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { Tensor quantized("quantized", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - if (rows > 0 && cols > 0) { + if (per_token_activation) { + set_per_token_amax_metadata(quantized, rows); + } else if (rows > 0 && cols > 0) { quantized.set_tensor_amax(compute_amax(input, rows, cols)); } else { quantized.set_tensor_amax(0.0f); } if (rows > 0 && cols > 0) { - nvte_quantize(input.data(), quantized.data(), 0); + if (per_token_activation) { + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_per_token_activation(true); + nvte_quantize_v2(input.data(), quantized.data(), quant_config, 0); + } else { + nvte_quantize(input.data(), quantized.data(), 0); + } cudaDeviceSynchronize(); } @@ -120,7 +167,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { const uint8_t *fp4_data = reinterpret_cast(quantized.rowwise_cpu_dptr()); const fp8e4m3 *scales = quantized.rowwise_cpu_scale_inv_ptr(); - const float amax_val = quantized.amax(); + const std::vector amax_val = get_amax_values(quantized); const NVTEShape scale_shape = quantized.rowwise_scale_inv_shape(); const size_t scale_stride = scale_shape.data[scale_shape.ndim - 1]; @@ -137,7 +184,8 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { // Dequantize NVFP4 with GEMM-swizzled scales and compare against compact path. template -void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols) { +void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, + const bool per_token_activation) { using namespace test; DType otype = TypeInfo::dtype; @@ -146,14 +194,22 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols) Tensor quantized_compact("quantized_compact", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - if (rows > 0 && cols > 0) { + if (per_token_activation) { + set_per_token_amax_metadata(quantized_compact, rows); + } else if (rows > 0 && cols > 0) { quantized_compact.set_tensor_amax(compute_amax(input, rows, cols)); } else { quantized_compact.set_tensor_amax(0.0f); } if (rows > 0 && cols > 0) { - nvte_quantize(input.data(), quantized_compact.data(), 0); + if (per_token_activation) { + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_per_token_activation(true); + nvte_quantize_v2(input.data(), quantized_compact.data(), quant_config, 0); + } else { + nvte_quantize(input.data(), quantized_compact.data(), 0); + } cudaDeviceSynchronize(); } @@ -165,13 +221,30 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols) // Create tensor with same FP4 data but swizzled scales Tensor quantized_swizzled("quantized_swizzled", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - quantized_swizzled.set_tensor_amax(0.0f); + if (per_token_activation) { + set_per_token_amax_metadata(quantized_swizzled, rows); + } else { + quantized_swizzled.set_tensor_amax(0.0f); + } quantized_swizzled.set_with_gemm_swizzled_scales(true); // Copy amax and scale from compact to swizzled before FP4 data, // since from_cpu() uploads all CPU buffers (including zero-init data). quantized_compact.to_cpu(); - quantized_swizzled.set_tensor_amax(quantized_compact.amax()); + if (per_token_activation) { + NVTEBasicTensor compact_amax; + NVTEBasicTensor swizzled_amax; + nvte_get_tensor_param_v2(quantized_compact.data(), kNVTEAmax, &compact_amax, + sizeof(compact_amax), nullptr); + nvte_get_tensor_param_v2(quantized_swizzled.data(), kNVTEAmax, &swizzled_amax, + sizeof(swizzled_amax), nullptr); + if (rows > 0) { + NVTE_CHECK_CUDA(cudaMemcpy(swizzled_amax.data_ptr, compact_amax.data_ptr, + rows * sizeof(float), cudaMemcpyDeviceToDevice)); + } + } else { + quantized_swizzled.set_tensor_amax(quantized_compact.amax()); + } // Copy FP4 data after from_cpu() to avoid being overwritten const size_t data_bytes = rows * cols / 2; @@ -227,7 +300,8 @@ std::vector> nvfp4_tensor_dims = { class DequantizeNVFP4TestSuite : public ::testing::TestWithParam , - transformer_engine::DType>> {}; + transformer_engine::DType, + bool>> {}; TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) { @@ -237,10 +311,11 @@ TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); + const bool per_token_activation = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4( - tensor_size.first, tensor_size.second); + tensor_size.first, tensor_size.second, per_token_activation); ); } @@ -249,19 +324,22 @@ INSTANTIATE_TEST_SUITE_P( DequantizeNVFP4TestSuite, ::testing::Combine( ::testing::ValuesIn(nvfp4_tensor_dims), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Bool()), [](const testing::TestParamInfo& info) { std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + - test::typeName(std::get<1>(info.param)); + test::typeName(std::get<1>(info.param)) + "X" + + (std::get<2>(info.param) ? "PerToken" : "PerTensor"); return name; } ); class DequantizeNVFP4SwizzledTestSuite : public ::testing::TestWithParam , - transformer_engine::DType>> {}; + transformer_engine::DType, + bool>> {}; TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) { @@ -271,10 +349,11 @@ TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); + const bool per_token_activation = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4_swizzled( - tensor_size.first, tensor_size.second); + tensor_size.first, tensor_size.second, per_token_activation); ); } @@ -283,12 +362,14 @@ INSTANTIATE_TEST_SUITE_P( DequantizeNVFP4SwizzledTestSuite, ::testing::Combine( ::testing::ValuesIn(nvfp4_tensor_dims), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Bool()), [](const testing::TestParamInfo& info) { std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + test::typeName(std::get<1>(info.param)) + "X" + + (std::get<2>(info.param) ? "PerToken" : "PerTensor") + "X" + "Swizzled"; return name; } From 21a19f5ecf882b0c5faa463dcdf2721b6a9692dd Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 2 May 2026 17:23:27 -0700 Subject: [PATCH 20/21] Only pass `per_token_activation` to forward activation quantizer and clean up Signed-off-by: Ziang Li --- tests/pytorch/test_backward_override.py | 56 +++++++------------ tests/pytorch/test_recipe.py | 27 ++++++++- tests/pytorch/utils.py | 6 -- .../pytorch/cpp_extensions/gemm.py | 12 +++- transformer_engine/pytorch/quantization.py | 4 +- 5 files changed, 59 insertions(+), 46 deletions(-) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 2156d6cef0..ed099314f8 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -99,12 +99,6 @@ def backward_override(request: pytest.FixtureRequest) -> str: return request.param -def _make_backward_test_recipe(recipe_name: str, **recipe_kwargs) -> Optional[recipe.Recipe]: - if recipe_name == "nvfp4_per_token" and "backward_override" not in recipe_kwargs: - recipe_kwargs["backward_override"] = "dequantized" - return make_recipe(recipe_name, **recipe_kwargs) - - # -------------------------- # Test cases # -------------------------- @@ -867,7 +861,7 @@ def test_linear_like_backward_override_matches_reference( _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) in_features = input_shape[-1] - quantized_ref_recipe = _make_backward_test_recipe(recipe_name) + quantized_ref_recipe = make_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override(module_type, mode_recipe, backward_override) @@ -1051,7 +1045,7 @@ def test_grouped_linear_backward_override_matches_reference( num_gemms = len(m_splits) num_tokens = sum(m_splits) - quantized_ref_recipe = _make_backward_test_recipe(recipe_name) + quantized_ref_recipe = make_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) @@ -1220,11 +1214,9 @@ def test_linear_like_runtime_backward_override_switch_updates_ctx( x = torch.randn(*input_shape, dtype=dtype, device="cuda") dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") - default_recipe = _make_backward_test_recipe(recipe_name) + default_recipe = make_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override(module_type, mode_recipe, backward_override) - expected_default_mode = default_recipe.backward_override - expected_default_fp8 = expected_default_mode is None *_, default_ctx = _run_single_step_with_ctx_state(module, x, dy, default_recipe) ( @@ -1233,10 +1225,10 @@ def test_linear_like_runtime_backward_override_switch_updates_ctx( default_grad_output_quantizer, default_reduce_and_update, ) = default_ctx - assert default_mode == expected_default_mode - assert default_fp8 == expected_default_fp8 - assert (default_grad_output_quantizer is not None) == expected_default_fp8 - assert default_reduce_and_update == expected_default_fp8 + assert default_mode is None + assert default_fp8 + assert default_grad_output_quantizer is not None + assert default_reduce_and_update *_, switched_ctx = _run_single_step_with_ctx_state(module, x, dy, mode_recipe) switched_mode, switched_fp8, switched_grad_output_quantizer, switched_reduce_and_update = ( @@ -1254,10 +1246,10 @@ def test_linear_like_runtime_backward_override_switch_updates_ctx( default_grad_output_quantizer_after, default_reduce_and_update_after, ) = default_ctx_after - assert default_mode_after == expected_default_mode - assert default_fp8_after == expected_default_fp8 - assert (default_grad_output_quantizer_after is not None) == expected_default_fp8 - assert default_reduce_and_update_after == expected_default_fp8 + assert default_mode_after is None + assert default_fp8_after + assert default_grad_output_quantizer_after is not None + assert default_reduce_and_update_after @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @@ -1292,11 +1284,9 @@ def test_grouped_linear_runtime_backward_override_switch_updates_ctx( x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") - default_recipe = _make_backward_test_recipe(recipe_name) + default_recipe = make_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) - expected_default_mode = default_recipe.backward_override - expected_default_fp8 = expected_default_mode is None *_, default_ctx = _run_grouped_linear_single_step_with_ctx_state( module, @@ -1306,9 +1296,9 @@ def test_grouped_linear_runtime_backward_override_switch_updates_ctx( default_recipe, ) default_mode, default_fp8, default_reduce_and_update = default_ctx - assert default_mode == expected_default_mode - assert default_fp8 == expected_default_fp8 - assert default_reduce_and_update == expected_default_fp8 + assert default_mode is None + assert default_fp8 + assert default_reduce_and_update *_, switched_ctx = _run_grouped_linear_single_step_with_ctx_state( module, @@ -1330,9 +1320,9 @@ def test_grouped_linear_runtime_backward_override_switch_updates_ctx( default_recipe, ) default_mode_after, default_fp8_after, default_reduce_and_update_after = default_ctx_after - assert default_mode_after == expected_default_mode - assert default_fp8_after == expected_default_fp8 - assert default_reduce_and_update_after == expected_default_fp8 + assert default_mode_after is None + assert default_fp8_after + assert default_reduce_and_update_after @pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) @@ -1363,7 +1353,7 @@ def test_fused_linear_paths_match_backward_override_reference( reset_rng_states() - quantized_ref_recipe = _make_backward_test_recipe(recipe_name) + quantized_ref_recipe = make_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("ops_linear", mode_recipe, backward_override) @@ -1504,7 +1494,7 @@ def test_fused_bias_activation_matches_masked_linear_backward( reset_rng_states() in_features = input_shape[-1] - quantized_ref_recipe = _make_backward_test_recipe(recipe_name) + quantized_ref_recipe = make_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("ops_linear", mode_recipe, backward_override) @@ -1743,11 +1733,7 @@ def test_backward_override_memory_peak_report( x = torch.randn(*input_shape, dtype=dtype, device="cuda") dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") - modes = ( - ("high_precision", "dequantized") - if recipe_name == "nvfp4_per_token" - else (None, "high_precision", "dequantized") - ) + modes = (None, "high_precision", "dequantized") mode_results: dict[str, dict[str, float] | str] = {} for mode in modes: diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index b44f27765a..f12148232c 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -25,10 +25,16 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.quantization import ( FP8GlobalStateManager, + NVFP4BlockScalingRecipeState, _amax_and_scale_update, ) import transformer_engine.pytorch.ops as te_ops -from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8BlockScaling, + MXFP8BlockScaling, + NVFP4BlockScaling, +) # Check if FP8 is supported fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) @@ -507,6 +513,25 @@ def test_quantizer_update(self, module_class): y = module(x) +@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) +def test_nvfp4_per_token_quantizer_roles(): + recipe = NVFP4BlockScaling(per_token_activation=True) + + forward_quantizers = NVFP4BlockScalingRecipeState( + recipe, + mode="forward", + num_quantizers=3, + ).make_quantizers() + assert [q.per_token_activation for q in forward_quantizers] == [True, False, True] + + backward_quantizers = NVFP4BlockScalingRecipeState( + recipe, + mode="backward", + num_quantizers=2, + ).make_quantizers() + assert [q.per_token_activation for q in backward_quantizers] == [False, False] + + @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("per_token_activation", [False, True], ids=["nvfp4", "nvfp4_per_token"]) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index b88bcd31b5..3dc4cdffe8 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -175,12 +175,6 @@ def skip_unsupported_backward_override( backward_override: Optional[str], ) -> None: """Skip known unsupported layer/recipe/backward-override combinations used in tests.""" - if ( - quant_recipe is not None - and getattr(quant_recipe, "per_token_activation", False) - and backward_override is None - ): - pytest.skip("Per-token NVFP4 requires an explicit backward override.") if backward_override is None: return if quant_recipe is None and backward_override is not None: diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index d23fdf1b59..79a7d28df5 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -203,7 +203,11 @@ def general_gemm( out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) else: assert layout[1] == "N", "Per-token NVFP4 GEMM currently supports N-layout B only." - assert not grad, "Per-token NVFP4 GEMM currently supports fprop only." + if grad: + raise RuntimeError( + "Per-token NVFP4 GEMM currently supports fprop only. " + "Backward NVFP4 gradient quantizers should use scalar global amax." + ) assert not gelu, "Per-token NVFP4 GEMM currently does not support fused GELU." assert not accumulate, "Per-token NVFP4 GEMM currently does not support accumulation." assert ( @@ -303,7 +307,11 @@ def general_grouped_gemm( if any(_is_nvfp4_per_token_tensor(tensor) for tensor in B): assert layout[1] == "N", "Per-token NVFP4 grouped GEMM currently supports N-layout B only." - assert not grad, "Per-token NVFP4 grouped GEMM currently supports fprop only." + if grad: + raise RuntimeError( + "Per-token NVFP4 grouped GEMM currently supports fprop only. " + "Backward NVFP4 gradient quantizers should use scalar global amax." + ) assert not gelu, "Per-token NVFP4 grouped GEMM currently does not support fused GELU." assert ( not accumulate diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 6ffca84a7d..2cb6c21946 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1375,7 +1375,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=qparams.random_hadamard_transform, with_2d_quantization=qparams.fp4_2d_quantization, stochastic_rounding=qparams.stochastic_rounding, - per_token_activation=self.recipe.per_token_activation, + per_token_activation=self.recipe.per_token_activation and idx % 3 != 1, ) return [_make_quantizer(idx) for idx in range(self.num_quantizers)] @@ -1390,7 +1390,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, - per_token_activation=self.recipe.per_token_activation, + per_token_activation=False, ) for _ in range(self.num_quantizers) ] From 75c19d0e172974a620fc38248851895f23a6c583 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 2 May 2026 17:46:18 -0700 Subject: [PATCH 21/21] Minor fix test Signed-off-by: Ziang Li --- tests/pytorch/test_backward_override.py | 2 +- tests/pytorch/test_sanity.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index ed099314f8..15f08975e2 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -255,7 +255,7 @@ def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int] pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") if recipe_name in ("nvfp4", "nvfp4_per_token") and any(m % 16 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + NVFP4 requires each non-empty m_split divisible by 16.") - if recipe_name == "nvfp4" and any(m % 64 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_per_token") and any(m % 64 != 0 for m in non_empty_splits): pytest.skip( "GroupedLinear + NVFP4 grouped split_quantize currently requires each non-empty " "m_split divisible by 64 due to grouped amax kernel constraints." diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index c7527ecfe4..bb1c952163 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -580,6 +580,8 @@ def test_sanity_grouped_linear( pytest.skip("NVFP4 not supported for grouped linear") if dtype == torch.float16: pytest.skip("FP16 output for NVFP4 not supported") + if backward_override is None and dtype != torch.bfloat16: + pytest.skip("NVFP4 grouped default backward requires BF16 grad output") use_fp8 = fp8_recipe is not None with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):