Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions benchmarks/cpp/cast/bench_multi_cast_transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "benchmark_utils.h"

#include "transformer_engine/padding_hip.h"
#include "transformer_engine/transpose_hip.h"
#include "transformer_engine/transformer_engine_hip.h"

Expand Down Expand Up @@ -225,6 +226,182 @@ static void BM_MultiCastTranspose(benchmark::State &state) {
HIP_CHECK(hipStreamDestroy(stream));
}

// Unfused baseline: separate padding kernel + cast_transpose
template <typename IType>
static void BM_PaddingThenMCT(benchmark::State &state) {
const size_t total_tokens = state.range(0);
const size_t cols = state.range(1);
const size_t num_experts = state.range(2);
const size_t top_k = state.range(3);
const size_t routing_mode = state.range(4);

uint64_t seed = derive_seed(total_tokens, cols, num_experts, top_k, routing_mode);

auto counts = (routing_mode == 0)
? simulate_topk_balanced(total_tokens, num_experts, top_k, seed)
: simulate_topk_skewed(total_tokens, num_experts, top_k, seed);

size_t sum_tok = std::accumulate(counts.begin(), counts.end(), size_t(0));

DType itype = std::is_same_v<IType, float> ? DType::kFloat32 :
std::is_same_v<IType, hip_bfloat16> ? DType::kBFloat16 :
DType::kFloat16;

std::string pfx = "pad_mct_" + std::to_string(total_tokens) + "_" + std::to_string(cols) + "_"
+ std::to_string(num_experts) + "_" + std::to_string(routing_mode);

std::vector<NVTETensor> nvte_in(num_experts), nvte_pad_out(num_experts),
nvte_mct_in(num_experts), nvte_mct_out(num_experts);

std::vector<int> padded_rows_list(num_experts);

for (size_t e = 0; e < num_experts; e++) {
size_t actual = std::max(counts[e], size_t(1));
size_t padded = ((actual + kPadMultiple - 1) / kPadMultiple) * kPadMultiple;
padded_rows_list[e] = static_cast<int>(padded);

auto &input = TensorCache::get_or_create(pfx + "_in_" + std::to_string(e), {actual, cols}, itype,
true, false, NVTE_DELAYED_TENSOR_SCALING, true);

auto &pad_out = TensorCache::get_or_create(pfx + "_pad_" + std::to_string(e), {padded, cols}, itype,
true, false, NVTE_DELAYED_TENSOR_SCALING, false);

auto &mct_out = TensorCache::get_or_create(pfx + "_out_" + std::to_string(e), {padded, cols}, DType::kFloat8E4M3,
true, true, NVTE_DELAYED_TENSOR_SCALING, false);

mct_out.set_scale(1.0f);

nvte_in[e] = input.data();
nvte_pad_out[e] = pad_out.data();
nvte_mct_in[e] = pad_out.data();
nvte_mct_out[e] = mct_out.data();
}

hipStream_t stream;
HIP_CHECK(hipStreamCreate(&stream));

hipEvent_t start, stop;
HIP_CHECK(hipEventCreate(&start));
HIP_CHECK(hipEventCreate(&stop));

warmup_gpu();

for (auto _ : state) {
HIP_CHECK(hipEventRecord(start, stream));
nvte_multi_padding(num_experts, nvte_in.data(), nvte_pad_out.data(), padded_rows_list.data(), stream);
nvte_multi_cast_transpose(num_experts, nvte_mct_in.data(), nvte_mct_out.data(), stream);
HIP_CHECK(hipEventRecord(stop, stream));
HIP_CHECK(hipEventSynchronize(stop));

float ms = 0;
HIP_CHECK(hipEventElapsedTime(&ms, start, stop));
state.SetIterationTime(ms / 1000.0);
}

HIP_CHECK(hipEventDestroy(start));
HIP_CHECK(hipEventDestroy(stop));

size_t total_bytes = 0;
for (size_t e = 0; e < num_experts; e++) {
size_t actual = std::max(counts[e], size_t(1));
size_t padded = padded_rows_list[e];
total_bytes += actual * cols * sizeof(IType);
total_bytes += padded * cols * sizeof(IType);
total_bytes += padded * cols * sizeof(IType);
total_bytes += padded * cols * sizeof(fp8_e4m3) * 2;
}
set_bytes_processed(state, total_bytes);

state.counters["experts"] = num_experts;
state.counters["cols"] = cols;
state.counters["avg_tok"] = static_cast<double>(sum_tok) / num_experts;

HIP_CHECK(hipStreamDestroy(stream));
}

