Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 1 addition & 45 deletions csrc/ck_gemm_a8w8_blockscale/gen_instances.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
import argparse
import os
import sys
Expand Down Expand Up @@ -70,50 +70,6 @@ def gen_ck_instance(self, k: KernelInstance):

#include "gemm_a8w8_blockscale_common.cuh"

enum class GemmSpecialization {{
Default = 0,
MPadding = 1,
NPadding = 2,
KPadding = 3,
MNPadding = 4,
MKPadding = 5,
NKPadding = 6,
MNKPadding = 7
}};

static const std::unordered_map<std::string, GemmSpecialization> g_gemm_spec_names{{
{{"", GemmSpecialization::Default}},
{{"M", GemmSpecialization::MPadding}},
{{"N", GemmSpecialization::NPadding}},
{{"K", GemmSpecialization::KPadding}},
{{"MN", GemmSpecialization::MNPadding}},
{{"MK", GemmSpecialization::MKPadding}},
{{"NK", GemmSpecialization::NKPadding}},
{{"MNK", GemmSpecialization::MNKPadding}}
}};

static GemmSpecialization GetGemmSpec(const int64_t m,
const int64_t n,
const int64_t k,
const int64_t m_per_block,
const int64_t n_per_block,
const int64_t k_per_block)
{{
auto IntegerDivideCeil = [](int x, int y) {{
return (x + y - size_t{{1}}) / y;
}};

std::string spec = "";
if (IntegerDivideCeil(m, m_per_block) * m_per_block - m != 0)
spec += "M";
if (IntegerDivideCeil(n, n_per_block) * n_per_block - n != 0)
spec += "N";
if (IntegerDivideCeil(k, k_per_block) * k_per_block - k != 0)
spec += "K";

return g_gemm_spec_names.at(spec);
}}

template <typename DDataType, typename EDataType>
torch::Tensor
{k.name}(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,43 @@ using CDEElementOp = PassThrough;
// static constexpr ck::index_t Scale_Block_N = 128;
// static constexpr ck::index_t Scale_Block_K = 128;

enum class GemmSpecialization : unsigned int {
Default = 0,
MPadding = 1u<<0,
NPadding = 1u<<1,
KPadding = 1u<<2,
MNPadding = MPadding|NPadding,
MKPadding = MPadding|KPadding,
NKPadding = NPadding|KPadding,
MNKPadding = MPadding|NPadding|KPadding
};

static GemmSpecialization operator|(GemmSpecialization lhs, GemmSpecialization rhs) {
return static_cast<GemmSpecialization>(unsigned(lhs) | unsigned(rhs));
}

static GemmSpecialization GetGemmSpec(const int64_t m,
const int64_t n,
const int64_t k,
const int64_t m_per_block,
const int64_t n_per_block,
const int64_t k_per_block)
{
auto HasPadding = [](int x, int y) {
return (x % y) != 0;
};

GemmSpecialization spec{};
if (HasPadding(m, m_per_block))
spec = spec | GemmSpecialization::M;
if (HasPadding(n, n_per_block))
spec = spec | GemmSpecialization::N;
if (HasPadding(k, k_per_block))
spec = spec | GemmSpecialization::K;

return spec;
}

template <typename AB1DataType,
typename EDataType,
ck::index_t BlockSize,
Expand Down
Loading