diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp index 27bc94f60..087a17f8e 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp @@ -63,19 +63,19 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, // FP8 special handling. // // A_use/B_use and transA_use/transB_use have already gone through the - // upstream-style grouped GEMM normalization above. This block only rewrites - // that normalized presentation into the CK FP8 preferred NT presentation by selecting - // `columnwise_data` when needed. + // upstream-style grouped GEMM normalization above. CK FP8 grouped GEMM is + // compiled only for the preferred NT presentation: // - // CK FP8 target presentation: - // A_use: N - // B_use: T + // transA_use = false + // transB_use = true // - // The outer condition checks whether this NT presentation is possible: - // - A_use is already N, or can be made N using columnwise_data - // - B_use is already T, or can be made T using columnwise_data + // This block rewrites the normalized presentation into that NT form by + // selecting columnwise_data when needed. If the required columnwise_data view + // is unavailable, this CK FP8 backend cannot represent the GEMM in its + // supported layout form, so we fall back instead of compiling/running an + // unsupported layout variant. // - // Then each operand is rewritten independently only if needed: + // Rewrite cases: // NN -> rewrite B only // TN -> rewrite A and B // NT -> already in target form @@ -87,16 +87,23 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, const bool has_a_col = A0_te->has_columnwise_data(); const bool has_b_col = B0_te->has_columnwise_data(); - if ((!transA_use || has_a_col) && (transB_use || has_b_col)) { - if (transA_use) { - use_a_colwise_data = true; - transA_use = false; - } + const bool can_make_a_nt = !transA_use || has_a_col; + const bool can_make_b_nt = transB_use || has_b_col; - if (!transB_use) { - use_b_colwise_data = true; - transB_use = true; - } + if (!can_make_a_nt || !can_make_b_nt) { + NVTE_WARN("ck_tile_grouped_gemm: FP8 grouped GEMM requires NT presentation. " + "Missing required columnwise_data for layout rewrite; falling back."); + return false; + } + + if (transA_use) { + use_a_colwise_data = true; + transA_use = false; + } + + if (!transB_use) { + use_b_colwise_data = true; + transB_use = true; } } @@ -164,7 +171,8 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, A_use, B_use, D, - static_cast(n), + n, + detect_gpu_arch(), group_num, transA_use, transB_use, diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h index 94a301df7..de2bdb985 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h @@ -103,6 +103,8 @@ struct GroupedGemmRunContext { NVTETensor* D = nullptr; int64_t N = 0; + GPUArch arch = GPUArch::UNKNOWN; + int group_num = 0; bool transA = false; bool transB = false; diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h index 4777e7b0f..efc074479 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h @@ -38,6 +38,29 @@ struct TileCfg_256x256x64 { static constexpr ck_tile::index_t TilePartitionerM01 = 4; }; +struct TileCfg_256x256x64_WMMA { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; + + static constexpr bool kPadM = true; + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + struct TileCfg_256x128x64 : TileCfg_256x256x64 { static constexpr ck_tile::index_t N_Tile = 128; }; @@ -249,7 +272,9 @@ bool ck_tile_grouped_gemm_fp16_dispatch_layout(DType a_dtype, DType d_dtype, TRANSFORMER_ENGINE_SWITCH_CONDITION(need_m_pad, kPadM, { TRANSFORMER_ENGINE_SWITCH_CONDITION(need_k_pad, kPadK, { - if (ctx.N % 256 == 0) { + if (ctx.arch == GPUArch::GFX1250) { + MAKE_RUNNER(TileCfg_256x256x64_WMMA, true, true, true); + } else if (ctx.N % 256 == 0) { MAKE_RUNNER(TileCfg_256x256x64, kPadM, false, kPadK); } else if (ctx.N % 128 == 0) { MAKE_RUNNER(TileCfg_256x128x64, kPadM, false, kPadK); diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp index d07d7927c..f093a378e 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -40,6 +40,29 @@ struct TileCfg_256x256x128_16x16x128_2x2x1 { static constexpr ck_tile::index_t TilePartitionerM01 = 8; }; +struct TileCfg_128x128x128_16x16x64_2x2x1 { + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 64; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 16; + static constexpr ck_tile::index_t TilePartitionerM01 = 8; +}; + struct TileCfg_128x128x128_16x16x128_2x2x1 : TileCfg_256x256x128_16x16x128_2x2x1 { static constexpr ck_tile::index_t M_Tile = 128; @@ -302,6 +325,11 @@ struct FP8TileCfg { using type = TileCfg_128x128x128_16x16x128_2x2x1; }; +template <> +struct FP8TileCfg { + using type = TileCfg_128x128x128_16x16x64_2x2x1; +}; + struct FP8GroupedShapeAlignment { bool all_n_256_aligned = true; bool all_n_128_aligned = true; @@ -449,13 +477,15 @@ bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, DType b_dtype, DType d_dtype, const GroupedGemmRunContext& ctx) { - switch (detect_gpu_arch()) { + switch (ctx.arch) { case GPUArch::GFX942: return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); case GPUArch::GFX950: return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); + case GPUArch::GFX1250: + return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); default: - NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950}"); + NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950, gfx1250}"); return false; } } diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp index 96726abb8..72d193591 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -373,6 +373,7 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor *A, .B = B_use, .D = D, .N = 0, + .arch = detect_gpu_arch(), .group_num = group_num, .transA = transA_use, .transB = transB_use,