// Fused: single cast_transpose kernel with built-in padding
template <typename IType>
static void BM_FusedPaddingMCT(benchmark::State &state) {
const size_t total_tokens = state.range(0);
const size_t cols = state.range(1);
const size_t num_experts = state.range(2);
const size_t top_k = state.range(3);
const size_t routing_mode = state.range(4);

uint64_t seed = derive_seed(total_tokens, cols, num_experts, top_k, routing_mode);

auto counts = (routing_mode == 0)
? simulate_topk_balanced(total_tokens, num_experts, top_k, seed)
: simulate_topk_skewed(total_tokens, num_experts, top_k, seed);

size_t sum_tok = std::accumulate(counts.begin(), counts.end(), size_t(0));

DType itype = std::is_same_v<IType, float> ? DType::kFloat32 :
std::is_same_v<IType, hip_bfloat16> ? DType::kBFloat16 :
DType::kFloat16;

std::string pfx = "fused_mct_" + std::to_string(total_tokens) + "_" + std::to_string(cols) + "_"
+ std::to_string(num_experts) + "_" + std::to_string(routing_mode);

std::vector<NVTETensor> nvte_in(num_experts), nvte_out(num_experts);
std::vector<int> valid_rows_list(num_experts);

for (size_t e = 0; e < num_experts; e++) {
size_t actual = std::max(counts[e], size_t(1));
size_t padded = ((actual + kPadMultiple - 1) / kPadMultiple) * kPadMultiple;
valid_rows_list[e] = static_cast<int>(actual);

auto &input = TensorCache::get_or_create(pfx + "_in_" + std::to_string(e), {actual, cols}, itype,
true, false, NVTE_DELAYED_TENSOR_SCALING, true);

auto &output = TensorCache::get_or_create(pfx + "_out_" + std::to_string(e), {padded, cols}, DType::kFloat8E4M3,
true, true, NVTE_DELAYED_TENSOR_SCALING, false);

output.set_scale(1.0f);

nvte_in[e] = input.data();
nvte_out[e] = output.data();
}

hipStream_t stream;
HIP_CHECK(hipStreamCreate(&stream));

hipEvent_t start, stop;
HIP_CHECK(hipEventCreate(&start));
HIP_CHECK(hipEventCreate(&stop));

warmup_gpu();

for (auto _ : state) {
HIP_CHECK(hipEventRecord(start, stream));
nvte_multi_cast_transpose_with_padding(num_experts, nvte_in.data(), nvte_out.data(), valid_rows_list.data(), stream);
HIP_CHECK(hipEventRecord(stop, stream));
HIP_CHECK(hipEventSynchronize(stop));

float ms = 0;
HIP_CHECK(hipEventElapsedTime(&ms, start, stop));
state.SetIterationTime(ms / 1000.0);
}

HIP_CHECK(hipEventDestroy(start));
HIP_CHECK(hipEventDestroy(stop));

size_t total_bytes = 0;
for (size_t e = 0; e < num_experts; e++) {
size_t actual = std::max(counts[e], size_t(1));
size_t padded = ((actual + kPadMultiple - 1) / kPadMultiple) * kPadMultiple;
total_bytes += actual * cols * sizeof(IType);
total_bytes += padded * cols * sizeof(fp8_e4m3) * 2;
}
set_bytes_processed(state, total_bytes);

state.counters["experts"] = num_experts;
state.counters["cols"] = cols;
state.counters["avg_tok"] = static_cast<double>(sum_tok) / num_experts;

HIP_CHECK(hipStreamDestroy(stream));
}

} // namespace

