diff --git a/transformer_engine/pytorch/triton_kernels/gmm/configs/gfx1250-GMM.json b/transformer_engine/pytorch/triton_kernels/gmm/configs/gfx1250-GMM.json new file mode 100644 index 000000000..5da30880d --- /dev/null +++ b/transformer_engine/pytorch/triton_kernels/gmm/configs/gfx1250-GMM.json @@ -0,0 +1,203 @@ +{ + "gmm": { + "default": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "tiny_shapes": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 4, + "num_stages": 1 + }, + "k_heavy": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "n_heavy": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "balanced_large_n": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "very_large_m": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "small_shapes": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 16, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "small_m_moderate_n": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 1, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "default_no_trans_rhs": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "k_heavy_bwd": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "very_large_m_small_n_bwd": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "balanced_large_n_bwd": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "n_very_heavy_bwd": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 8, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + }, + "small_k_large_n_bwd": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 8, + "num_stages": 2 + } + }, + "ptgmm": { + "default": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 4, + "num_stages": 1 + }, + "high_group_count": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 1, + "GRID_DIM": 256, + "num_warps": 2, + "num_stages": 2 + }, + "small_n": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 4, + "num_stages": 1 + }, + "small_n_high_group": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 4, + "GRID_DIM": 256, + "num_warps": 4, + "num_stages": 2 + }, + "accumulate": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 1, + "GRID_DIM": 256, + "num_warps": 2, + "num_stages": 2 + } + }, + "nptgmm": { + "default": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 256, + "GROUP_SIZE": 1, + "num_warps": 8, + "num_stages": 2 + }, + "small_n": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 16, + "num_warps": 8, + "num_stages": 2 + }, + "accumulate": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE": 1, + "num_warps": 8, + "num_stages": 2 + } + } +}