Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/envvars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,12 @@ Kernel Configuration
:Default: ``0``
:Description: Emit a warning when falling back from CUTLASS to cuBLAS for grouped GEMM operations.

.. envvar:: NVTE_NVFP4_PER_TOKEN_ACTIVATION

:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable per-token activation quantization for the ``NVFP4BlockScaling`` recipe in GroupedLinear split-quantize paths. When set to ``1`` (or when ``NVFP4BlockScaling(per_token_activation=True)`` is used), NVFP4 rowwise ``amax`` metadata stores one FP32 value per token (row) instead of a single scalar.

Torch Compilation and Fusion
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
211 changes: 182 additions & 29 deletions tests/cpp/operator/test_cast_nvfp4_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,14 @@ void quantize_nvfp4_1d(float (*OP)(const float),
block_amax = std::max(block_amax, std::abs(elt));
}

// 2. Compute E4M3 scaling factor
// Compute per-block encoding/decoding scaling factor
const float S_dec_b = block_amax / 6.0f;

// Scale & Store per-block decoding scaling factor
const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
// Compute and store the per-block FP8 decode scale
const float S_dec_b = block_amax * (S_enc * (1.0f / 6.0f));
const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(fminf(S_dec_b, Numeric_Traits<float>::maxNorm));
const float S_dec_b_fp32 = static_cast<float>(S_dec_b_fp8);

// Compute "correct" per-block encoding scaling factor
const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32;
const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f :
fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits<float>::maxNorm);
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to change here to stay aligned with pytorch reference.


const size_t scale_idx = i * scales_stride + block_X;
scales[scale_idx] = S_dec_b_fp8;
Expand Down Expand Up @@ -317,11 +315,69 @@ void compute_ref(float (*OP)(const float),
const size_t scales_stride,
const size_t scales_stride_t,
const bool use_fast_math,
const bool use_2d_quantization = false)
const bool use_2d_quantization = false,
std::vector<float> *per_token_amax = nullptr)
{
std::vector<InputType> input_t = create_transpose(input, rows, cols);

if (use_2d_quantization) {
if (per_token_amax != nullptr) {
constexpr size_t kBlockSize = 16;
constexpr float fp4_max_inv = 1.0f / 6.0f;
constexpr float float_max = Numeric_Traits<float>::maxNorm;

per_token_amax->resize(rows, 0.0f);
for (size_t row = 0; row < rows; ++row) {
float row_amax = 0.0f;
for (size_t col = 0; col < cols; ++col) {
row_amax = fmaxf(row_amax, fabsf(static_cast<float>(input[row * cols + col])));
}
(*per_token_amax)[row] = row_amax;
quantize_nvfp4(OP,
input + row * cols,
output + row * (cols / 2),
scales + row * scales_stride,
1,
cols,
scales_stride,
row_amax,
use_fast_math,
use_2d_quantization);
}

for (size_t col = 0; col < cols; ++col) {
for (size_t row_start = 0; row_start < rows; row_start += kBlockSize) {
float vals[kBlockSize];
float s_enc[kBlockSize];
float scaled_block_amax = 0.0f;
for (size_t i = 0; i < kBlockSize; ++i) {
const size_t row = row_start + i;
const float val = static_cast<float>(input[row * cols + col]);
const float S_enc =
compute_global_encode_scaling_factor_FP4((*per_token_amax)[row], false);
vals[i] = val;
s_enc[i] = S_enc;
scaled_block_amax = fmaxf(scaled_block_amax, fabsf(val) * (S_enc * fp4_max_inv));
}

const float S_dec_b_f32 = fminf(scaled_block_amax, float_max);
const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b_f32);
scales_t[col * scales_stride_t + row_start / kBlockSize] = S_dec_b_fp8;

for (size_t i = 0; i < kBlockSize; i += 2) {
const float S_dec_rowwise_x = 1.0f / s_enc[i];
const float S_dec_rowwise_y = 1.0f / s_enc[i + 1];
const float S_dec_b_fp32 = static_cast<float>(S_dec_b_fp8);
const float S_enc_b_fp8_x =
fminf(1.0f / (S_dec_b_fp32 * S_dec_rowwise_x), float_max);
const float S_enc_b_fp8_y =
fminf(1.0f / (S_dec_b_fp32 * S_dec_rowwise_y), float_max);
const float2 scaled_elt_pair = {vals[i] * S_enc_b_fp8_x,
vals[i + 1] * S_enc_b_fp8_y};
output_t[(col * rows + row_start + i) / 2] = fp4e2m1x2(scaled_elt_pair);
}
}
}
} else if (use_2d_quantization) {
// Step 1: Compute mathematical 8×8 scaling factors
std::vector<std::vector<fp8e4m3>> math_scales;
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math);
Expand Down Expand Up @@ -526,10 +582,24 @@ void compareResults_nvfp4(const Tensor &test,
compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol);
}

