From aae20a797deb532cf9086c25578ac04ead8c0fe3 Mon Sep 17 00:00:00 2001 From: Jorge Bellon Castro Date: Fri, 19 Jun 2026 16:44:40 +0100 Subject: [PATCH] Simplify GemmSpecialization construction --- csrc/ck_gemm_a8w8_blockscale/gen_instances.py | 46 +------------------ .../include/gemm_a8w8_blockscale_common.cuh | 37 +++++++++++++++ 2 files changed, 38 insertions(+), 45 deletions(-) diff --git a/csrc/ck_gemm_a8w8_blockscale/gen_instances.py b/csrc/ck_gemm_a8w8_blockscale/gen_instances.py index 29de526813..22b2cf2244 100644 --- a/csrc/ck_gemm_a8w8_blockscale/gen_instances.py +++ b/csrc/ck_gemm_a8w8_blockscale/gen_instances.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import argparse import os import sys @@ -70,50 +70,6 @@ def gen_ck_instance(self, k: KernelInstance): #include "gemm_a8w8_blockscale_common.cuh" -enum class GemmSpecialization {{ - Default = 0, - MPadding = 1, - NPadding = 2, - KPadding = 3, - MNPadding = 4, - MKPadding = 5, - NKPadding = 6, - MNKPadding = 7 -}}; - -static const std::unordered_map g_gemm_spec_names{{ - {{"", GemmSpecialization::Default}}, - {{"M", GemmSpecialization::MPadding}}, - {{"N", GemmSpecialization::NPadding}}, - {{"K", GemmSpecialization::KPadding}}, - {{"MN", GemmSpecialization::MNPadding}}, - {{"MK", GemmSpecialization::MKPadding}}, - {{"NK", GemmSpecialization::NKPadding}}, - {{"MNK", GemmSpecialization::MNKPadding}} -}}; - -static GemmSpecialization GetGemmSpec(const int64_t m, - const int64_t n, - const int64_t k, - const int64_t m_per_block, - const int64_t n_per_block, - const int64_t k_per_block) -{{ - auto IntegerDivideCeil = [](int x, int y) {{ - return (x + y - size_t{{1}}) / y; - }}; - - std::string spec = ""; - if (IntegerDivideCeil(m, m_per_block) * m_per_block - m != 0) - spec += "M"; - if (IntegerDivideCeil(n, n_per_block) * n_per_block - n != 0) - spec += "N"; - if (IntegerDivideCeil(k, k_per_block) * k_per_block - k != 0) - spec += "K"; - - return g_gemm_spec_names.at(spec); -}} - template torch::Tensor {k.name}( diff --git a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh index d7f0a43d18..b570c9b9ed 100644 --- a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh +++ b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh @@ -69,6 +69,43 @@ using CDEElementOp = PassThrough; // static constexpr ck::index_t Scale_Block_N = 128; // static constexpr ck::index_t Scale_Block_K = 128; +enum class GemmSpecialization : unsigned int { + Default = 0, + MPadding = 1u<<0, + NPadding = 1u<<1, + KPadding = 1u<<2, + MNPadding = MPadding|NPadding, + MKPadding = MPadding|KPadding, + NKPadding = NPadding|KPadding, + MNKPadding = MPadding|NPadding|KPadding +}; + +static GemmSpecialization operator|(GemmSpecialization lhs, GemmSpecialization rhs) { + return static_cast(unsigned(lhs) | unsigned(rhs)); +} + +static GemmSpecialization GetGemmSpec(const int64_t m, + const int64_t n, + const int64_t k, + const int64_t m_per_block, + const int64_t n_per_block, + const int64_t k_per_block) +{ + auto HasPadding = [](int x, int y) { + return (x % y) != 0; + }; + + GemmSpecialization spec{}; + if (HasPadding(m, m_per_block)) + spec = spec | GemmSpecialization::M; + if (HasPadding(n, n_per_block)) + spec = spec | GemmSpecialization::N; + if (HasPadding(k, k_per_block)) + spec = spec | GemmSpecialization::K; + + return spec; +} + template