#define REGISTER_MCT(ITYPE, INAME) \
Expand All @@ -239,6 +416,29 @@ static void BM_MultiCastTranspose(benchmark::State &state) {
->Unit(benchmark::kMicrosecond) \
->UseManualTime();

#define REGISTER_PAD_MCT(ITYPE, INAME) \
BENCHMARK_TEMPLATE(BM_PaddingThenMCT, ITYPE) \
->Name("BM_PaddingThenMCT/" INAME "_E4M3/moe") \
MOE_BALANCED \
->Unit(benchmark::kMicrosecond) \
->UseManualTime(); \
BENCHMARK_TEMPLATE(BM_PaddingThenMCT, ITYPE) \
->Name("BM_PaddingThenMCT/" INAME "_E4M3/moe_skewed") \
MOE_SKEWED \
->Unit(benchmark::kMicrosecond) \
->UseManualTime(); \
BENCHMARK_TEMPLATE(BM_FusedPaddingMCT, ITYPE) \
->Name("BM_FusedPaddingMCT/" INAME "_E4M3/moe") \
MOE_BALANCED \
->Unit(benchmark::kMicrosecond) \
->UseManualTime(); \
BENCHMARK_TEMPLATE(BM_FusedPaddingMCT, ITYPE) \
->Name("BM_FusedPaddingMCT/" INAME "_E4M3/moe_skewed") \
MOE_SKEWED \
->Unit(benchmark::kMicrosecond) \
->UseManualTime();

REGISTER_MCT(hip_bfloat16, "BF16")
REGISTER_PAD_MCT(hip_bfloat16, "BF16")

BENCHMARK_MAIN();
163 changes: 163 additions & 0 deletions tests/cpp/operator/test_multi_cast_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,137 @@ void performTest() {
}
}

#ifdef __HIP_PLATFORM_AMD__
template <typename InputType, typename OutputType>
void compute_ref_with_padding(const std::vector<std::vector<InputType>> &input_list,
std::vector<std::vector<OutputType>> &output_c_list,
std::vector<std::vector<OutputType>> &output_t_list,
const std::vector<float> &scale_list,
std::vector<float> &amax_list,
const std::vector<size_t> &valid_height_list,
const std::vector<size_t> &padded_height_list,
const std::vector<size_t> &width_list) {
using compute_t = float;
for (size_t tensor_id = 0; tensor_id < input_list.size(); tensor_id++) {
const auto &input = input_list[tensor_id];
auto &output_c = output_c_list[tensor_id];
auto &output_t = output_t_list[tensor_id];
const compute_t scale = scale_list[tensor_id];
compute_t &amax = amax_list[tensor_id];
const size_t valid_h = valid_height_list[tensor_id];
const size_t padded_h = padded_height_list[tensor_id];
const size_t width = width_list[tensor_id];
amax = -1e100;
for (size_t i = 0; i < padded_h; i++) {
for (size_t j = 0; j < width; j++) {
if (i < valid_h) {
const compute_t x = static_cast<compute_t>(input[i * width + j]);
const OutputType y = static_cast<OutputType>(scale * x);
amax = fmaxf(amax, fabsf(x));
output_c[i * width + j] = y;
output_t[j * padded_h + i] = y;
} else {
const OutputType zero = static_cast<OutputType>(0.0f);
output_c[i * width + j] = zero;
output_t[j * padded_h + i] = zero;
}
}
}
}
}

