From 38a245a3fcfe63bd67d38d1ced20b8b73a45906c Mon Sep 17 00:00:00 2001 From: David Gornshtein Date: Sun, 26 Apr 2026 23:28:33 +0000 Subject: [PATCH 1/3] [Common/PyTorch] Add MXFP8 cast-and-transpose op MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a fused MXFP8 cast-and-transpose op that takes a high-precision tensor plus the source's existing compact column-wise E8M0 scales and emits row-wise compact MXFP8 storage for the source's logical transpose. The standard MXFP8Quantizer path can already produce row-wise and column-wise MXFP8 from BF16/FP16/FP32 input. There is currently no public TE path that, given X and its compact column-wise scales S_col(X), produces the row-wise compact MXFP8 storage for the logical transpose X.T without either re-reading the BF16 source or copying the existing column-wise MXFP8 payload and scales into transposed row-wise storage. This op closes that gap. It is the building block needed to route MXFP8 backward through TN GEMMs on hardware where cuBLASLt does not currently support MXFP8 backward NN/NT layouts (NVIDIA Spark sm_12.1). On B200 / H100 the new op is unused by default; downstream code can still call it for any path that wants direct transposed-rowwise MXFP8 emission without a payload copy. Surfaces in three layers, all additive: * C API (ABI-safe): - nvte_mxfp8_scaling_transpose_cast(input, scale_inv_colwise, output_rowwise, output_rowwise_scale_inv, rows, cols, stream) — minimal signature, E4M3 output, non-swizzled scales. - nvte_mxfp8_scaling_transpose_cast_v2(..., fp8_dtype, with_gemm_swizzled_scales, stream) — extended signature. * PyTorch extension: transformer_engine_torch.mxfp8_scaling_transpose_cast (default kwargs match the minimal C symbol's behavior). * Python: MXFP8Quantizer.quantize_rowwise_transpose(tensor, columnwise_scale_inv, *, fake_dtype=None, with_gemm_swizzled_scales=None) returns a row-wise-only MXFP8Tensor whose logical shape is tensor.T. No existing C symbol, Python signature, or default behavior is changed. Tests in tests/pytorch/mxfp8/: * test_mxfp8_scaling_transpose_cast.py — byte equivalence vs. column-wise- then-copy reference (E4M3 + E5M2, multiple shapes), Python helper equivalence, decoded-value reconstruction within MXFP8 quantization tolerance, error paths for FP8 input and non-block-aligned dims. * test_mxfp8_scaling_transpose_cast_swizzled.py — with with_gemm_swizzled_scales=True, emitted row-wise payload and scales match the bytes produced by the standard MXFP8Quantizer.quantize swizzled path on the actual transposed source. Comparison is byte-for-byte rather than via decoded values because TE's dequantize kernels intentionally reject with_gemm_swizzled_scales=True inputs (one-way GEMM-operand layout). Tested on NVIDIA GB10 (sm_12.1) with TE rebuilt from this change: all 14 parametrized tests pass. Signed-off-by: David Gornshtein --- .../test_mxfp8_scaling_transpose_cast.py | 208 ++++++++++++++++++ ...t_mxfp8_scaling_transpose_cast_swizzled.py | 77 +++++++ .../include/transformer_engine/recipe.h | 40 ++++ .../common/recipe/mxfp8_scaling.cu | 151 +++++++++++++ transformer_engine/pytorch/csrc/extensions.h | 5 + .../csrc/extensions/fp8_partial_cast.cpp | 22 ++ .../pytorch/csrc/extensions/pybind.cpp | 8 + .../pytorch/tensor/mxfp8_tensor.py | 89 ++++++++ 8 files changed, 600 insertions(+) create mode 100644 tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast.py create mode 100644 tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast_swizzled.py diff --git a/tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast.py b/tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast.py new file mode 100644 index 0000000000..5d467f7345 --- /dev/null +++ b/tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast.py @@ -0,0 +1,208 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for the new MXFP8 cast-and-transpose op. + +These tests are written to drop into tests/pytorch/mxfp8/ of upstream +TransformerEngine. They exercise: + +1. nvte_mxfp8_scaling_transpose_cast numerics vs. an in-test reference + reconstruction (MXFP8Quantizer.quantize columnwise + naive Python + transpose-and-pack). +2. nvte_mxfp8_scaling_transpose_cast byte-for-byte equivalence to a copy + adapter that takes the existing column-wise MXFP8 payload, transposes it, + and rewrites it as row-wise storage. +3. The MXFP8Quantizer.quantize_rowwise_transpose Python helper. +4. The with_gemm_swizzled_scales=True variant (covered in + test_mxfp8_scaling_transpose_cast_swizzled.py). + +All tests gate on: + +- CUDA available +- transformer_engine installed +- transformer_engine_torch.mxfp8_scaling_transpose_cast symbol present +""" + +from __future__ import annotations + +import math + +import pytest +import torch + +te = pytest.importorskip("transformer_engine") +tex = pytest.importorskip("transformer_engine_torch") + +if not torch.cuda.is_available(): + pytest.skip("CUDA required", allow_module_level=True) +if not hasattr(tex, "mxfp8_scaling_transpose_cast"): + pytest.skip("Built TE missing mxfp8_scaling_transpose_cast", allow_module_level=True) + +from transformer_engine.pytorch.constants import MXFP8_BLOCK_SCALING_SIZE +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + + +def _make_source(rows: int, cols: int, dtype=torch.bfloat16, seed: int = 1234) -> torch.Tensor: + g = torch.Generator(device="cuda").manual_seed(seed) + return torch.randn((rows, cols), dtype=dtype, device="cuda", generator=g) * 4.0 + + +def _make_quantizer(fp8_dtype) -> MXFP8Quantizer: + q = MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True) + q.optimize_for_gemm = False + return q + + +def _quantize_with_columnwise(quantizer: MXFP8Quantizer, source: torch.Tensor): + """Quantize source with both row-wise and column-wise MXFP8 storage.""" + quantizer.set_usage(rowwise=True, columnwise=True) + return quantizer.quantize(source) + + +def _copy_adapter_transpose(mxfp8_tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Reference: form transposed row-wise MXFP8 by copying the existing + column-wise MXFP8 payload and column-wise scales into transposed + row-wise storage.""" + cw_data = mxfp8_tensor._columnwise_data.contiguous() + cw_scale = mxfp8_tensor._columnwise_scale_inv.contiguous() + rowwise_data = cw_data.t().contiguous() + rowwise_scale = cw_scale.t().contiguous() + return rowwise_data, rowwise_scale + + +@pytest.mark.parametrize("rows,cols", [(64, 128), (128, 256), (256, 4096)]) +@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) +def test_transpose_cast_matches_copy_adapter_bytes(rows, cols, fp8_dtype): + """Direct byte equivalence: the new op must produce exactly the same + payload and scale bytes as transposing existing column-wise MXFP8 storage.""" + source = _make_source(rows, cols) + quantizer = MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True) + quantizer.optimize_for_gemm = False + mxfp8 = _quantize_with_columnwise(quantizer, source) + + expected_payload, expected_scale = _copy_adapter_transpose(mxfp8) + + rowwise_data = torch.empty((cols, rows), dtype=torch.uint8, device="cuda") + rowwise_scale = torch.empty( + (mxfp8._columnwise_scale_inv.shape[1], mxfp8._columnwise_scale_inv.shape[0]), + dtype=torch.uint8, + device="cuda", + ) + tex.mxfp8_scaling_transpose_cast( + source, + mxfp8._columnwise_scale_inv.contiguous(), + rowwise_data, + rowwise_scale, + rows, + cols, + int(fp8_dtype), + False, # with_gemm_swizzled_scales + ) + torch.cuda.synchronize() + + assert torch.equal(rowwise_data.view(torch.uint8), expected_payload.view(torch.uint8)), ( + "Row-wise MXFP8 payload bytes differ from copy-adapter reference" + ) + assert torch.equal(rowwise_scale, expected_scale), ( + "Row-wise MXFP8 scale bytes differ from copy-adapter reference" + ) + + +@pytest.mark.parametrize("rows,cols", [(64, 128), (256, 4096)]) +def test_quantize_rowwise_transpose_helper_equivalence(rows, cols): + """The Python helper should match the raw extension call.""" + source = _make_source(rows, cols) + fp8_dtype = tex.DType.kFloat8E4M3 + + quantizer = _make_quantizer(fp8_dtype) + mxfp8 = _quantize_with_columnwise(quantizer, source) + + helper_quantizer = _make_quantizer(fp8_dtype) + helper_quantizer.set_usage(rowwise=True, columnwise=False) + transposed = helper_quantizer.quantize_rowwise_transpose( + source, mxfp8._columnwise_scale_inv.contiguous() + ) + + expected_payload, expected_scale = _copy_adapter_transpose(mxfp8) + + assert tuple(transposed.shape) == (cols, rows) + assert transposed._rowwise_data is not None + assert transposed._columnwise_data is None + assert torch.equal( + transposed._rowwise_data.view(torch.uint8), expected_payload.view(torch.uint8) + ) + assert torch.equal(transposed._rowwise_scale_inv, expected_scale) + + +@pytest.mark.parametrize("rows,cols", [(64, 128), (128, 256)]) +def test_transpose_cast_numerical_reconstruction(rows, cols): + """Block-decoded transposed payload should reconstruct source.T to + within MXFP8 quantization tolerance, matching the reference quantizer.""" + source = _make_source(rows, cols).to(torch.bfloat16) + fp8_dtype = tex.DType.kFloat8E4M3 + + quantizer = _make_quantizer(fp8_dtype) + mxfp8 = _quantize_with_columnwise(quantizer, source) + + # Native row-wise reference for source.T: re-quantize the transposed source. + ref_quantizer = _make_quantizer(fp8_dtype) + ref_quantizer.set_usage(rowwise=True, columnwise=False) + ref_t = ref_quantizer.quantize(source.t().contiguous()) + ref_decoded = ref_t.dequantize().to(torch.float32) + + helper_quantizer = _make_quantizer(fp8_dtype) + helper_quantizer.set_usage(rowwise=True, columnwise=False) + transposed = helper_quantizer.quantize_rowwise_transpose( + source, mxfp8._columnwise_scale_inv.contiguous() + ) + got_decoded = transposed.dequantize().to(torch.float32) + + # Both reconstructions of source.T should be within 2x the per-block + # MXFP8 quantization error of one another. They differ only in scale + # selection: native row-wise re-quantizer chooses scales from + # source.T's row blocks, while transpose-cast reuses scales from + # source's column blocks. These are the same blocks of source values, + # so the chosen scales are identical and the decoded outputs should + # match exactly bit-for-bit modulo block-edge effects. + rel = (got_decoded - ref_decoded).norm() / (ref_decoded.norm() + 1e-8) + assert rel.item() < 5e-2, f"transpose-cast reconstruction drifted: rel L2 {rel.item():.4f}" + + +def test_transpose_cast_rejects_fp8_input(): + """High-precision input is required; an FP8 source must error out.""" + source = _make_source(64, 128, dtype=torch.bfloat16) + quantizer = _make_quantizer(tex.DType.kFloat8E4M3) + mxfp8 = _quantize_with_columnwise(quantizer, source) + + rowwise_data = torch.empty((128, 64), dtype=torch.uint8, device="cuda") + rowwise_scale = torch.empty( + (mxfp8._columnwise_scale_inv.shape[1], mxfp8._columnwise_scale_inv.shape[0]), + dtype=torch.uint8, + device="cuda", + ) + with pytest.raises((RuntimeError, TypeError, ValueError)): + tex.mxfp8_scaling_transpose_cast( + mxfp8._rowwise_data, # FP8, not high-precision + mxfp8._columnwise_scale_inv, + rowwise_data, + rowwise_scale, + 64, + 128, + int(tex.DType.kFloat8E4M3), + False, + ) + + +def test_transpose_cast_requires_block_aligned_dims(): + source = _make_source(64, 128) + quantizer = _make_quantizer(tex.DType.kFloat8E4M3) + quantizer.set_usage(rowwise=True, columnwise=False) + bad_source = torch.randn(48, 128, dtype=torch.bfloat16, device="cuda") + bad_scale = torch.zeros( + (max(1, math.ceil(48 / MXFP8_BLOCK_SCALING_SIZE)), 128), + dtype=torch.uint8, + device="cuda", + ) + with pytest.raises((RuntimeError, ValueError)): + quantizer.quantize_rowwise_transpose(bad_source, bad_scale) diff --git a/tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast_swizzled.py b/tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast_swizzled.py new file mode 100644 index 0000000000..6f944cbdcc --- /dev/null +++ b/tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast_swizzled.py @@ -0,0 +1,77 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""GEMM-swizzled scale layout test for the MXFP8 cast-and-transpose op. + +When with_gemm_swizzled_scales=True, the new op must write row-wise scale +bytes directly in the layout consumed by MXFP8 GEMM (the same layout produced +by the standard MXFP8Quantizer.quantize(..., with_gemm_swizzled_scales=True) +path) instead of the compact layout. This test compares the swizzled scales +emitted by the new op against the swizzled scales produced by re-quantizing +the transposed source through the standard quantizer with swizzled output. +""" + +from __future__ import annotations + +import pytest +import torch + +te = pytest.importorskip("transformer_engine") +tex = pytest.importorskip("transformer_engine_torch") + +if not torch.cuda.is_available(): + pytest.skip("CUDA required", allow_module_level=True) +if not hasattr(tex, "mxfp8_scaling_transpose_cast"): + pytest.skip("Built TE missing mxfp8_scaling_transpose_cast", allow_module_level=True) + +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + + +def _make_source(rows: int, cols: int, seed: int = 1234) -> torch.Tensor: + g = torch.Generator(device="cuda").manual_seed(seed) + return torch.randn((rows, cols), dtype=torch.bfloat16, device="cuda", generator=g) * 4.0 + + +def _quantize_native_swizzled_transpose(source: torch.Tensor): + """Reference: re-quantize the actual transpose with the standard quantizer + and swizzled scales. The byte content of the row-wise scales for source.T + is what the new op should produce.""" + q = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=False) + q.optimize_for_gemm = True + q.set_usage(rowwise=True, columnwise=False) + return q.quantize(source.t().contiguous()) + + +@pytest.mark.parametrize("rows,cols", [(128, 256), (256, 4096)]) +def test_swizzled_scales_match_native_transpose(rows, cols): + source = _make_source(rows, cols) + fp8_dtype = tex.DType.kFloat8E4M3 + + column_quantizer = MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True) + column_quantizer.optimize_for_gemm = False + column_quantizer.set_usage(rowwise=True, columnwise=True) + column_mxfp8 = column_quantizer.quantize(source) + + helper_quantizer = MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=False) + helper_quantizer.optimize_for_gemm = True + transposed = helper_quantizer.quantize_rowwise_transpose( + source, + column_mxfp8._columnwise_scale_inv.contiguous(), + with_gemm_swizzled_scales=True, + ) + + native_t = _quantize_native_swizzled_transpose(source) + + # Payload bytes (no swizzling on payload) must match native transposed + # quantization byte-for-byte, since both paths quantize the same source + # blocks with the same E8M0 scales. + assert torch.equal( + transposed._rowwise_data.view(torch.uint8), native_t._rowwise_data.view(torch.uint8) + ), "Swizzled transpose-emit payload bytes differ from native transposed quantization" + + # Scales must also be exact byte-equal — both paths target the GEMM + # swizzled layout for the same logical row-wise tensor. + assert torch.equal(transposed._rowwise_scale_inv, native_t._rowwise_scale_inv), ( + "Swizzled row-wise scale bytes differ from native transposed quantization" + ) diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index cad27a2992..9415f757e5 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -292,6 +292,46 @@ void nvte_mxfp8_scaling_partial_cast(const NVTETensor input, NVTETensor output_r const NVTETensor scale_inv_colwise, int rows, int cols, size_t start_offset, cudaStream_t stream); +/*! \brief Cast and transpose an input tensor into MXFP8 row-wise storage. + * + * Consumes a high-precision tensor and the compact column-wise E8M0 scales + * already computed for that source tensor. Emits row-wise MXFP8 payload and + * scale-inverse storage for the logical transpose of the source. + * + * Output dtype is E4M3 and scales are written in compact (non-swizzled) + * layout. For E5M2 output or GEMM-swizzled scales use the _v2 variant. + * + * \param[in] input Input tensor, flattened as rows x cols. + * \param[in] scale_inv_colwise Source compact column-wise E8M0 scales. + * \param[out] output_rowwise Row-wise MXFP8 payload for input.T. + * \param[out] output_rowwise_scale_inv Row-wise E8M0 scales for input.T. + * \param[in] rows Number of rows in the source logical tensor. + * \param[in] cols Number of columns in the source logical tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_mxfp8_scaling_transpose_cast(const NVTETensor input, + const NVTETensor scale_inv_colwise, + NVTETensor output_rowwise, + NVTETensor output_rowwise_scale_inv, int rows, int cols, + cudaStream_t stream); + +/*! \brief Extended variant of nvte_mxfp8_scaling_transpose_cast. + * + * Same semantics as nvte_mxfp8_scaling_transpose_cast, with two extra knobs: + * + * \param[in] fp8_dtype Output FP8 payload dtype: E4M3 or E5M2. + * \param[in] with_gemm_swizzled_scales Whether output scales should be + * emitted directly in the GEMM + * swizzled layout instead of the + * compact layout. + */ +void nvte_mxfp8_scaling_transpose_cast_v2(const NVTETensor input, + const NVTETensor scale_inv_colwise, + NVTETensor output_rowwise, + NVTETensor output_rowwise_scale_inv, int rows, int cols, + NVTEDType fp8_dtype, bool with_gemm_swizzled_scales, + cudaStream_t stream); + /*! \brief Compute per-tensor scaling factor for NVFP4 format. * * This function computes the scaling factor (alpha) for NVFP4 quantization based diff --git a/transformer_engine/common/recipe/mxfp8_scaling.cu b/transformer_engine/common/recipe/mxfp8_scaling.cu index be692d4563..b7f1d1c6af 100644 --- a/transformer_engine/common/recipe/mxfp8_scaling.cu +++ b/transformer_engine/common/recipe/mxfp8_scaling.cu @@ -7,6 +7,7 @@ #include #include "../common.h" +#include "../cast/mxfp8/swizzle.cuh" #include "../util/ptx.cuh" #include "../utils.cuh" @@ -130,6 +131,55 @@ __global__ void __launch_bounds__(kThreadsPerBlock) } } +__global__ void __launch_bounds__(kThreadsPerBlock) + mxfp8_scaling_transpose_scales_kernel(const e8m0_t *scale_inv_colwise, + e8m0_t *output_rowwise_scale_inv, int colwise_scale_rows, + int colwise_scale_cols, + int rowwise_transpose_scale_stride, int source_rows, + bool with_gemm_swizzled_scales) { + const int64_t total = static_cast(colwise_scale_rows) * colwise_scale_cols; + const int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total) { + return; + } + + const int64_t out_r = idx / colwise_scale_rows; + const int64_t out_c = idx - out_r * colwise_scale_rows; + size_t output_idx = out_r * rowwise_transpose_scale_stride + out_c; + if (with_gemm_swizzled_scales) { + const size_t num_tiles_x = (static_cast(source_rows) + 127) / 128; + output_idx = transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx( + static_cast(out_r), static_cast(out_c), num_tiles_x); + } + output_rowwise_scale_inv[output_idx] = scale_inv_colwise[out_c * colwise_scale_cols + out_r]; +} + +constexpr int kTransposeTileDim = 16; + +template +__global__ void mxfp8_scaling_transpose_cast_kernel( + const IType *input, const e8m0_t *scale_inv_colwise, OType *output_rowwise, int rows, int cols, + int colwise_scale_stride) { + __shared__ OType tile[kTransposeTileDim][kTransposeTileDim + 1]; + + const int64_t c = blockIdx.x * kTransposeTileDim + threadIdx.x; + const int64_t r = blockIdx.y * kTransposeTileDim + threadIdx.y; + if (r < rows && c < cols) { + const e8m0_t biased_exponent = scale_inv_colwise[(r / 32) * colwise_scale_stride + c]; + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + tile[threadIdx.y][threadIdx.x] = static_cast( + static_cast(input[r * cols + c]) * block_scale_inverse); + } + + __syncthreads(); + + const int64_t out_r = blockIdx.x * kTransposeTileDim + threadIdx.y; + const int64_t out_c = blockIdx.y * kTransposeTileDim + threadIdx.x; + if (out_r < cols && out_c < rows) { + output_rowwise[out_r * rows + out_c] = tile[threadIdx.x][threadIdx.y]; + } +} + void mxfp8_scaling_compute_partial_amax(const Tensor input, Tensor amax_rowwise, Tensor amax_colwise, int rows, int cols, size_t start_offset, cudaStream_t stream) { @@ -227,6 +277,81 @@ void mxfp8_scaling_partial_cast(const Tensor input, Tensor output_rowwise, Tenso start_offset, input.data.shape[0]);) } +void mxfp8_scaling_transpose_cast(const Tensor input, const Tensor scale_inv_colwise, + Tensor output_rowwise, Tensor output_rowwise_scale_inv, int rows, + int cols, DType fp8_dtype, bool with_gemm_swizzled_scales, + cudaStream_t stream) { + NVTE_CHECK(rows % 32 == 0, "rows must be divisible by 32"); + NVTE_CHECK(cols % 32 == 0, "cols must be divisible by 32"); + NVTE_CHECK(input.data.shape.size() >= 1, "input must be allocated"); + NVTE_CHECK(input.numel() == static_cast(rows) * cols, + "input numel must match rows * cols"); + NVTE_CHECK(!is_fp8_dtype(input.dtype()), "input must be a high-precision tensor"); + + NVTE_CHECK(output_rowwise.data.shape.size() == 2, "output_rowwise must be a 2D tensor"); + NVTE_CHECK(output_rowwise.data.shape[0] == static_cast(cols), + "output_rowwise dim0 must equal source cols"); + NVTE_CHECK(output_rowwise.data.shape[1] == static_cast(rows), + "output_rowwise dim1 must equal source rows"); + NVTE_CHECK(fp8_dtype == DType::kFloat8E4M3 || fp8_dtype == DType::kFloat8E5M2, + "fp8_dtype should be e4m3 or e5m2"); + NVTE_CHECK(output_rowwise.dtype() == fp8_dtype || output_rowwise.dtype() == DType::kByte, + "output_rowwise should match fp8_dtype or be uint8 storage"); + + NVTE_CHECK(scale_inv_colwise.data.shape.size() == 2, "scale_inv_colwise must be a 2D tensor"); + NVTE_CHECK(scale_inv_colwise.data.shape[0] % colwise_row_padding == 0, + "Wrong padding of scale_inv_colwise's rows"); + NVTE_CHECK(scale_inv_colwise.data.shape[0] >= static_cast(rows / 32), "Invalid rows"); + NVTE_CHECK(scale_inv_colwise.data.shape[1] % colwise_col_padding == 0, + "Wrong padding of scale_inv_colwise's cols"); + NVTE_CHECK(scale_inv_colwise.data.shape[1] >= static_cast(cols), "Invalid cols"); + NVTE_CHECK(scale_inv_colwise.dtype() == DType::kByte, "Wrong dtype of scale_inv_colwise"); + + NVTE_CHECK(output_rowwise_scale_inv.data.shape.size() == 2, + "output_rowwise_scale_inv must be a 2D tensor"); + NVTE_CHECK(output_rowwise_scale_inv.data.shape[0] == scale_inv_colwise.data.shape[1], + "output_rowwise_scale_inv dim0 must equal scale_inv_colwise dim1"); + NVTE_CHECK(output_rowwise_scale_inv.data.shape[1] == scale_inv_colwise.data.shape[0], + "output_rowwise_scale_inv dim1 must equal scale_inv_colwise dim0"); + NVTE_CHECK(output_rowwise_scale_inv.dtype() == DType::kByte, + "Wrong dtype of output_rowwise_scale_inv"); + + const int scale_blocks = static_cast( + DIVUP(output_rowwise_scale_inv.numel(), static_cast(kThreadsPerBlock))); + if (output_rowwise_scale_inv.numel() > 0) { + mxfp8_scaling_transpose_scales_kernel<<>>( + reinterpret_cast(scale_inv_colwise.data.dptr), + reinterpret_cast(output_rowwise_scale_inv.data.dptr), + static_cast(scale_inv_colwise.data.shape[0]), + static_cast(scale_inv_colwise.data.shape[1]), + static_cast(output_rowwise_scale_inv.data.shape[1]), rows, + with_gemm_swizzled_scales); + } + + if (input.numel() > 0) { + const dim3 block(kTransposeTileDim, kTransposeTileDim); + const dim3 grid(DIVUP(cols, kTransposeTileDim), DIVUP(rows, kTransposeTileDim)); + if (fp8_dtype == DType::kFloat8E4M3) { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + mxfp8_scaling_transpose_cast_kernel<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(scale_inv_colwise.data.dptr), + reinterpret_cast(output_rowwise.data.dptr), rows, cols, + scale_inv_colwise.data.shape[1]);) + } else { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + mxfp8_scaling_transpose_cast_kernel<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(scale_inv_colwise.data.dptr), + reinterpret_cast(output_rowwise.data.dptr), rows, cols, + scale_inv_colwise.data.shape[1]);) + } + } + NVTE_CHECK_CUDA(cudaGetLastError()); +} + } // namespace mxfp8_scaling_recipe } // namespace transformer_engine @@ -251,3 +376,29 @@ void nvte_mxfp8_scaling_partial_cast(const NVTETensor input, NVTETensor output_r *convertNVTETensorCheck(output_colwise), *convertNVTETensorCheck(scale_inv_rowwise), *convertNVTETensorCheck(scale_inv_colwise), rows, cols, start_offset, stream); } + +void nvte_mxfp8_scaling_transpose_cast_v2(const NVTETensor input, + const NVTETensor scale_inv_colwise, + NVTETensor output_rowwise, + NVTETensor output_rowwise_scale_inv, int rows, int cols, + NVTEDType fp8_dtype, bool with_gemm_swizzled_scales, + cudaStream_t stream) { + NVTE_API_CALL(nvte_mxfp8_scaling_transpose_cast_v2); + using namespace transformer_engine; + mxfp8_scaling_recipe::mxfp8_scaling_transpose_cast( + *convertNVTETensorCheck(input), *convertNVTETensorCheck(scale_inv_colwise), + *convertNVTETensorCheck(output_rowwise), + *convertNVTETensorCheck(output_rowwise_scale_inv), rows, cols, + static_cast(fp8_dtype), with_gemm_swizzled_scales, stream); +} + +void nvte_mxfp8_scaling_transpose_cast(const NVTETensor input, + const NVTETensor scale_inv_colwise, + NVTETensor output_rowwise, + NVTETensor output_rowwise_scale_inv, int rows, int cols, + cudaStream_t stream) { + nvte_mxfp8_scaling_transpose_cast_v2(input, scale_inv_colwise, output_rowwise, + output_rowwise_scale_inv, rows, cols, + kNVTEFloat8E4M3, /*with_gemm_swizzled_scales=*/false, + stream); +} diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 4a2ea7412b..e0b63aa90f 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -435,6 +435,11 @@ void mxfp8_scaling_partial_cast(const at::Tensor &input, at::Tensor output_rowwi const at::Tensor &scale_inv_colwise, int rows, int cols, size_t start_offset); +void mxfp8_scaling_transpose_cast(const at::Tensor &input, const at::Tensor &scale_inv_colwise, + at::Tensor output_rowwise, + at::Tensor output_rowwise_scale_inv, int rows, int cols, + int64_t fp8_dtype, bool with_gemm_swizzled_scales); + /*************************************************************************************************** * Rotary positional embedding **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp b/transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp index d6693a485e..b7bc1fe291 100644 --- a/transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp @@ -86,4 +86,26 @@ void mxfp8_scaling_partial_cast(const at::Tensor &input, at::Tensor output_rowwi at::cuda::getCurrentCUDAStream()); } +void mxfp8_scaling_transpose_cast(const at::Tensor &input, const at::Tensor &scale_inv_colwise, + at::Tensor output_rowwise, + at::Tensor output_rowwise_scale_inv, int rows, int cols, + int64_t fp8_dtype, bool with_gemm_swizzled_scales) { + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + TORCH_CHECK(scale_inv_colwise.is_contiguous(), "scale_inv_colwise must be contiguous"); + TORCH_CHECK(output_rowwise.is_contiguous(), "output_rowwise must be contiguous"); + TORCH_CHECK(output_rowwise_scale_inv.is_contiguous(), + "output_rowwise_scale_inv must be contiguous"); + + const TensorWrapper input_cu = makeTransformerEngineTensor(input); + const TensorWrapper scale_inv_colwise_cu = makeTransformerEngineTensor(scale_inv_colwise); + TensorWrapper output_rowwise_cu = makeTransformerEngineTensor(output_rowwise); + TensorWrapper output_rowwise_scale_inv_cu = + makeTransformerEngineTensor(output_rowwise_scale_inv); + + nvte_mxfp8_scaling_transpose_cast_v2( + input_cu.data(), scale_inv_colwise_cu.data(), output_rowwise_cu.data(), + output_rowwise_scale_inv_cu.data(), rows, cols, static_cast(fp8_dtype), + with_gemm_swizzled_scales, at::cuda::getCurrentCUDAStream()); +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index eb7576d905..d0d71e8dc2 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -385,6 +385,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("output_rowwise"), py::arg("output_colwise"), py::arg("scale_inv_rowwise"), py::arg("scale_inv_colwise"), py::arg("rows"), py::arg("cols"), py::arg("start_offset"), py::call_guard()); + m.def("mxfp8_scaling_transpose_cast", + &transformer_engine::pytorch::mxfp8_scaling_transpose_cast, + "Cast source into row-wise MXFP8 storage for its logical transpose", py::arg("input"), + py::arg("scale_inv_colwise"), py::arg("output_rowwise"), + py::arg("output_rowwise_scale_inv"), py::arg("rows"), py::arg("cols"), + py::arg("fp8_dtype") = static_cast(transformer_engine::DType::kFloat8E4M3), + py::arg("with_gemm_swizzled_scales") = false, + py::call_guard()); m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding, "Fused Multi-tensor padding", py::call_guard()); m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding, diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 5cab519c79..a0a3d794da 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -227,6 +227,95 @@ def create_tensor_from_data( with_gemm_swizzled_scales=False, ) + def quantize_rowwise_transpose( + self, + tensor: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + *, + fake_dtype: Optional[torch.dtype] = None, + with_gemm_swizzled_scales: Optional[bool] = None, + ) -> MXFP8Tensor: + """Quantize ``tensor.T`` into row-wise MXFP8 storage. + + ``columnwise_scale_inv`` must be the compact E8M0 column-wise scale + tensor computed for ``tensor``. This path emits the transposed row-wise + payload directly from the high-precision source instead of copying an + existing MXFP8 column-wise payload. If ``with_gemm_swizzled_scales`` is + true, the emitted row-wise scales are written directly in the layout + expected by MXFP8 GEMM. + """ + + if not hasattr(tex, "mxfp8_scaling_transpose_cast"): + raise RuntimeError("TransformerEngine extension is missing mxfp8_scaling_transpose_cast") + if tensor.dim() < 2: + raise ValueError( + f"MXFP8 transpose quantization requires at least 2D input: {tensor.shape}" + ) + if not tensor.is_cuda or not columnwise_scale_inv.is_cuda: + raise ValueError("MXFP8 transpose quantization requires CUDA tensors") + if self.dtype not in (tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2): + raise TypeError(f"MXFP8 transpose quantization only supports E4M3/E5M2, got {self.dtype}") + if tensor.dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise TypeError(f"Unsupported MXFP8 transpose source dtype: {tensor.dtype}") + if columnwise_scale_inv.dtype != torch.uint8: + raise TypeError(f"columnwise_scale_inv must be uint8, got {columnwise_scale_inv.dtype}") + if columnwise_scale_inv.dim() != 2: + raise ValueError("columnwise_scale_inv must be 2D") + + source_2d = tensor.contiguous().reshape(-1, tensor.shape[-1]) + columnwise_scale_inv = columnwise_scale_inv.contiguous() + rows, cols = source_2d.shape + if rows % MXFP8_BLOCK_SCALING_SIZE != 0 or cols % MXFP8_BLOCK_SCALING_SIZE != 0: + raise ValueError( + "MXFP8 transpose quantization requires flattened source dims " + f"divisible by {MXFP8_BLOCK_SCALING_SIZE}, got {(rows, cols)}" + ) + + expected_scale_shape = ( + round_up_to_nearest_multiple(rows // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple(cols, 128), + ) + if tuple(columnwise_scale_inv.shape) != expected_scale_shape: + raise ValueError( + "columnwise_scale_inv has wrong compact MXFP8 shape: " + f"expected {expected_scale_shape}, got {tuple(columnwise_scale_inv.shape)}" + ) + + rowwise_data = torch.empty((cols, rows), dtype=torch.uint8, device=tensor.device) + rowwise_scale_inv = torch.empty( + (columnwise_scale_inv.shape[1], columnwise_scale_inv.shape[0]), + dtype=torch.uint8, + device=tensor.device, + ) + if with_gemm_swizzled_scales is None: + with_gemm_swizzled_scales = self.optimize_for_gemm + tex.mxfp8_scaling_transpose_cast( + source_2d, + columnwise_scale_inv, + rowwise_data, + rowwise_scale_inv, + rows, + cols, + int(self.dtype), + bool(with_gemm_swizzled_scales), + ) + + quantizer = self.copy() + quantizer.set_usage(rowwise=True, columnwise=False) + quantizer.optimize_for_gemm = bool(with_gemm_swizzled_scales) + return MXFP8Tensor( + shape=rowwise_data.shape, + dtype=tensor.dtype if fake_dtype is None else fake_dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=None, + columnwise_scale_inv=None, + fp8_dtype=self.dtype, + quantizer=quantizer, + requires_grad=False, + with_gemm_swizzled_scales=bool(with_gemm_swizzled_scales), + ) + def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: if tensor.dtype != torch.float32: tensor = tensor.to(dtype=torch.float32) From cbc1010402b7f87bc19a5b4afb31dfde7d69a302 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 26 Apr 2026 23:30:24 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../test_mxfp8_scaling_transpose_cast.py | 12 ++++---- ...t_mxfp8_scaling_transpose_cast_swizzled.py | 6 ++-- .../include/transformer_engine/recipe.h | 3 +- .../common/recipe/mxfp8_scaling.cu | 29 +++++++++---------- transformer_engine/pytorch/csrc/extensions.h | 6 ++-- .../csrc/extensions/fp8_partial_cast.cpp | 17 +++++------ .../pytorch/csrc/extensions/pybind.cpp | 6 ++-- .../pytorch/tensor/mxfp8_tensor.py | 8 +++-- 8 files changed, 42 insertions(+), 45 deletions(-) diff --git a/tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast.py b/tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast.py index 5d467f7345..94cc49f1c2 100644 --- a/tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast.py +++ b/tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast.py @@ -101,12 +101,12 @@ def test_transpose_cast_matches_copy_adapter_bytes(rows, cols, fp8_dtype): ) torch.cuda.synchronize() - assert torch.equal(rowwise_data.view(torch.uint8), expected_payload.view(torch.uint8)), ( - "Row-wise MXFP8 payload bytes differ from copy-adapter reference" - ) - assert torch.equal(rowwise_scale, expected_scale), ( - "Row-wise MXFP8 scale bytes differ from copy-adapter reference" - ) + assert torch.equal( + rowwise_data.view(torch.uint8), expected_payload.view(torch.uint8) + ), "Row-wise MXFP8 payload bytes differ from copy-adapter reference" + assert torch.equal( + rowwise_scale, expected_scale + ), "Row-wise MXFP8 scale bytes differ from copy-adapter reference" @pytest.mark.parametrize("rows,cols", [(64, 128), (256, 4096)]) diff --git a/tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast_swizzled.py b/tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast_swizzled.py index 6f944cbdcc..e9ab46891e 100644 --- a/tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast_swizzled.py +++ b/tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast_swizzled.py @@ -72,6 +72,6 @@ def test_swizzled_scales_match_native_transpose(rows, cols): # Scales must also be exact byte-equal — both paths target the GEMM # swizzled layout for the same logical row-wise tensor. - assert torch.equal(transposed._rowwise_scale_inv, native_t._rowwise_scale_inv), ( - "Swizzled row-wise scale bytes differ from native transposed quantization" - ) + assert torch.equal( + transposed._rowwise_scale_inv, native_t._rowwise_scale_inv + ), "Swizzled row-wise scale bytes differ from native transposed quantization" diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 9415f757e5..ec3bc642d3 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -309,8 +309,7 @@ void nvte_mxfp8_scaling_partial_cast(const NVTETensor input, NVTETensor output_r * \param[in] cols Number of columns in the source logical tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_mxfp8_scaling_transpose_cast(const NVTETensor input, - const NVTETensor scale_inv_colwise, +void nvte_mxfp8_scaling_transpose_cast(const NVTETensor input, const NVTETensor scale_inv_colwise, NVTETensor output_rowwise, NVTETensor output_rowwise_scale_inv, int rows, int cols, cudaStream_t stream); diff --git a/transformer_engine/common/recipe/mxfp8_scaling.cu b/transformer_engine/common/recipe/mxfp8_scaling.cu index b7f1d1c6af..634fb863c0 100644 --- a/transformer_engine/common/recipe/mxfp8_scaling.cu +++ b/transformer_engine/common/recipe/mxfp8_scaling.cu @@ -6,8 +6,8 @@ #include -#include "../common.h" #include "../cast/mxfp8/swizzle.cuh" +#include "../common.h" #include "../util/ptx.cuh" #include "../utils.cuh" @@ -157,9 +157,10 @@ __global__ void __launch_bounds__(kThreadsPerBlock) constexpr int kTransposeTileDim = 16; template -__global__ void mxfp8_scaling_transpose_cast_kernel( - const IType *input, const e8m0_t *scale_inv_colwise, OType *output_rowwise, int rows, int cols, - int colwise_scale_stride) { +__global__ void mxfp8_scaling_transpose_cast_kernel(const IType *input, + const e8m0_t *scale_inv_colwise, + OType *output_rowwise, int rows, int cols, + int colwise_scale_stride) { __shared__ OType tile[kTransposeTileDim][kTransposeTileDim + 1]; const int64_t c = blockIdx.x * kTransposeTileDim + threadIdx.x; @@ -167,8 +168,8 @@ __global__ void mxfp8_scaling_transpose_cast_kernel( if (r < rows && c < cols) { const e8m0_t biased_exponent = scale_inv_colwise[(r / 32) * colwise_scale_stride + c]; const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - tile[threadIdx.y][threadIdx.x] = static_cast( - static_cast(input[r * cols + c]) * block_scale_inverse); + tile[threadIdx.y][threadIdx.x] = + static_cast(static_cast(input[r * cols + c]) * block_scale_inverse); } __syncthreads(); @@ -324,8 +325,7 @@ void mxfp8_scaling_transpose_cast(const Tensor input, const Tensor scale_inv_col reinterpret_cast(output_rowwise_scale_inv.data.dptr), static_cast(scale_inv_colwise.data.shape[0]), static_cast(scale_inv_colwise.data.shape[1]), - static_cast(output_rowwise_scale_inv.data.shape[1]), rows, - with_gemm_swizzled_scales); + static_cast(output_rowwise_scale_inv.data.shape[1]), rows, with_gemm_swizzled_scales); } if (input.numel() > 0) { @@ -387,18 +387,15 @@ void nvte_mxfp8_scaling_transpose_cast_v2(const NVTETensor input, using namespace transformer_engine; mxfp8_scaling_recipe::mxfp8_scaling_transpose_cast( *convertNVTETensorCheck(input), *convertNVTETensorCheck(scale_inv_colwise), - *convertNVTETensorCheck(output_rowwise), - *convertNVTETensorCheck(output_rowwise_scale_inv), rows, cols, - static_cast(fp8_dtype), with_gemm_swizzled_scales, stream); + *convertNVTETensorCheck(output_rowwise), *convertNVTETensorCheck(output_rowwise_scale_inv), + rows, cols, static_cast(fp8_dtype), with_gemm_swizzled_scales, stream); } -void nvte_mxfp8_scaling_transpose_cast(const NVTETensor input, - const NVTETensor scale_inv_colwise, +void nvte_mxfp8_scaling_transpose_cast(const NVTETensor input, const NVTETensor scale_inv_colwise, NVTETensor output_rowwise, NVTETensor output_rowwise_scale_inv, int rows, int cols, cudaStream_t stream) { nvte_mxfp8_scaling_transpose_cast_v2(input, scale_inv_colwise, output_rowwise, - output_rowwise_scale_inv, rows, cols, - kNVTEFloat8E4M3, /*with_gemm_swizzled_scales=*/false, - stream); + output_rowwise_scale_inv, rows, cols, kNVTEFloat8E4M3, + /*with_gemm_swizzled_scales=*/false, stream); } diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e0b63aa90f..6d180d916e 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -436,9 +436,9 @@ void mxfp8_scaling_partial_cast(const at::Tensor &input, at::Tensor output_rowwi size_t start_offset); void mxfp8_scaling_transpose_cast(const at::Tensor &input, const at::Tensor &scale_inv_colwise, - at::Tensor output_rowwise, - at::Tensor output_rowwise_scale_inv, int rows, int cols, - int64_t fp8_dtype, bool with_gemm_swizzled_scales); + at::Tensor output_rowwise, at::Tensor output_rowwise_scale_inv, + int rows, int cols, int64_t fp8_dtype, + bool with_gemm_swizzled_scales); /*************************************************************************************************** * Rotary positional embedding diff --git a/transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp b/transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp index b7bc1fe291..393a4131a4 100644 --- a/transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp @@ -87,9 +87,9 @@ void mxfp8_scaling_partial_cast(const at::Tensor &input, at::Tensor output_rowwi } void mxfp8_scaling_transpose_cast(const at::Tensor &input, const at::Tensor &scale_inv_colwise, - at::Tensor output_rowwise, - at::Tensor output_rowwise_scale_inv, int rows, int cols, - int64_t fp8_dtype, bool with_gemm_swizzled_scales) { + at::Tensor output_rowwise, at::Tensor output_rowwise_scale_inv, + int rows, int cols, int64_t fp8_dtype, + bool with_gemm_swizzled_scales) { TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); TORCH_CHECK(scale_inv_colwise.is_contiguous(), "scale_inv_colwise must be contiguous"); TORCH_CHECK(output_rowwise.is_contiguous(), "output_rowwise must be contiguous"); @@ -99,13 +99,12 @@ void mxfp8_scaling_transpose_cast(const at::Tensor &input, const at::Tensor &sca const TensorWrapper input_cu = makeTransformerEngineTensor(input); const TensorWrapper scale_inv_colwise_cu = makeTransformerEngineTensor(scale_inv_colwise); TensorWrapper output_rowwise_cu = makeTransformerEngineTensor(output_rowwise); - TensorWrapper output_rowwise_scale_inv_cu = - makeTransformerEngineTensor(output_rowwise_scale_inv); + TensorWrapper output_rowwise_scale_inv_cu = makeTransformerEngineTensor(output_rowwise_scale_inv); - nvte_mxfp8_scaling_transpose_cast_v2( - input_cu.data(), scale_inv_colwise_cu.data(), output_rowwise_cu.data(), - output_rowwise_scale_inv_cu.data(), rows, cols, static_cast(fp8_dtype), - with_gemm_swizzled_scales, at::cuda::getCurrentCUDAStream()); + nvte_mxfp8_scaling_transpose_cast_v2(input_cu.data(), scale_inv_colwise_cu.data(), + output_rowwise_cu.data(), output_rowwise_scale_inv_cu.data(), + rows, cols, static_cast(fp8_dtype), + with_gemm_swizzled_scales, at::cuda::getCurrentCUDAStream()); } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d0d71e8dc2..81a00a9e50 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -385,14 +385,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("output_rowwise"), py::arg("output_colwise"), py::arg("scale_inv_rowwise"), py::arg("scale_inv_colwise"), py::arg("rows"), py::arg("cols"), py::arg("start_offset"), py::call_guard()); - m.def("mxfp8_scaling_transpose_cast", - &transformer_engine::pytorch::mxfp8_scaling_transpose_cast, + m.def("mxfp8_scaling_transpose_cast", &transformer_engine::pytorch::mxfp8_scaling_transpose_cast, "Cast source into row-wise MXFP8 storage for its logical transpose", py::arg("input"), py::arg("scale_inv_colwise"), py::arg("output_rowwise"), py::arg("output_rowwise_scale_inv"), py::arg("rows"), py::arg("cols"), py::arg("fp8_dtype") = static_cast(transformer_engine::DType::kFloat8E4M3), - py::arg("with_gemm_swizzled_scales") = false, - py::call_guard()); + py::arg("with_gemm_swizzled_scales") = false, py::call_guard()); m.def("fused_multi_row_padding", &transformer_engine::pytorch::fused_multi_row_padding, "Fused Multi-tensor padding", py::call_guard()); m.def("fused_multi_row_unpadding", &transformer_engine::pytorch::fused_multi_row_unpadding, diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index a0a3d794da..fbbc91571e 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -246,7 +246,9 @@ def quantize_rowwise_transpose( """ if not hasattr(tex, "mxfp8_scaling_transpose_cast"): - raise RuntimeError("TransformerEngine extension is missing mxfp8_scaling_transpose_cast") + raise RuntimeError( + "TransformerEngine extension is missing mxfp8_scaling_transpose_cast" + ) if tensor.dim() < 2: raise ValueError( f"MXFP8 transpose quantization requires at least 2D input: {tensor.shape}" @@ -254,7 +256,9 @@ def quantize_rowwise_transpose( if not tensor.is_cuda or not columnwise_scale_inv.is_cuda: raise ValueError("MXFP8 transpose quantization requires CUDA tensors") if self.dtype not in (tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2): - raise TypeError(f"MXFP8 transpose quantization only supports E4M3/E5M2, got {self.dtype}") + raise TypeError( + f"MXFP8 transpose quantization only supports E4M3/E5M2, got {self.dtype}" + ) if tensor.dtype not in (torch.float32, torch.float16, torch.bfloat16): raise TypeError(f"Unsupported MXFP8 transpose source dtype: {tensor.dtype}") if columnwise_scale_inv.dtype != torch.uint8: From 0d8b158dbc1a967d9e37364b04e61f1e4cbb86af Mon Sep 17 00:00:00 2001 From: David Gornshtein Date: Mon, 27 Apr 2026 10:43:02 +0000 Subject: [PATCH 3/3] Address review feedback * Add NVTE_API_CALL(nvte_mxfp8_scaling_transpose_cast) to the v1 entry point so profiling/tracing tools attribute calls to the actual symbol the caller used instead of v2. * Add __launch_bounds__(kTransposeTileDim * kTransposeTileDim) on the transpose-cast payload kernel to match the launch shape and let the compiler tune register allocation, consistent with the other __global__ kernels in this file. * Drop unused source = _make_source(64, 128) allocation from test_transpose_cast_requires_block_aligned_dims; only bad_source and bad_scale are exercised. Signed-off-by: David Gornshtein --- tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast.py | 1 - transformer_engine/common/recipe/mxfp8_scaling.cu | 9 +++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast.py b/tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast.py index 94cc49f1c2..784ca984ef 100644 --- a/tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast.py +++ b/tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast.py @@ -195,7 +195,6 @@ def test_transpose_cast_rejects_fp8_input(): def test_transpose_cast_requires_block_aligned_dims(): - source = _make_source(64, 128) quantizer = _make_quantizer(tex.DType.kFloat8E4M3) quantizer.set_usage(rowwise=True, columnwise=False) bad_source = torch.randn(48, 128, dtype=torch.bfloat16, device="cuda") diff --git a/transformer_engine/common/recipe/mxfp8_scaling.cu b/transformer_engine/common/recipe/mxfp8_scaling.cu index 634fb863c0..7433de48a2 100644 --- a/transformer_engine/common/recipe/mxfp8_scaling.cu +++ b/transformer_engine/common/recipe/mxfp8_scaling.cu @@ -157,10 +157,10 @@ __global__ void __launch_bounds__(kThreadsPerBlock) constexpr int kTransposeTileDim = 16; template -__global__ void mxfp8_scaling_transpose_cast_kernel(const IType *input, - const e8m0_t *scale_inv_colwise, - OType *output_rowwise, int rows, int cols, - int colwise_scale_stride) { +__global__ void __launch_bounds__(kTransposeTileDim *kTransposeTileDim) + mxfp8_scaling_transpose_cast_kernel(const IType *input, const e8m0_t *scale_inv_colwise, + OType *output_rowwise, int rows, int cols, + int colwise_scale_stride) { __shared__ OType tile[kTransposeTileDim][kTransposeTileDim + 1]; const int64_t c = blockIdx.x * kTransposeTileDim + threadIdx.x; @@ -395,6 +395,7 @@ void nvte_mxfp8_scaling_transpose_cast(const NVTETensor input, const NVTETensor NVTETensor output_rowwise, NVTETensor output_rowwise_scale_inv, int rows, int cols, cudaStream_t stream) { + NVTE_API_CALL(nvte_mxfp8_scaling_transpose_cast); nvte_mxfp8_scaling_transpose_cast_v2(input, scale_inv_colwise, output_rowwise, output_rowwise_scale_inv, rows, cols, kNVTEFloat8E4M3, /*with_gemm_swizzled_scales=*/false, stream);