From e74c2141c85dbeb43fc1d75b35e55892d6876867 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 18 Jun 2026 22:48:50 +0000 Subject: [PATCH] gfx1250: add (lightly-optimized) Triton GMM config --- .../gmm/configs/gfx1250-GMM.json | 203 ++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100644 transformer_engine/pytorch/triton_kernels/gmm/configs/gfx1250-GMM.json diff --git a/transformer_engine/pytorch/triton_kernels/gmm/configs/gfx1250-GMM.json b/transformer_engine/pytorch/triton_kernels/gmm/configs/gfx1250-GMM.json new file mode 100644 index 000000000..5da30880d --- /dev/null +++ b/transformer_engine/pytorch/triton_kernels/gmm/configs/gfx1250-GMM.json @@ -0,0 +1,203 @@ +{ + "gmm": { + "default": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "tiny_shapes": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 4, + "num_stages": 1 + }, + "k_heavy": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "n_heavy": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "balanced_large_n": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "very_large_m": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "small_shapes": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 16, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "small_m_moderate_n": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 1, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "default_no_trans_rhs": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "k_heavy_bwd": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "very_large_m_small_n_bwd": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "balanced_large_n_bwd": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "n_very_heavy_bwd": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 8, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "small_k_large_n_bwd": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + } + }, + "ptgmm": { + "default": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 4, + "num_stages": 1 + }, + "high_group_count": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 1, + "GRID_DIM": 256, + "num_warps": 2, + "num_stages": 2 + }, + "small_n": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 4, + "num_stages": 1 + }, + "small_n_high_group": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 4, + "num_stages": 2 + }, + "accumulate": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 1, + "GRID_DIM": 256, + "num_warps": 2, + "num_stages": 2 + } + }, + "nptgmm": { + "default": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 1, + "num_warps": 8, + "num_stages": 2 + }, + "small_n": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 16, + "num_warps": 8, + "num_stages": 2 + }, + "accumulate": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 1, + "num_warps": 8, + "num_stages": 2 + } + } +}