void compare_per_token_amax(const float *test_amax, const std::vector<float> &ref_amax) {
std::vector<float> test_amax_data(ref_amax.size());
ASSERT_EQ(cudaMemcpy(test_amax_data.data(),
test_amax,
ref_amax.size() * sizeof(float),
cudaMemcpyDeviceToHost),
cudaSuccess);
for (size_t row = 0; row < ref_amax.size(); ++row) {
ASSERT_EQ(test_amax_data[row], ref_amax[row])
<< "Per-token amax mismatch at row " << row;
}
}

template <typename InputType>
void performTest(float (*OP)(const float),
const std::vector<size_t>& shape,
const bool use_fast_math) {
const bool use_fast_math,
const bool per_token_activation = false) {
using namespace test;

DType itype = TypeInfo<InputType>::dtype;
Expand Down Expand Up @@ -557,6 +627,8 @@ void performTest(float (*OP)(const float),

Tensor input("input", shape, itype);
Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING);
float *per_token_amax = nullptr;
float *per_token_columnwise_amax = nullptr;

std::unique_ptr<fp4e2m1x2[]> ref_output = std::make_unique<fp4e2m1x2[]>(rows * (cols / 2));
std::unique_ptr<fp4e2m1x2[]> ref_output_t = std::make_unique<fp4e2m1x2[]>(cols * (rows / 2));
Expand All @@ -565,28 +637,76 @@ void performTest(float (*OP)(const float),

fillCase<fp32>(&input, InputsFillCase::uniform);

// Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues
const float amax = 448.0f * 6.0f * 8.0f;
bool use_2d_quantization = false;
std::vector<float> ref_per_token_amax;
if (per_token_activation) {
compute_ref<InputType>(OP,
input.rowwise_cpu_dptr<InputType>(),
ref_output.get(),
ref_output_t.get(),
ref_scales.get(),
ref_scales_t.get(),
0.0f,
rows,
cols,
scales_stride,
scales_stride_t,
use_fast_math,
use_2d_quantization,
&ref_per_token_amax);

NVTETensor output_tensor = output.data();
NVTEBasicTensor old_amax;
NVTEBasicTensor old_columnwise_amax;
nvte_get_tensor_param_v2(output_tensor, kNVTEAmax, &old_amax, sizeof(old_amax), nullptr);
nvte_get_tensor_param_v2(output_tensor, kNVTEColumnwiseAmax, &old_columnwise_amax,
sizeof(old_columnwise_amax), nullptr);
if (old_amax.data_ptr != nullptr) {
NVTE_CHECK_CUDA(cudaFree(old_amax.data_ptr));
}
if (old_columnwise_amax.data_ptr != nullptr) {
NVTE_CHECK_CUDA(cudaFree(old_columnwise_amax.data_ptr));
}
NVTE_CHECK_CUDA(cudaMalloc(&per_token_amax, rows * sizeof(float)));
NVTE_CHECK_CUDA(cudaMalloc(&per_token_columnwise_amax, rows * sizeof(float)));
NVTE_CHECK_CUDA(cudaMemset(per_token_amax, 0, rows * sizeof(float)));
NVTE_CHECK_CUDA(cudaMemset(per_token_columnwise_amax, 0, rows * sizeof(float)));
std::vector<size_t> per_token_amax_shape = {rows};
NVTEBasicTensor amax_tensor = {per_token_amax,
static_cast<NVTEDType>(DType::kFloat32),
nvte_make_shape(per_token_amax_shape.data(),
per_token_amax_shape.size())};
NVTEBasicTensor columnwise_amax_tensor = {per_token_columnwise_amax,
static_cast<NVTEDType>(DType::kFloat32),
nvte_make_shape(per_token_amax_shape.data(),
per_token_amax_shape.size())};
nvte_set_tensor_param_v2(output_tensor, kNVTEAmax, &amax_tensor, sizeof(amax_tensor));
nvte_set_tensor_param_v2(output_tensor, kNVTEColumnwiseAmax, &columnwise_amax_tensor,
sizeof(columnwise_amax_tensor));
} else {
// Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues
const float amax = 448.0f * 6.0f * 8.0f;

// Set 2nd stage NVFP4 scaling factor
output.set_tensor_amax(amax);
output.set_tensor_amax_columnwise(amax);
// Set 2nd stage NVFP4 scaling factor
output.set_tensor_amax(amax);
output.set_tensor_amax_columnwise(amax);

bool use_2d_quantization = false;

compute_ref<InputType>(OP,
input.rowwise_cpu_dptr<InputType>(),
ref_output.get(),
ref_output_t.get(),
ref_scales.get(),
ref_scales_t.get(),
amax,
rows,
cols,
scales_stride,
scales_stride_t,
use_fast_math,
use_2d_quantization);
compute_ref<InputType>(OP,
input.rowwise_cpu_dptr<InputType>(),
ref_output.get(),
ref_output_t.get(),
ref_scales.get(),
ref_scales_t.get(),
amax,
rows,
cols,
scales_stride,
scales_stride_t,
use_fast_math,
use_2d_quantization);
}
// Initialize stochastic rounding
Tensor rng_state("rng_state", std::vector<size_t>{2}, DType::kInt64);
rng_state.rowwise_cpu_dptr<int64_t>()[0] = 123; // rng_seed
Expand All @@ -600,6 +720,7 @@ void performTest(float (*OP)(const float),

// Set 2D quantization based on compile-time flag
quant_config.set_nvfp4_2d_quantization(use_2d_quantization);
quant_config.set_nvfp4_per_token_activation(per_token_activation);

// Call appropriate function based on operation type
// Activation functions take 3 parameters (input, output, stream)
Expand Down Expand Up @@ -646,6 +767,10 @@ void performTest(float (*OP)(const float),
ref_scales_t.get(),
unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t,
scale_mismatches_num);

if (per_token_activation) {
compare_per_token_amax(per_token_amax, ref_per_token_amax);
}
}

std::vector<std::vector<size_t>> tensor_dims = {
Expand Down Expand Up @@ -678,6 +803,7 @@ class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam
<std::tuple<ActivationType,
std::vector<size_t>,
transformer_engine::DType,
bool,
bool>> {};

TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
Expand All @@ -693,6 +819,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
const auto tensor_dims = std::get<1>(GetParam());
const DType input_type = std::get<2>(GetParam());
const bool use_fast_math = std::get<3>(GetParam());
const bool per_token_activation = std::get<4>(GetParam());

// Skip tests if the input tensor is 1D
if (tensor_dims.size() < 2) {
Expand All @@ -710,7 +837,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
}

TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
performTest<InputType>(OP, tensor_dims, use_fast_math);
performTest<InputType>(OP, tensor_dims, use_fast_math, per_token_activation);
);
}

Expand All @@ -733,6 +860,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(Activation_types),
::testing::ValuesIn(tensor_dims),
::testing::Values(DType::kBFloat16),
::testing::Values(false),
::testing::Values(false)),
[](const testing::TestParamInfo<FusedCastTransposeNVFP4TestSuite::ParamType>& info) {
std::string name = to_string(std::get<0>(info.param));
Expand All @@ -746,3 +874,28 @@ INSTANTIATE_TEST_SUITE_P(
}
return name;
});

INSTANTIATE_TEST_SUITE_P(
OperatorTestPerToken,
FusedCastTransposeNVFP4TestSuite,
::testing::Combine(
::testing::Values(ActivationType::Identity),
::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]),
::testing::Values(DType::kBFloat16, DType::kFloat32),
::testing::Values(false),
::testing::Values(true)),
[](const testing::TestParamInfo<FusedCastTransposeNVFP4TestSuite::ParamType>& info) {
std::string name = to_string(std::get<0>(info.param));
const auto& shape = std::get<1>(info.param);
for (const auto& s: shape) {
name += "X" + std::to_string(s);
}
name += "X" + test::typeName(std::get<2>(info.param));
if (std::get<3>(info.param)) {
name += "X_FAST_SCALING";
}
if (std::get<4>(info.param)) {
name += "XPER_TOKEN";
}
return name;
});
Loading