template <typename InputType, typename OutputType>
void performTestWithPadding() {
using namespace test;

const DType itype = TypeInfo<InputType>::dtype;
const DType otype = TypeInfo<OutputType>::dtype;

// (valid_rows, padded_rows, cols)
const std::vector<std::tuple<size_t, size_t, size_t>> tensor_dims = {
{1, 16, 256},
{13, 16, 256},
{15, 16, 256},
{100, 112, 768},
{250, 256, 768},
{33, 48, 512},
{255, 256, 1024},
};
const size_t num_tensors = tensor_dims.size();

std::vector<Tensor> input_list, output_list;
std::vector<std::vector<InputType>> ref_input_list;
std::vector<std::vector<OutputType>> ref_output_c_list, ref_output_t_list;
std::vector<float> ref_scale_list(num_tensors), ref_amax_list(num_tensors);
std::vector<size_t> ref_valid_h_list(num_tensors), ref_padded_h_list(num_tensors);
std::vector<size_t> ref_width_list(num_tensors);
std::vector<int> valid_num_rows(num_tensors);

for (size_t tensor_id = 0; tensor_id < num_tensors; tensor_id++) {
const size_t valid_h = std::get<0>(tensor_dims[tensor_id]);
const size_t padded_h = std::get<1>(tensor_dims[tensor_id]);
const size_t width = std::get<2>(tensor_dims[tensor_id]);

input_list.emplace_back("input_" + std::to_string(tensor_id), std::vector<size_t>{valid_h, width}, itype);
output_list.emplace_back("output_" + std::to_string(tensor_id), std::vector<size_t>{padded_h, width},
otype, true, true);

auto &input = input_list.back();
auto &output = output_list.back();
fillUniform(&input);
setRandomScale(&output);

ref_input_list.emplace_back(valid_h * width);
ref_output_c_list.emplace_back(padded_h * width);
ref_output_t_list.emplace_back(width * padded_h);

std::copy(input.rowwise_cpu_dptr<InputType>(), input.rowwise_cpu_dptr<InputType>() + valid_h * width,
ref_input_list.back().begin());

ref_scale_list[tensor_id] = output.scale();
ref_valid_h_list[tensor_id] = valid_h;
ref_padded_h_list[tensor_id] = padded_h;
ref_width_list[tensor_id] = width;
valid_num_rows[tensor_id] = static_cast<int>(valid_h);
Comment thread
alextmagro marked this conversation as resolved.
}

auto make_nvte_vector = [](std::vector<Tensor> &tensor_list) -> std::vector<NVTETensor> {
std::vector<NVTETensor> nvte_tensor_list;
for (auto &tensor : tensor_list) {
nvte_tensor_list.emplace_back(tensor.data());
}
return nvte_tensor_list;
};
nvte_multi_cast_transpose_with_padding(num_tensors,
make_nvte_vector(input_list).data(),
make_nvte_vector(output_list).data(),
valid_num_rows.data(), 0);

compute_ref_with_padding<InputType, OutputType>(ref_input_list, ref_output_c_list, ref_output_t_list, ref_scale_list,
ref_amax_list, ref_valid_h_list, ref_padded_h_list, ref_width_list);

(void)cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

for (size_t tensor_id = 0; tensor_id < num_tensors; tensor_id++) {
const size_t valid_h = ref_valid_h_list[tensor_id];
const size_t padded_h = ref_padded_h_list[tensor_id];
const size_t width = ref_width_list[tensor_id];

if (isFp8Type(otype) && valid_h > 0) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_list[tensor_id].amax(), ref_amax_list[tensor_id], atol_amax, rtol_amax);
compareResults("scale_inv", output_list[tensor_id].rowwise_scale_inv(), 1.f / output_list[tensor_id].scale(),
atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_list[tensor_id], ref_output_c_list[tensor_id].data(), true, atol, rtol);
compareResults("output_t", output_list[tensor_id], ref_output_t_list[tensor_id].data(), false, atol, rtol);
}
}
#endif // #ifdef __HIP_PLATFORM_AMD__

} // namespace

class MultiCastTransposeTestSuite
Expand Down Expand Up @@ -187,3 +318,35 @@ INSTANTIATE_TEST_SUITE_P(
test::typeName(std::get<1>(info.param));
return name;
});

#ifdef __HIP_PLATFORM_AMD__
class MultiCastTransposeWithPaddingTestSuite
: public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType>> {};

TEST_P(MultiCastTransposeWithPaddingTestSuite, TestMultiCastTransposeWithPadding) {
using namespace transformer_engine;
using namespace test;

const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());

TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTestWithPadding<InputType, OutputType>();
);
);
}

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MultiCastTransposeWithPaddingTestSuite,
::testing::Combine(
::testing::Values(DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)),
[](const testing::TestParamInfo<MultiCastTransposeWithPaddingTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param));
return name;
});
#endif
Loading
Loading