From 273a1dab0acc8c787c3b9e4ed7ad4610f4dc40ac Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Tue, 16 Jun 2026 14:31:40 +0000 Subject: [PATCH 1/2] add gfx1250 support to ck tile group gemm --- .../gemm/ck_grouped_gemm/ck_grouped_gemm.cpp | 45 +++++----- .../ck_grouped_gemm/ck_grouped_gemm_fp16.cpp | 82 ++++++++++++++----- .../ck_grouped_gemm_fp16_impl.h | 54 +++++++++--- .../ck_grouped_gemm_fp16_nn.cpp | 3 +- .../ck_grouped_gemm_fp16_nt.cpp | 3 +- .../ck_grouped_gemm_fp16_tn.cpp | 3 +- .../ck_grouped_gemm_fp16_tt.cpp | 3 +- .../ck_grouped_gemm/ck_grouped_gemm_fp8.cpp | 39 ++++++++- 8 files changed, 178 insertions(+), 54 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp index 27bc94f60..240207423 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp @@ -63,19 +63,19 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, // FP8 special handling. // // A_use/B_use and transA_use/transB_use have already gone through the - // upstream-style grouped GEMM normalization above. This block only rewrites - // that normalized presentation into the CK FP8 preferred NT presentation by selecting - // `columnwise_data` when needed. + // upstream-style grouped GEMM normalization above. CK FP8 grouped GEMM is + // compiled only for the preferred NT presentation: // - // CK FP8 target presentation: - // A_use: N - // B_use: T + // transA_use = false + // transB_use = true // - // The outer condition checks whether this NT presentation is possible: - // - A_use is already N, or can be made N using columnwise_data - // - B_use is already T, or can be made T using columnwise_data + // This block rewrites the normalized presentation into that NT form by + // selecting columnwise_data when needed. If the required columnwise_data view + // is unavailable, this CK FP8 backend cannot represent the GEMM in its + // supported layout form, so we fall back instead of compiling/running an + // unsupported layout variant. // - // Then each operand is rewritten independently only if needed: + // Rewrite cases: // NN -> rewrite B only // TN -> rewrite A and B // NT -> already in target form @@ -87,16 +87,23 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, const bool has_a_col = A0_te->has_columnwise_data(); const bool has_b_col = B0_te->has_columnwise_data(); - if ((!transA_use || has_a_col) && (transB_use || has_b_col)) { - if (transA_use) { - use_a_colwise_data = true; - transA_use = false; - } + const bool can_make_a_nt = !transA_use || has_a_col; + const bool can_make_b_nt = transB_use || has_b_col; - if (!transB_use) { - use_b_colwise_data = true; - transB_use = true; - } + if (!can_make_a_nt || !can_make_b_nt) { + NVTE_WARN("ck_tile_grouped_gemm: FP8 grouped GEMM requires NT presentation. " + "Missing required columnwise_data for layout rewrite; falling back."); + return false; + } + + if (transA_use) { + use_a_colwise_data = true; + transA_use = false; + } + + if (!transB_use) { + use_b_colwise_data = true; + transB_use = true; } } diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp index 4d21250d6..6c2272844 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp @@ -11,14 +11,62 @@ namespace transformer_engine { namespace grouped_gemm { +template +bool ck_tile_grouped_gemm_fp16_dispatch_arch_impl( + DType a_dtype, + DType d_dtype, + bool need_m_pad, + bool need_k_pad, + const GroupedGemmRunContext& ctx) { + if (!ctx.transA && !ctx.transB) { + return ck_tile_grouped_gemm_fp16_dispatch_nn( + a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); + } else if (!ctx.transA && ctx.transB) { + return ck_tile_grouped_gemm_fp16_dispatch_nt( + a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); + } else if (ctx.transA && !ctx.transB) { + return ck_tile_grouped_gemm_fp16_dispatch_tn( + a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); + } else { + return ck_tile_grouped_gemm_fp16_dispatch_tt( + a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); + } +} + +bool ck_tile_grouped_gemm_fp16_dispatch_arch(DType a_dtype, + DType d_dtype, + bool need_m_pad, + bool need_k_pad, + const GroupedGemmRunContext& ctx) { + switch (detect_gpu_arch()) { +#if defined(__gfx942__) + case GPUArch::GFX942: + return ck_tile_grouped_gemm_fp16_dispatch_arch_impl( + a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); +#endif +#if defined(__gfx950__) + case GPUArch::GFX950: + return ck_tile_grouped_gemm_fp16_dispatch_arch_impl( + a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); +#endif +#if defined(__gfx1250__) + case GPUArch::GFX1250: + return ck_tile_grouped_gemm_fp16_dispatch_arch_impl( + a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); +#endif + default: + return false; + } +} + bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, DType b_dtype, DType d_dtype, const GroupedGemmRunContext& ctx) { // Check M and K alignment across all groups. // All tile configs share the same M_Tile (256) and K_Tile (64). - constexpr ck_tile::index_t M_Tile = TileCfg_256x256x64::M_Tile; - constexpr ck_tile::index_t K_Tile = TileCfg_256x256x64::K_Tile; + constexpr ck_tile::index_t M_Tile = Fp16GroupedGemmMTile; + constexpr ck_tile::index_t K_Tile = Fp16GroupedGemmKTile; bool need_m_pad = false; bool need_k_pad = false; @@ -31,12 +79,15 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, const int64_t M = ctx.transA ? Ad1 : Ad0; const int64_t K = ctx.transA ? Ad0 : Ad1; - if (M % M_Tile != 0) + if (M % M_Tile != 0) { need_m_pad = true; - if (K % K_Tile != 0) + } + if (K % K_Tile != 0) { need_k_pad = true; - if (need_m_pad && need_k_pad) + } + if (need_m_pad && need_k_pad) { break; + } } } @@ -58,27 +109,18 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, if (!all_have_columnwise) { return false; } + // Dispatch with B's columnwise buffer as RowMajor (transB=false). GroupedGemmRunContext ctx_b_colwise = ctx; ctx_b_colwise.transB = false; ctx_b_colwise.use_b_columnwise_data = true; - if (!ctx_b_colwise.transA) { - return ck_tile_grouped_gemm_fp16_dispatch_nn(a_dtype, d_dtype, need_m_pad, need_k_pad, ctx_b_colwise); - } else { - return ck_tile_grouped_gemm_fp16_dispatch_tn(a_dtype, d_dtype, need_m_pad, need_k_pad, ctx_b_colwise); - } + return ck_tile_grouped_gemm_fp16_dispatch_arch( + a_dtype, d_dtype, need_m_pad, need_k_pad, ctx_b_colwise); } - // Dispatch to per-layout translation unit. - if (!ctx.transA && !ctx.transB) { - return ck_tile_grouped_gemm_fp16_dispatch_nn(a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); - } else if (!ctx.transA && ctx.transB) { - return ck_tile_grouped_gemm_fp16_dispatch_nt(a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); - } else if (ctx.transA && !ctx.transB) { - return ck_tile_grouped_gemm_fp16_dispatch_tn(a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); - } else { - return ck_tile_grouped_gemm_fp16_dispatch_tt(a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); - } + // Dispatch to arch + per-layout translation unit. + return ck_tile_grouped_gemm_fp16_dispatch_arch( + a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); } } // namespace grouped_gemm diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h index 4777e7b0f..37aa0c244 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h @@ -15,10 +15,13 @@ namespace grouped_gemm { // Tile configs: FP16/BF16 // ------------------------- -struct TileCfg_256x256x64 { - static constexpr ck_tile::index_t M_Tile = 256; +static constexpr ck_tile::index_t Fp16GroupedGemmMTile = 256; +static constexpr ck_tile::index_t Fp16GroupedGemmKTile = 64; + +struct TileCfg_256x256x64_MFMA { + static constexpr ck_tile::index_t M_Tile = Fp16GroupedGemmMTile; static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 64; + static constexpr ck_tile::index_t K_Tile = Fp16GroupedGemmKTile; static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; @@ -38,10 +41,33 @@ struct TileCfg_256x256x64 { static constexpr ck_tile::index_t TilePartitionerM01 = 4; }; -struct TileCfg_256x128x64 : TileCfg_256x256x64 { +struct TileCfg_256x128x64_MFMA : TileCfg_256x256x64_MFMA { static constexpr ck_tile::index_t N_Tile = 128; }; +struct TileCfg_256x256x64_WMMA { + static constexpr ck_tile::index_t M_Tile = Fp16GroupedGemmMTile; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = Fp16GroupedGemmKTile; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; + + static constexpr bool kPadM = true; + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + template struct WithPadding : Base { static constexpr bool kPadM = PadM_; @@ -232,7 +258,7 @@ class GroupedGemmRunner : public RunnerInterface { }) // Templated dispatch on A/B layouts, shared by all layout-specific .cpp files. -template +template bool ck_tile_grouped_gemm_fp16_dispatch_layout(DType a_dtype, DType d_dtype, bool need_m_pad, bool need_k_pad, const GroupedGemmRunContext& ctx) { @@ -249,12 +275,16 @@ bool ck_tile_grouped_gemm_fp16_dispatch_layout(DType a_dtype, DType d_dtype, TRANSFORMER_ENGINE_SWITCH_CONDITION(need_m_pad, kPadM, { TRANSFORMER_ENGINE_SWITCH_CONDITION(need_k_pad, kPadK, { - if (ctx.N % 256 == 0) { - MAKE_RUNNER(TileCfg_256x256x64, kPadM, false, kPadK); - } else if (ctx.N % 128 == 0) { - MAKE_RUNNER(TileCfg_256x128x64, kPadM, false, kPadK); + if constexpr (Arch == GPUArch::GFX1250) { + MAKE_RUNNER(TileCfg_256x256x64_WMMA, true, true, true); } else { - MAKE_RUNNER(TileCfg_256x128x64, kPadM, true, kPadK); + if (ctx.N % 256 == 0) { + MAKE_RUNNER(TileCfg_256x256x64_MFMA, kPadM, false, kPadK); + } else if (ctx.N % 128 == 0) { + MAKE_RUNNER(TileCfg_256x128x64_MFMA, kPadM, false, kPadK); + } else { + MAKE_RUNNER(TileCfg_256x128x64_MFMA, kPadM, true, kPadK); + } } }); }); @@ -269,15 +299,19 @@ bool ck_tile_grouped_gemm_fp16_dispatch_layout(DType a_dtype, DType d_dtype, // Per-layout dispatch function signature. // Each layout file (NN, NT, TN, TT) implements one of these. +template bool ck_tile_grouped_gemm_fp16_dispatch_nn(DType a_dtype, DType d_dtype, bool need_m_pad, bool need_k_pad, const GroupedGemmRunContext& ctx); +template bool ck_tile_grouped_gemm_fp16_dispatch_nt(DType a_dtype, DType d_dtype, bool need_m_pad, bool need_k_pad, const GroupedGemmRunContext& ctx); +template bool ck_tile_grouped_gemm_fp16_dispatch_tn(DType a_dtype, DType d_dtype, bool need_m_pad, bool need_k_pad, const GroupedGemmRunContext& ctx); +template bool ck_tile_grouped_gemm_fp16_dispatch_tt(DType a_dtype, DType d_dtype, bool need_m_pad, bool need_k_pad, const GroupedGemmRunContext& ctx); diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nn.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nn.cpp index 276859f35..3e832e030 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nn.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nn.cpp @@ -9,10 +9,11 @@ namespace transformer_engine { namespace grouped_gemm { +template bool ck_tile_grouped_gemm_fp16_dispatch_nn(DType a_dtype, DType d_dtype, bool need_m_pad, bool need_k_pad, const GroupedGemmRunContext& ctx) { - return ck_tile_grouped_gemm_fp16_dispatch_layout( + return ck_tile_grouped_gemm_fp16_dispatch_layout( a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); } diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nt.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nt.cpp index 19a79820f..90a53e4f6 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nt.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nt.cpp @@ -9,10 +9,11 @@ namespace transformer_engine { namespace grouped_gemm { +template bool ck_tile_grouped_gemm_fp16_dispatch_nt(DType a_dtype, DType d_dtype, bool need_m_pad, bool need_k_pad, const GroupedGemmRunContext& ctx) { - return ck_tile_grouped_gemm_fp16_dispatch_layout( + return ck_tile_grouped_gemm_fp16_dispatch_layout( a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); } diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tn.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tn.cpp index ef5715aa3..9c9e1c1f0 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tn.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tn.cpp @@ -9,10 +9,11 @@ namespace transformer_engine { namespace grouped_gemm { +template bool ck_tile_grouped_gemm_fp16_dispatch_tn(DType a_dtype, DType d_dtype, bool need_m_pad, bool need_k_pad, const GroupedGemmRunContext& ctx) { - return ck_tile_grouped_gemm_fp16_dispatch_layout( + return ck_tile_grouped_gemm_fp16_dispatch_layout( a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); } diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tt.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tt.cpp index 039276450..ae706ef9e 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tt.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tt.cpp @@ -9,10 +9,11 @@ namespace transformer_engine { namespace grouped_gemm { +template bool ck_tile_grouped_gemm_fp16_dispatch_tt(DType a_dtype, DType d_dtype, bool need_m_pad, bool need_k_pad, const GroupedGemmRunContext& ctx) { - return ck_tile_grouped_gemm_fp16_dispatch_layout( + return ck_tile_grouped_gemm_fp16_dispatch_layout( a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); } diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp index d07d7927c..121ffe21a 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -40,6 +40,29 @@ struct TileCfg_256x256x128_16x16x128_2x2x1 { static constexpr ck_tile::index_t TilePartitionerM01 = 8; }; +struct TileCfg_128x128x128_16x16x64_2x2x1 { + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 64; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 16; + static constexpr ck_tile::index_t TilePartitionerM01 = 8; +}; + struct TileCfg_128x128x128_16x16x128_2x2x1 : TileCfg_256x256x128_16x16x128_2x2x1 { static constexpr ck_tile::index_t M_Tile = 128; @@ -302,6 +325,11 @@ struct FP8TileCfg { using type = TileCfg_128x128x128_16x16x128_2x2x1; }; +template <> +struct FP8TileCfg { + using type = TileCfg_128x128x128_16x16x64_2x2x1; +}; + struct FP8GroupedShapeAlignment { bool all_n_256_aligned = true; bool all_n_128_aligned = true; @@ -450,12 +478,21 @@ bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, DType d_dtype, const GroupedGemmRunContext& ctx) { switch (detect_gpu_arch()) { +#if defined(__gfx942__) case GPUArch::GFX942: return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); +#endif +#if defined(__gfx950__) case GPUArch::GFX950: return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); +#endif +#if defined(__gfx1250__) + case GPUArch::GFX1250: + return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); +#endif + default: - NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950}"); + NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950, gfx1250}"); return false; } } From 2cb357b72c4f51e889450873702daeccfadd5c4a Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Tue, 16 Jun 2026 20:37:49 +0000 Subject: [PATCH 2/2] address pr comments: make arch runtime dependent for ck group gemm --- .../gemm/ck_grouped_gemm/ck_grouped_gemm.cpp | 3 +- .../ck_grouped_gemm/ck_grouped_gemm_common.h | 2 + .../ck_grouped_gemm/ck_grouped_gemm_fp16.cpp | 82 +++++-------------- .../ck_grouped_gemm_fp16_impl.h | 41 ++++------ .../ck_grouped_gemm_fp16_nn.cpp | 3 +- .../ck_grouped_gemm_fp16_nt.cpp | 3 +- .../ck_grouped_gemm_fp16_tn.cpp | 3 +- .../ck_grouped_gemm_fp16_tt.cpp | 3 +- .../ck_grouped_gemm/ck_grouped_gemm_fp8.cpp | 9 +- .../ck_grouped_gemm/ck_mx_grouped_gemm.cpp | 1 + 10 files changed, 46 insertions(+), 104 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp index 240207423..087a17f8e 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp @@ -171,7 +171,8 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, A_use, B_use, D, - static_cast(n), + n, + detect_gpu_arch(), group_num, transA_use, transB_use, diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h index 94a301df7..de2bdb985 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h @@ -103,6 +103,8 @@ struct GroupedGemmRunContext { NVTETensor* D = nullptr; int64_t N = 0; + GPUArch arch = GPUArch::UNKNOWN; + int group_num = 0; bool transA = false; bool transB = false; diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp index 6c2272844..4d21250d6 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp @@ -11,62 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { -template -bool ck_tile_grouped_gemm_fp16_dispatch_arch_impl( - DType a_dtype, - DType d_dtype, - bool need_m_pad, - bool need_k_pad, - const GroupedGemmRunContext& ctx) { - if (!ctx.transA && !ctx.transB) { - return ck_tile_grouped_gemm_fp16_dispatch_nn( - a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); - } else if (!ctx.transA && ctx.transB) { - return ck_tile_grouped_gemm_fp16_dispatch_nt( - a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); - } else if (ctx.transA && !ctx.transB) { - return ck_tile_grouped_gemm_fp16_dispatch_tn( - a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); - } else { - return ck_tile_grouped_gemm_fp16_dispatch_tt( - a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); - } -} - -bool ck_tile_grouped_gemm_fp16_dispatch_arch(DType a_dtype, - DType d_dtype, - bool need_m_pad, - bool need_k_pad, - const GroupedGemmRunContext& ctx) { - switch (detect_gpu_arch()) { -#if defined(__gfx942__) - case GPUArch::GFX942: - return ck_tile_grouped_gemm_fp16_dispatch_arch_impl( - a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); -#endif -#if defined(__gfx950__) - case GPUArch::GFX950: - return ck_tile_grouped_gemm_fp16_dispatch_arch_impl( - a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); -#endif -#if defined(__gfx1250__) - case GPUArch::GFX1250: - return ck_tile_grouped_gemm_fp16_dispatch_arch_impl( - a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); -#endif - default: - return false; - } -} - bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, DType b_dtype, DType d_dtype, const GroupedGemmRunContext& ctx) { // Check M and K alignment across all groups. // All tile configs share the same M_Tile (256) and K_Tile (64). - constexpr ck_tile::index_t M_Tile = Fp16GroupedGemmMTile; - constexpr ck_tile::index_t K_Tile = Fp16GroupedGemmKTile; + constexpr ck_tile::index_t M_Tile = TileCfg_256x256x64::M_Tile; + constexpr ck_tile::index_t K_Tile = TileCfg_256x256x64::K_Tile; bool need_m_pad = false; bool need_k_pad = false; @@ -79,15 +31,12 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, const int64_t M = ctx.transA ? Ad1 : Ad0; const int64_t K = ctx.transA ? Ad0 : Ad1; - if (M % M_Tile != 0) { + if (M % M_Tile != 0) need_m_pad = true; - } - if (K % K_Tile != 0) { + if (K % K_Tile != 0) need_k_pad = true; - } - if (need_m_pad && need_k_pad) { + if (need_m_pad && need_k_pad) break; - } } } @@ -109,18 +58,27 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, if (!all_have_columnwise) { return false; } - // Dispatch with B's columnwise buffer as RowMajor (transB=false). GroupedGemmRunContext ctx_b_colwise = ctx; ctx_b_colwise.transB = false; ctx_b_colwise.use_b_columnwise_data = true; - return ck_tile_grouped_gemm_fp16_dispatch_arch( - a_dtype, d_dtype, need_m_pad, need_k_pad, ctx_b_colwise); + if (!ctx_b_colwise.transA) { + return ck_tile_grouped_gemm_fp16_dispatch_nn(a_dtype, d_dtype, need_m_pad, need_k_pad, ctx_b_colwise); + } else { + return ck_tile_grouped_gemm_fp16_dispatch_tn(a_dtype, d_dtype, need_m_pad, need_k_pad, ctx_b_colwise); + } } - // Dispatch to arch + per-layout translation unit. - return ck_tile_grouped_gemm_fp16_dispatch_arch( - a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); + // Dispatch to per-layout translation unit. + if (!ctx.transA && !ctx.transB) { + return ck_tile_grouped_gemm_fp16_dispatch_nn(a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); + } else if (!ctx.transA && ctx.transB) { + return ck_tile_grouped_gemm_fp16_dispatch_nt(a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); + } else if (ctx.transA && !ctx.transB) { + return ck_tile_grouped_gemm_fp16_dispatch_tn(a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); + } else { + return ck_tile_grouped_gemm_fp16_dispatch_tt(a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); + } } } // namespace grouped_gemm diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h index 37aa0c244..efc074479 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_impl.h @@ -15,13 +15,10 @@ namespace grouped_gemm { // Tile configs: FP16/BF16 // ------------------------- -static constexpr ck_tile::index_t Fp16GroupedGemmMTile = 256; -static constexpr ck_tile::index_t Fp16GroupedGemmKTile = 64; - -struct TileCfg_256x256x64_MFMA { - static constexpr ck_tile::index_t M_Tile = Fp16GroupedGemmMTile; +struct TileCfg_256x256x64 { + static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = Fp16GroupedGemmKTile; + static constexpr ck_tile::index_t K_Tile = 64; static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; @@ -41,14 +38,10 @@ struct TileCfg_256x256x64_MFMA { static constexpr ck_tile::index_t TilePartitionerM01 = 4; }; -struct TileCfg_256x128x64_MFMA : TileCfg_256x256x64_MFMA { - static constexpr ck_tile::index_t N_Tile = 128; -}; - struct TileCfg_256x256x64_WMMA { - static constexpr ck_tile::index_t M_Tile = Fp16GroupedGemmMTile; + static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = Fp16GroupedGemmKTile; + static constexpr ck_tile::index_t K_Tile = 64; static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; @@ -68,6 +61,10 @@ struct TileCfg_256x256x64_WMMA { static constexpr ck_tile::index_t TilePartitionerM01 = 4; }; +struct TileCfg_256x128x64 : TileCfg_256x256x64 { + static constexpr ck_tile::index_t N_Tile = 128; +}; + template struct WithPadding : Base { static constexpr bool kPadM = PadM_; @@ -258,7 +255,7 @@ class GroupedGemmRunner : public RunnerInterface { }) // Templated dispatch on A/B layouts, shared by all layout-specific .cpp files. -template +template bool ck_tile_grouped_gemm_fp16_dispatch_layout(DType a_dtype, DType d_dtype, bool need_m_pad, bool need_k_pad, const GroupedGemmRunContext& ctx) { @@ -275,16 +272,14 @@ bool ck_tile_grouped_gemm_fp16_dispatch_layout(DType a_dtype, DType d_dtype, TRANSFORMER_ENGINE_SWITCH_CONDITION(need_m_pad, kPadM, { TRANSFORMER_ENGINE_SWITCH_CONDITION(need_k_pad, kPadK, { - if constexpr (Arch == GPUArch::GFX1250) { + if (ctx.arch == GPUArch::GFX1250) { MAKE_RUNNER(TileCfg_256x256x64_WMMA, true, true, true); + } else if (ctx.N % 256 == 0) { + MAKE_RUNNER(TileCfg_256x256x64, kPadM, false, kPadK); + } else if (ctx.N % 128 == 0) { + MAKE_RUNNER(TileCfg_256x128x64, kPadM, false, kPadK); } else { - if (ctx.N % 256 == 0) { - MAKE_RUNNER(TileCfg_256x256x64_MFMA, kPadM, false, kPadK); - } else if (ctx.N % 128 == 0) { - MAKE_RUNNER(TileCfg_256x128x64_MFMA, kPadM, false, kPadK); - } else { - MAKE_RUNNER(TileCfg_256x128x64_MFMA, kPadM, true, kPadK); - } + MAKE_RUNNER(TileCfg_256x128x64, kPadM, true, kPadK); } }); }); @@ -299,19 +294,15 @@ bool ck_tile_grouped_gemm_fp16_dispatch_layout(DType a_dtype, DType d_dtype, // Per-layout dispatch function signature. // Each layout file (NN, NT, TN, TT) implements one of these. -template bool ck_tile_grouped_gemm_fp16_dispatch_nn(DType a_dtype, DType d_dtype, bool need_m_pad, bool need_k_pad, const GroupedGemmRunContext& ctx); -template bool ck_tile_grouped_gemm_fp16_dispatch_nt(DType a_dtype, DType d_dtype, bool need_m_pad, bool need_k_pad, const GroupedGemmRunContext& ctx); -template bool ck_tile_grouped_gemm_fp16_dispatch_tn(DType a_dtype, DType d_dtype, bool need_m_pad, bool need_k_pad, const GroupedGemmRunContext& ctx); -template bool ck_tile_grouped_gemm_fp16_dispatch_tt(DType a_dtype, DType d_dtype, bool need_m_pad, bool need_k_pad, const GroupedGemmRunContext& ctx); diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nn.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nn.cpp index 3e832e030..276859f35 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nn.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nn.cpp @@ -9,11 +9,10 @@ namespace transformer_engine { namespace grouped_gemm { -template bool ck_tile_grouped_gemm_fp16_dispatch_nn(DType a_dtype, DType d_dtype, bool need_m_pad, bool need_k_pad, const GroupedGemmRunContext& ctx) { - return ck_tile_grouped_gemm_fp16_dispatch_layout( + return ck_tile_grouped_gemm_fp16_dispatch_layout( a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); } diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nt.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nt.cpp index 90a53e4f6..19a79820f 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nt.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nt.cpp @@ -9,11 +9,10 @@ namespace transformer_engine { namespace grouped_gemm { -template bool ck_tile_grouped_gemm_fp16_dispatch_nt(DType a_dtype, DType d_dtype, bool need_m_pad, bool need_k_pad, const GroupedGemmRunContext& ctx) { - return ck_tile_grouped_gemm_fp16_dispatch_layout( + return ck_tile_grouped_gemm_fp16_dispatch_layout( a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); } diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tn.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tn.cpp index 9c9e1c1f0..ef5715aa3 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tn.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tn.cpp @@ -9,11 +9,10 @@ namespace transformer_engine { namespace grouped_gemm { -template bool ck_tile_grouped_gemm_fp16_dispatch_tn(DType a_dtype, DType d_dtype, bool need_m_pad, bool need_k_pad, const GroupedGemmRunContext& ctx) { - return ck_tile_grouped_gemm_fp16_dispatch_layout( + return ck_tile_grouped_gemm_fp16_dispatch_layout( a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); } diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tt.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tt.cpp index ae706ef9e..039276450 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tt.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tt.cpp @@ -9,11 +9,10 @@ namespace transformer_engine { namespace grouped_gemm { -template bool ck_tile_grouped_gemm_fp16_dispatch_tt(DType a_dtype, DType d_dtype, bool need_m_pad, bool need_k_pad, const GroupedGemmRunContext& ctx) { - return ck_tile_grouped_gemm_fp16_dispatch_layout( + return ck_tile_grouped_gemm_fp16_dispatch_layout( a_dtype, d_dtype, need_m_pad, need_k_pad, ctx); } diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp index 121ffe21a..f093a378e 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -477,20 +477,13 @@ bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, DType b_dtype, DType d_dtype, const GroupedGemmRunContext& ctx) { - switch (detect_gpu_arch()) { -#if defined(__gfx942__) + switch (ctx.arch) { case GPUArch::GFX942: return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); -#endif -#if defined(__gfx950__) case GPUArch::GFX950: return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); -#endif -#if defined(__gfx1250__) case GPUArch::GFX1250: return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); -#endif - default: NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950, gfx1250}"); return false; diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp index 96726abb8..72d193591 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -373,6 +373,7 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor *A, .B = B_use, .D = D, .N = 0, + .arch = detect_gpu_arch(), .group_num = group_num, .transA = transA_use, .transB = transB_use,