From b2427563955054f8b01ac535824b929d54c77751 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 16 Jun 2026 15:31:03 +0000 Subject: [PATCH 1/2] gfx1250 gemm: loosen restrictions on K --- tests/cpp/operator/test_cublaslt_gemm.cu | 22 +++++++++++---------- transformer_engine/common/gemm/rocm_gemm.cu | 9 +++++++-- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 7fa630033..bc57df2c0 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -463,6 +463,9 @@ void performTest(const TestParams& params) { const bool has_fp8 = isFp8Type(atype) || isFp8Type(btype); const bool use_mxfp8 = params.scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING; + cudaDeviceProp prop; + (void)cudaGetDeviceProperties(&prop, 0); + if (use_mxfp8) { if (!has_fp8) { @@ -471,14 +474,12 @@ void performTest(const TestParams& params) { if (params.m % 16 || params.n % 16) { GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16"; } - if (params.k % 128) { - GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128"; + const size_t required_k_multiple = (prop.major == 12 && prop.minor == 5) ? 32 : 128; + if (params.k % required_k_multiple) { + GTEST_SKIP() << "MXFP8 requires K to be a multiple of " << required_k_multiple; } } - cudaDeviceProp prop; - (void)cudaGetDeviceProperties(&prop, 0); - #ifdef __HIP_PLATFORM_AMD__ #if HIP_VERSION < 70200000 @@ -695,16 +696,17 @@ void performDqTest(const TestParams ¶ms) { GTEST_ASSERT_TRUE(isFp8Type(atype) && isFp8Type(btype)) << "FP8/BF8 input datatype is expected"; GTEST_ASSERT_FALSE(isFp8Type(dtype)) << "Non FP8/BF8 output datatype is expected"; + cudaDeviceProp prop; + (void)cudaGetDeviceProperties(&prop, 0); + if (params.m % 16 || params.n % 16) { GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16"; } - if (params.k % 128) { - GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128"; + const size_t required_k_multiple = (prop.major == 12 && prop.minor == 5) ? 32 : 128; + if (params.k % required_k_multiple) { + GTEST_SKIP() << "MXFP8 requires K to be a multiple of " << required_k_multiple; } - cudaDeviceProp prop; - (void)cudaGetDeviceProperties(&prop, 0); - bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5) || prop.major >= 12; if (!mxfp8_supported) { GTEST_SKIP() << "MXFP8 is not supported in current config"; diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 51d90e591..5574fb9cd 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -27,6 +27,7 @@ #include #include "../common.h" +#include "../util/cuda_runtime.h" #include "../util/vectorized_pointwise.h" #include "../util/logging.h" @@ -1736,10 +1737,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK((is_transb ? B0 : B1) == k, "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1, ")"); - // Check that K is a multiple of 128, and M/N are multiples of 16 for MXFP8 GEMM + // Check that K is compatible with the MXFP8 scale layout, and M/N are multiples of 16 if (inputA->scaling_mode == NVTE_MXFP8_1D_SCALING || inputB->scaling_mode == NVTE_MXFP8_1D_SCALING) { + const bool is_gfx1250 = cuda::sm_arch() == 125; + const int required_k_multiple = is_gfx1250 ? 32 : 128; NVTE_CHECK(inputBias->data.dptr == nullptr, "MXFP8 GEMM does not yet support bias."); - NVTE_CHECK((k % 128) == 0, "GEMM K dimension must be multiple of 128 for MXFP8 scaling (got K=", k, ")"); + NVTE_CHECK((k % required_k_multiple) == 0, + "GEMM K dimension must be multiple of ", required_k_multiple, + " for MXFP8 scaling (got K=", k, ")"); NVTE_CHECK((m % 16) == 0, "GEMM M dimension must be multiple of 16 for MXFP8 scaling (got M=", m, ")"); NVTE_CHECK((n % 16) == 0, "GEMM N dimension must be multiple of 16 for MXFP8 scaling (got N=", n, ")"); } From 3a7dd8f4c521c11fe7ad8f14d1a3709be51fb3bc Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 16 Jun 2026 19:09:07 +0000 Subject: [PATCH 2/2] address review comments. --- tests/cpp/operator/test_cublaslt_gemm.cu | 10 ++++++++-- transformer_engine/common/gemm/rocm_gemm.cu | 2 ++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index bc57df2c0..42f997a68 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -474,7 +474,10 @@ void performTest(const TestParams& params) { if (params.m % 16 || params.n % 16) { GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16"; } - const size_t required_k_multiple = (prop.major == 12 && prop.minor == 5) ? 32 : 128; + size_t required_k_multiple = 128; + #ifdef __HIP_PLATFORM_AMD__ + required_k_multiple = (prop.major == 12 && prop.minor == 5) ? 32 : 128; + #endif if (params.k % required_k_multiple) { GTEST_SKIP() << "MXFP8 requires K to be a multiple of " << required_k_multiple; } @@ -702,7 +705,10 @@ void performDqTest(const TestParams ¶ms) { if (params.m % 16 || params.n % 16) { GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16"; } - const size_t required_k_multiple = (prop.major == 12 && prop.minor == 5) ? 32 : 128; + size_t required_k_multiple = 128; +#ifdef __HIP_PLATFORM_AMD__ + required_k_multiple = (prop.major == 12 && prop.minor == 5) ? 32 : 128; +#endif if (params.k % required_k_multiple) { GTEST_SKIP() << "MXFP8 requires K to be a multiple of " << required_k_multiple; } diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 5574fb9cd..20d05e3db 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -1740,6 +1740,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, // Check that K is compatible with the MXFP8 scale layout, and M/N are multiples of 16 if (inputA->scaling_mode == NVTE_MXFP8_1D_SCALING || inputB->scaling_mode == NVTE_MXFP8_1D_SCALING) { const bool is_gfx1250 = cuda::sm_arch() == 125; + // TODO: Also use 32 for gfx950 once hipBLASLt (and TE) support MXFP8 GEMM with + // swizzled scales on that architecture. const int required_k_multiple = is_gfx1250 ? 32 : 128; NVTE_CHECK(inputBias->data.dptr == nullptr, "MXFP8 GEMM does not yet support bias."); NVTE_CHECK((k % required_k_multiple) == 0,