Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,9 @@ void performTest(const TestParams& params) {
const bool has_fp8 = isFp8Type(atype) || isFp8Type(btype);
const bool use_mxfp8 = params.scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING;

cudaDeviceProp prop;
(void)cudaGetDeviceProperties(&prop, 0);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use the get_arch function here to avoid calling this for every test?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sm_arch isn't used in the tests; this change just moves where the properties are discovered.


if (use_mxfp8)
{
if (!has_fp8) {
Expand All @@ -471,14 +474,15 @@ void performTest(const TestParams& params) {
if (params.m % 16 || params.n % 16) {
GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16";
}
if (params.k % 128) {
GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128";
size_t required_k_multiple = 128;
#ifdef __HIP_PLATFORM_AMD__
required_k_multiple = (prop.major == 12 && prop.minor == 5) ? 32 : 128;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: gfx1250 requires a multiple of the block size, not necessarily 32. I believe 16 may also be supported.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TE defines mxfp8 block size as 32 (MX formats are 32, nvfp4 is 16).

#endif
if (params.k % required_k_multiple) {
GTEST_SKIP() << "MXFP8 requires K to be a multiple of " << required_k_multiple;
}
}

cudaDeviceProp prop;
(void)cudaGetDeviceProperties(&prop, 0);

#ifdef __HIP_PLATFORM_AMD__

#if HIP_VERSION < 70200000
Expand Down Expand Up @@ -695,16 +699,20 @@ void performDqTest(const TestParams &params) {
GTEST_ASSERT_TRUE(isFp8Type(atype) && isFp8Type(btype)) << "FP8/BF8 input datatype is expected";
GTEST_ASSERT_FALSE(isFp8Type(dtype)) << "Non FP8/BF8 output datatype is expected";

cudaDeviceProp prop;
(void)cudaGetDeviceProperties(&prop, 0);

if (params.m % 16 || params.n % 16) {
GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16";
}
if (params.k % 128) {
GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128";
size_t required_k_multiple = 128;
#ifdef __HIP_PLATFORM_AMD__
required_k_multiple = (prop.major == 12 && prop.minor == 5) ? 32 : 128;
#endif
if (params.k % required_k_multiple) {
GTEST_SKIP() << "MXFP8 requires K to be a multiple of " << required_k_multiple;
}

cudaDeviceProp prop;
(void)cudaGetDeviceProperties(&prop, 0);

bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5) || prop.major >= 12;
if (!mxfp8_supported) {
GTEST_SKIP() << "MXFP8 is not supported in current config";
Expand Down
11 changes: 9 additions & 2 deletions transformer_engine/common/gemm/rocm_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <cstring>

#include "../common.h"
#include "../util/cuda_runtime.h"
#include "../util/vectorized_pointwise.h"
#include "../util/logging.h"

Expand Down Expand Up @@ -1736,10 +1737,16 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK((is_transb ? B0 : B1) == k,
"GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1,
")");
// Check that K is a multiple of 128, and M/N are multiples of 16 for MXFP8 GEMM
// Check that K is compatible with the MXFP8 scale layout, and M/N are multiples of 16
if (inputA->scaling_mode == NVTE_MXFP8_1D_SCALING || inputB->scaling_mode == NVTE_MXFP8_1D_SCALING) {
const bool is_gfx1250 = cuda::sm_arch() == 125;
// TODO: Also use 32 for gfx950 once hipBLASLt (and TE) support MXFP8 GEMM with
// swizzled scales on that architecture.
const int required_k_multiple = is_gfx1250 ? 32 : 128;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a TODO here to change this for gfx950 after scale preswizzle is in hipblasLt.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added in 3a7dd8f

NVTE_CHECK(inputBias->data.dptr == nullptr, "MXFP8 GEMM does not yet support bias.");
NVTE_CHECK((k % 128) == 0, "GEMM K dimension must be multiple of 128 for MXFP8 scaling (got K=", k, ")");
NVTE_CHECK((k % required_k_multiple) == 0,
"GEMM K dimension must be multiple of ", required_k_multiple,
" for MXFP8 scaling (got K=", k, ")");
NVTE_CHECK((m % 16) == 0, "GEMM M dimension must be multiple of 16 for MXFP8 scaling (got M=", m, ")");

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that hipblaslt supports arbitrary M/N for gfx1250?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather leave that for a future improvement (if it turns out to be necessary) - this config is currently untested.

NVTE_CHECK((n % 16) == 0, "GEMM N dimension must be multiple of 16 for MXFP8 scaling (got N=", n, ")");
}
Expand Down