-
Notifications
You must be signed in to change notification settings - Fork 32
gfx1250 mxfp8 gemm: loosen restrictions on K #627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -463,6 +463,9 @@ void performTest(const TestParams& params) { | |
| const bool has_fp8 = isFp8Type(atype) || isFp8Type(btype); | ||
| const bool use_mxfp8 = params.scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING; | ||
|
|
||
| cudaDeviceProp prop; | ||
| (void)cudaGetDeviceProperties(&prop, 0); | ||
|
|
||
| if (use_mxfp8) | ||
| { | ||
| if (!has_fp8) { | ||
|
|
@@ -471,14 +474,15 @@ void performTest(const TestParams& params) { | |
| if (params.m % 16 || params.n % 16) { | ||
| GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16"; | ||
| } | ||
| if (params.k % 128) { | ||
| GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128"; | ||
| size_t required_k_multiple = 128; | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| required_k_multiple = (prop.major == 12 && prop.minor == 5) ? 32 : 128; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: gfx1250 requires a multiple of the block size, not necessarily 32. I believe 16 may also be supported.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TE defines mxfp8 block size as 32 (MX formats are 32, nvfp4 is 16). |
||
| #endif | ||
| if (params.k % required_k_multiple) { | ||
| GTEST_SKIP() << "MXFP8 requires K to be a multiple of " << required_k_multiple; | ||
| } | ||
| } | ||
|
|
||
| cudaDeviceProp prop; | ||
| (void)cudaGetDeviceProperties(&prop, 0); | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
|
|
||
| #if HIP_VERSION < 70200000 | ||
|
|
@@ -695,16 +699,20 @@ void performDqTest(const TestParams ¶ms) { | |
| GTEST_ASSERT_TRUE(isFp8Type(atype) && isFp8Type(btype)) << "FP8/BF8 input datatype is expected"; | ||
| GTEST_ASSERT_FALSE(isFp8Type(dtype)) << "Non FP8/BF8 output datatype is expected"; | ||
|
|
||
| cudaDeviceProp prop; | ||
| (void)cudaGetDeviceProperties(&prop, 0); | ||
|
|
||
| if (params.m % 16 || params.n % 16) { | ||
| GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16"; | ||
| } | ||
| if (params.k % 128) { | ||
| GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128"; | ||
| size_t required_k_multiple = 128; | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| required_k_multiple = (prop.major == 12 && prop.minor == 5) ? 32 : 128; | ||
| #endif | ||
| if (params.k % required_k_multiple) { | ||
| GTEST_SKIP() << "MXFP8 requires K to be a multiple of " << required_k_multiple; | ||
| } | ||
|
|
||
| cudaDeviceProp prop; | ||
| (void)cudaGetDeviceProperties(&prop, 0); | ||
|
|
||
| bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5) || prop.major >= 12; | ||
| if (!mxfp8_supported) { | ||
| GTEST_SKIP() << "MXFP8 is not supported in current config"; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |
| #include <cstring> | ||
|
|
||
| #include "../common.h" | ||
| #include "../util/cuda_runtime.h" | ||
| #include "../util/vectorized_pointwise.h" | ||
| #include "../util/logging.h" | ||
|
|
||
|
|
@@ -1736,10 +1737,16 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, | |
| NVTE_CHECK((is_transb ? B0 : B1) == k, | ||
| "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1, | ||
| ")"); | ||
| // Check that K is a multiple of 128, and M/N are multiples of 16 for MXFP8 GEMM | ||
| // Check that K is compatible with the MXFP8 scale layout, and M/N are multiples of 16 | ||
| if (inputA->scaling_mode == NVTE_MXFP8_1D_SCALING || inputB->scaling_mode == NVTE_MXFP8_1D_SCALING) { | ||
| const bool is_gfx1250 = cuda::sm_arch() == 125; | ||
| // TODO: Also use 32 for gfx950 once hipBLASLt (and TE) support MXFP8 GEMM with | ||
| // swizzled scales on that architecture. | ||
| const int required_k_multiple = is_gfx1250 ? 32 : 128; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a TODO here to change this for gfx950 after scale preswizzle is in hipblasLt.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added in 3a7dd8f |
||
| NVTE_CHECK(inputBias->data.dptr == nullptr, "MXFP8 GEMM does not yet support bias."); | ||
| NVTE_CHECK((k % 128) == 0, "GEMM K dimension must be multiple of 128 for MXFP8 scaling (got K=", k, ")"); | ||
| NVTE_CHECK((k % required_k_multiple) == 0, | ||
| "GEMM K dimension must be multiple of ", required_k_multiple, | ||
| " for MXFP8 scaling (got K=", k, ")"); | ||
| NVTE_CHECK((m % 16) == 0, "GEMM M dimension must be multiple of 16 for MXFP8 scaling (got M=", m, ")"); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that hipblaslt supports arbitrary M/N for gfx1250?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd rather leave that for a future improvement (if it turns out to be necessary) - this config is currently untested. |
||
| NVTE_CHECK((n % 16) == 0, "GEMM N dimension must be multiple of 16 for MXFP8 scaling (got N=", n, ")"); | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use the get_arch function here to avoid calling this for every test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sm_arch isn't used in the tests; this change just moves where the properties are discovered.