Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,19 @@ bool ck_tile_grouped_gemm(const NVTETensor* A,
// FP8 special handling.
//
// A_use/B_use and transA_use/transB_use have already gone through the
// upstream-style grouped GEMM normalization above. This block only rewrites
// that normalized presentation into the CK FP8 preferred NT presentation by selecting
// `columnwise_data` when needed.
// upstream-style grouped GEMM normalization above. CK FP8 grouped GEMM is
// compiled only for the preferred NT presentation:
//
// CK FP8 target presentation:
// A_use: N
// B_use: T
// transA_use = false
// transB_use = true
//
// The outer condition checks whether this NT presentation is possible:
// - A_use is already N, or can be made N using columnwise_data
// - B_use is already T, or can be made T using columnwise_data
// This block rewrites the normalized presentation into that NT form by
// selecting columnwise_data when needed. If the required columnwise_data view
// is unavailable, this CK FP8 backend cannot represent the GEMM in its
// supported layout form, so we fall back instead of compiling/running an
// unsupported layout variant.
//
// Then each operand is rewritten independently only if needed:
// Rewrite cases:
// NN -> rewrite B only
// TN -> rewrite A and B
// NT -> already in target form
Expand All @@ -87,16 +87,23 @@ bool ck_tile_grouped_gemm(const NVTETensor* A,
const bool has_a_col = A0_te->has_columnwise_data();
const bool has_b_col = B0_te->has_columnwise_data();

if ((!transA_use || has_a_col) && (transB_use || has_b_col)) {
if (transA_use) {
use_a_colwise_data = true;
transA_use = false;
}
const bool can_make_a_nt = !transA_use || has_a_col;
const bool can_make_b_nt = transB_use || has_b_col;

if (!transB_use) {
use_b_colwise_data = true;
transB_use = true;
}
if (!can_make_a_nt || !can_make_b_nt) {
NVTE_WARN("ck_tile_grouped_gemm: FP8 grouped GEMM requires NT presentation. "
"Missing required columnwise_data for layout rewrite; falling back.");
return false;
}

if (transA_use) {
use_a_colwise_data = true;
transA_use = false;
}

if (!transB_use) {
use_b_colwise_data = true;
transB_use = true;
}
}

Expand Down Expand Up @@ -164,7 +171,8 @@ bool ck_tile_grouped_gemm(const NVTETensor* A,
A_use,
B_use,
D,
static_cast<int>(n),
n,
detect_gpu_arch(),
group_num,
transA_use,
transB_use,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ struct GroupedGemmRunContext {
NVTETensor* D = nullptr;
int64_t N = 0;

GPUArch arch = GPUArch::UNKNOWN;

int group_num = 0;
bool transA = false;
bool transB = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,29 @@ struct TileCfg_256x256x64 {
static constexpr ck_tile::index_t TilePartitionerM01 = 4;
};

struct TileCfg_256x256x64_WMMA {
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;

static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;

static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;

static constexpr bool DoubleSmemBuffer = false;

static constexpr ck_tile::index_t TilePartitionerGroupNum = 8;
static constexpr ck_tile::index_t TilePartitionerM01 = 4;
};

struct TileCfg_256x128x64 : TileCfg_256x256x64 {
static constexpr ck_tile::index_t N_Tile = 128;
};
Expand Down Expand Up @@ -249,7 +272,9 @@ bool ck_tile_grouped_gemm_fp16_dispatch_layout(DType a_dtype, DType d_dtype,

TRANSFORMER_ENGINE_SWITCH_CONDITION(need_m_pad, kPadM, {
TRANSFORMER_ENGINE_SWITCH_CONDITION(need_k_pad, kPadK, {
if (ctx.N % 256 == 0) {
if (ctx.arch == GPUArch::GFX1250) {
MAKE_RUNNER(TileCfg_256x256x64_WMMA, true, true, true);
} else if (ctx.N % 256 == 0) {
MAKE_RUNNER(TileCfg_256x256x64, kPadM, false, kPadK);
} else if (ctx.N % 128 == 0) {
MAKE_RUNNER(TileCfg_256x128x64, kPadM, false, kPadK);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,29 @@ struct TileCfg_256x256x128_16x16x128_2x2x1 {
static constexpr ck_tile::index_t TilePartitionerM01 = 8;
};

struct TileCfg_128x128x128_16x16x64_2x2x1 {
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;

static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 64;

static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;

static constexpr bool DoubleSmemBuffer = false;

static constexpr ck_tile::index_t TilePartitionerGroupNum = 16;
static constexpr ck_tile::index_t TilePartitionerM01 = 8;
};

struct TileCfg_128x128x128_16x16x128_2x2x1
: TileCfg_256x256x128_16x16x128_2x2x1 {
static constexpr ck_tile::index_t M_Tile = 128;
Expand Down Expand Up @@ -302,6 +325,11 @@ struct FP8TileCfg<GPUArch::GFX950> {
using type = TileCfg_128x128x128_16x16x128_2x2x1;
};

template <>
struct FP8TileCfg<GPUArch::GFX1250> {
using type = TileCfg_128x128x128_16x16x64_2x2x1;
};

struct FP8GroupedShapeAlignment {
bool all_n_256_aligned = true;
bool all_n_128_aligned = true;
Expand Down Expand Up @@ -449,13 +477,15 @@ bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype,
DType b_dtype,
DType d_dtype,
const GroupedGemmRunContext& ctx) {
switch (detect_gpu_arch()) {
switch (ctx.arch) {
case GPUArch::GFX942:
return ck_tile_grouped_gemm_fp8_dispatch_arch<GPUArch::GFX942>(a_dtype, b_dtype, d_dtype, ctx);
case GPUArch::GFX950:
return ck_tile_grouped_gemm_fp8_dispatch_arch<GPUArch::GFX950>(a_dtype, b_dtype, d_dtype, ctx);
case GPUArch::GFX1250:
return ck_tile_grouped_gemm_fp8_dispatch_arch<GPUArch::GFX1250>(a_dtype, b_dtype, d_dtype, ctx);
default:
NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950}");
NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950, gfx1250}");
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor *A,
.B = B_use,
.D = D,
.N = 0,
.arch = detect_gpu_arch(),
.group_num = group_num,
.transA = transA_use,
.transB = transB_use,
Expand Down
Loading