diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 7fa630033..42f997a68 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -463,6 +463,9 @@ void performTest(const TestParams& params) { const bool has_fp8 = isFp8Type(atype) || isFp8Type(btype); const bool use_mxfp8 = params.scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING; + cudaDeviceProp prop; + (void)cudaGetDeviceProperties(&prop, 0); + if (use_mxfp8) { if (!has_fp8) { @@ -471,14 +474,15 @@ void performTest(const TestParams& params) { if (params.m % 16 || params.n % 16) { GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16"; } - if (params.k % 128) { - GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128"; + size_t required_k_multiple = 128; + #ifdef __HIP_PLATFORM_AMD__ + required_k_multiple = (prop.major == 12 && prop.minor == 5) ? 32 : 128; + #endif + if (params.k % required_k_multiple) { + GTEST_SKIP() << "MXFP8 requires K to be a multiple of " << required_k_multiple; } } - cudaDeviceProp prop; - (void)cudaGetDeviceProperties(&prop, 0); - #ifdef __HIP_PLATFORM_AMD__ #if HIP_VERSION < 70200000 @@ -695,16 +699,20 @@ void performDqTest(const TestParams ¶ms) { GTEST_ASSERT_TRUE(isFp8Type(atype) && isFp8Type(btype)) << "FP8/BF8 input datatype is expected"; GTEST_ASSERT_FALSE(isFp8Type(dtype)) << "Non FP8/BF8 output datatype is expected"; + cudaDeviceProp prop; + (void)cudaGetDeviceProperties(&prop, 0); + if (params.m % 16 || params.n % 16) { GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16"; } - if (params.k % 128) { - GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128"; + size_t required_k_multiple = 128; +#ifdef __HIP_PLATFORM_AMD__ + required_k_multiple = (prop.major == 12 && prop.minor == 5) ? 32 : 128; +#endif + if (params.k % required_k_multiple) { + GTEST_SKIP() << "MXFP8 requires K to be a multiple of " << required_k_multiple; } - cudaDeviceProp prop; - (void)cudaGetDeviceProperties(&prop, 0); - bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5) || prop.major >= 12; if (!mxfp8_supported) { GTEST_SKIP() << "MXFP8 is not supported in current config"; diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 51d90e591..20d05e3db 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -27,6 +27,7 @@ #include #include "../common.h" +#include "../util/cuda_runtime.h" #include "../util/vectorized_pointwise.h" #include "../util/logging.h" @@ -1736,10 +1737,16 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK((is_transb ? B0 : B1) == k, "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1, ")"); - // Check that K is a multiple of 128, and M/N are multiples of 16 for MXFP8 GEMM + // Check that K is compatible with the MXFP8 scale layout, and M/N are multiples of 16 if (inputA->scaling_mode == NVTE_MXFP8_1D_SCALING || inputB->scaling_mode == NVTE_MXFP8_1D_SCALING) { + const bool is_gfx1250 = cuda::sm_arch() == 125; + // TODO: Also use 32 for gfx950 once hipBLASLt (and TE) support MXFP8 GEMM with + // swizzled scales on that architecture. + const int required_k_multiple = is_gfx1250 ? 32 : 128; NVTE_CHECK(inputBias->data.dptr == nullptr, "MXFP8 GEMM does not yet support bias."); - NVTE_CHECK((k % 128) == 0, "GEMM K dimension must be multiple of 128 for MXFP8 scaling (got K=", k, ")"); + NVTE_CHECK((k % required_k_multiple) == 0, + "GEMM K dimension must be multiple of ", required_k_multiple, + " for MXFP8 scaling (got K=", k, ")"); NVTE_CHECK((m % 16) == 0, "GEMM M dimension must be multiple of 16 for MXFP8 scaling (got M=", m, ")"); NVTE_CHECK((n % 16) == 0, "GEMM N dimension must be multiple of 16 for MXFP8 scaling (got N=", n, ")"); }