diff --git a/benchmarks/cpp/cast/bench_multi_cast_transpose.cpp b/benchmarks/cpp/cast/bench_multi_cast_transpose.cpp index 8e556fafb..46eed00d6 100644 --- a/benchmarks/cpp/cast/bench_multi_cast_transpose.cpp +++ b/benchmarks/cpp/cast/bench_multi_cast_transpose.cpp @@ -11,6 +11,7 @@ #include "benchmark_utils.h" +#include "transformer_engine/padding_hip.h" #include "transformer_engine/transpose_hip.h" #include "transformer_engine/transformer_engine_hip.h" @@ -225,6 +226,182 @@ static void BM_MultiCastTranspose(benchmark::State &state) { HIP_CHECK(hipStreamDestroy(stream)); } +// Unfused baseline: separate padding kernel + cast_transpose +template +static void BM_PaddingThenMCT(benchmark::State &state) { + const size_t total_tokens = state.range(0); + const size_t cols = state.range(1); + const size_t num_experts = state.range(2); + const size_t top_k = state.range(3); + const size_t routing_mode = state.range(4); + + uint64_t seed = derive_seed(total_tokens, cols, num_experts, top_k, routing_mode); + + auto counts = (routing_mode == 0) + ? simulate_topk_balanced(total_tokens, num_experts, top_k, seed) + : simulate_topk_skewed(total_tokens, num_experts, top_k, seed); + + size_t sum_tok = std::accumulate(counts.begin(), counts.end(), size_t(0)); + + DType itype = std::is_same_v ? DType::kFloat32 : + std::is_same_v ? DType::kBFloat16 : + DType::kFloat16; + + std::string pfx = "pad_mct_" + std::to_string(total_tokens) + "_" + std::to_string(cols) + "_" + + std::to_string(num_experts) + "_" + std::to_string(routing_mode); + + std::vector nvte_in(num_experts), nvte_pad_out(num_experts), + nvte_mct_in(num_experts), nvte_mct_out(num_experts); + + std::vector padded_rows_list(num_experts); + + for (size_t e = 0; e < num_experts; e++) { + size_t actual = std::max(counts[e], size_t(1)); + size_t padded = ((actual + kPadMultiple - 1) / kPadMultiple) * kPadMultiple; + padded_rows_list[e] = static_cast(padded); + + auto &input = TensorCache::get_or_create(pfx + "_in_" + std::to_string(e), {actual, cols}, itype, + true, false, NVTE_DELAYED_TENSOR_SCALING, true); + + auto &pad_out = TensorCache::get_or_create(pfx + "_pad_" + std::to_string(e), {padded, cols}, itype, + true, false, NVTE_DELAYED_TENSOR_SCALING, false); + + auto &mct_out = TensorCache::get_or_create(pfx + "_out_" + std::to_string(e), {padded, cols}, DType::kFloat8E4M3, + true, true, NVTE_DELAYED_TENSOR_SCALING, false); + + mct_out.set_scale(1.0f); + + nvte_in[e] = input.data(); + nvte_pad_out[e] = pad_out.data(); + nvte_mct_in[e] = pad_out.data(); + nvte_mct_out[e] = mct_out.data(); + } + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + warmup_gpu(); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + nvte_multi_padding(num_experts, nvte_in.data(), nvte_pad_out.data(), padded_rows_list.data(), stream); + nvte_multi_cast_transpose(num_experts, nvte_mct_in.data(), nvte_mct_out.data(), stream); + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + size_t total_bytes = 0; + for (size_t e = 0; e < num_experts; e++) { + size_t actual = std::max(counts[e], size_t(1)); + size_t padded = padded_rows_list[e]; + total_bytes += actual * cols * sizeof(IType); + total_bytes += padded * cols * sizeof(IType); + total_bytes += padded * cols * sizeof(IType); + total_bytes += padded * cols * sizeof(fp8_e4m3) * 2; + } + set_bytes_processed(state, total_bytes); + + state.counters["experts"] = num_experts; + state.counters["cols"] = cols; + state.counters["avg_tok"] = static_cast(sum_tok) / num_experts; + + HIP_CHECK(hipStreamDestroy(stream)); +} + +// Fused: single cast_transpose kernel with built-in padding +template +static void BM_FusedPaddingMCT(benchmark::State &state) { + const size_t total_tokens = state.range(0); + const size_t cols = state.range(1); + const size_t num_experts = state.range(2); + const size_t top_k = state.range(3); + const size_t routing_mode = state.range(4); + + uint64_t seed = derive_seed(total_tokens, cols, num_experts, top_k, routing_mode); + + auto counts = (routing_mode == 0) + ? simulate_topk_balanced(total_tokens, num_experts, top_k, seed) + : simulate_topk_skewed(total_tokens, num_experts, top_k, seed); + + size_t sum_tok = std::accumulate(counts.begin(), counts.end(), size_t(0)); + + DType itype = std::is_same_v ? DType::kFloat32 : + std::is_same_v ? DType::kBFloat16 : + DType::kFloat16; + + std::string pfx = "fused_mct_" + std::to_string(total_tokens) + "_" + std::to_string(cols) + "_" + + std::to_string(num_experts) + "_" + std::to_string(routing_mode); + + std::vector nvte_in(num_experts), nvte_out(num_experts); + std::vector valid_rows_list(num_experts); + + for (size_t e = 0; e < num_experts; e++) { + size_t actual = std::max(counts[e], size_t(1)); + size_t padded = ((actual + kPadMultiple - 1) / kPadMultiple) * kPadMultiple; + valid_rows_list[e] = static_cast(actual); + + auto &input = TensorCache::get_or_create(pfx + "_in_" + std::to_string(e), {actual, cols}, itype, + true, false, NVTE_DELAYED_TENSOR_SCALING, true); + + auto &output = TensorCache::get_or_create(pfx + "_out_" + std::to_string(e), {padded, cols}, DType::kFloat8E4M3, + true, true, NVTE_DELAYED_TENSOR_SCALING, false); + + output.set_scale(1.0f); + + nvte_in[e] = input.data(); + nvte_out[e] = output.data(); + } + + hipStream_t stream; + HIP_CHECK(hipStreamCreate(&stream)); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + warmup_gpu(); + + for (auto _ : state) { + HIP_CHECK(hipEventRecord(start, stream)); + nvte_multi_cast_transpose_with_padding(num_experts, nvte_in.data(), nvte_out.data(), valid_rows_list.data(), stream); + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float ms = 0; + HIP_CHECK(hipEventElapsedTime(&ms, start, stop)); + state.SetIterationTime(ms / 1000.0); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + size_t total_bytes = 0; + for (size_t e = 0; e < num_experts; e++) { + size_t actual = std::max(counts[e], size_t(1)); + size_t padded = ((actual + kPadMultiple - 1) / kPadMultiple) * kPadMultiple; + total_bytes += actual * cols * sizeof(IType); + total_bytes += padded * cols * sizeof(fp8_e4m3) * 2; + } + set_bytes_processed(state, total_bytes); + + state.counters["experts"] = num_experts; + state.counters["cols"] = cols; + state.counters["avg_tok"] = static_cast(sum_tok) / num_experts; + + HIP_CHECK(hipStreamDestroy(stream)); +} + } // namespace #define REGISTER_MCT(ITYPE, INAME) \ @@ -239,6 +416,29 @@ static void BM_MultiCastTranspose(benchmark::State &state) { ->Unit(benchmark::kMicrosecond) \ ->UseManualTime(); +#define REGISTER_PAD_MCT(ITYPE, INAME) \ + BENCHMARK_TEMPLATE(BM_PaddingThenMCT, ITYPE) \ + ->Name("BM_PaddingThenMCT/" INAME "_E4M3/moe") \ + MOE_BALANCED \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_PaddingThenMCT, ITYPE) \ + ->Name("BM_PaddingThenMCT/" INAME "_E4M3/moe_skewed") \ + MOE_SKEWED \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_FusedPaddingMCT, ITYPE) \ + ->Name("BM_FusedPaddingMCT/" INAME "_E4M3/moe") \ + MOE_BALANCED \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); \ + BENCHMARK_TEMPLATE(BM_FusedPaddingMCT, ITYPE) \ + ->Name("BM_FusedPaddingMCT/" INAME "_E4M3/moe_skewed") \ + MOE_SKEWED \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); + REGISTER_MCT(hip_bfloat16, "BF16") +REGISTER_PAD_MCT(hip_bfloat16, "BF16") BENCHMARK_MAIN(); diff --git a/tests/cpp/operator/test_multi_cast_transpose.cu b/tests/cpp/operator/test_multi_cast_transpose.cu index 22db1b5dc..8890c6b9e 100644 --- a/tests/cpp/operator/test_multi_cast_transpose.cu +++ b/tests/cpp/operator/test_multi_cast_transpose.cu @@ -155,6 +155,137 @@ void performTest() { } } +#ifdef __HIP_PLATFORM_AMD__ +template +void compute_ref_with_padding(const std::vector> &input_list, + std::vector> &output_c_list, + std::vector> &output_t_list, + const std::vector &scale_list, + std::vector &amax_list, + const std::vector &valid_height_list, + const std::vector &padded_height_list, + const std::vector &width_list) { + using compute_t = float; + for (size_t tensor_id = 0; tensor_id < input_list.size(); tensor_id++) { + const auto &input = input_list[tensor_id]; + auto &output_c = output_c_list[tensor_id]; + auto &output_t = output_t_list[tensor_id]; + const compute_t scale = scale_list[tensor_id]; + compute_t &amax = amax_list[tensor_id]; + const size_t valid_h = valid_height_list[tensor_id]; + const size_t padded_h = padded_height_list[tensor_id]; + const size_t width = width_list[tensor_id]; + amax = -1e100; + for (size_t i = 0; i < padded_h; i++) { + for (size_t j = 0; j < width; j++) { + if (i < valid_h) { + const compute_t x = static_cast(input[i * width + j]); + const OutputType y = static_cast(scale * x); + amax = fmaxf(amax, fabsf(x)); + output_c[i * width + j] = y; + output_t[j * padded_h + i] = y; + } else { + const OutputType zero = static_cast(0.0f); + output_c[i * width + j] = zero; + output_t[j * padded_h + i] = zero; + } + } + } + } +} + +template +void performTestWithPadding() { + using namespace test; + + const DType itype = TypeInfo::dtype; + const DType otype = TypeInfo::dtype; + + // (valid_rows, padded_rows, cols) + const std::vector> tensor_dims = { + {1, 16, 256}, + {13, 16, 256}, + {15, 16, 256}, + {100, 112, 768}, + {250, 256, 768}, + {33, 48, 512}, + {255, 256, 1024}, + }; + const size_t num_tensors = tensor_dims.size(); + + std::vector input_list, output_list; + std::vector> ref_input_list; + std::vector> ref_output_c_list, ref_output_t_list; + std::vector ref_scale_list(num_tensors), ref_amax_list(num_tensors); + std::vector ref_valid_h_list(num_tensors), ref_padded_h_list(num_tensors); + std::vector ref_width_list(num_tensors); + std::vector valid_num_rows(num_tensors); + + for (size_t tensor_id = 0; tensor_id < num_tensors; tensor_id++) { + const size_t valid_h = std::get<0>(tensor_dims[tensor_id]); + const size_t padded_h = std::get<1>(tensor_dims[tensor_id]); + const size_t width = std::get<2>(tensor_dims[tensor_id]); + + input_list.emplace_back("input_" + std::to_string(tensor_id), std::vector{valid_h, width}, itype); + output_list.emplace_back("output_" + std::to_string(tensor_id), std::vector{padded_h, width}, + otype, true, true); + + auto &input = input_list.back(); + auto &output = output_list.back(); + fillUniform(&input); + setRandomScale(&output); + + ref_input_list.emplace_back(valid_h * width); + ref_output_c_list.emplace_back(padded_h * width); + ref_output_t_list.emplace_back(width * padded_h); + + std::copy(input.rowwise_cpu_dptr(), input.rowwise_cpu_dptr() + valid_h * width, + ref_input_list.back().begin()); + + ref_scale_list[tensor_id] = output.scale(); + ref_valid_h_list[tensor_id] = valid_h; + ref_padded_h_list[tensor_id] = padded_h; + ref_width_list[tensor_id] = width; + valid_num_rows[tensor_id] = static_cast(valid_h); + } + + auto make_nvte_vector = [](std::vector &tensor_list) -> std::vector { + std::vector nvte_tensor_list; + for (auto &tensor : tensor_list) { + nvte_tensor_list.emplace_back(tensor.data()); + } + return nvte_tensor_list; + }; + nvte_multi_cast_transpose_with_padding(num_tensors, + make_nvte_vector(input_list).data(), + make_nvte_vector(output_list).data(), + valid_num_rows.data(), 0); + + compute_ref_with_padding(ref_input_list, ref_output_c_list, ref_output_t_list, ref_scale_list, + ref_amax_list, ref_valid_h_list, ref_padded_h_list, ref_width_list); + + (void)cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + for (size_t tensor_id = 0; tensor_id < num_tensors; tensor_id++) { + const size_t valid_h = ref_valid_h_list[tensor_id]; + const size_t padded_h = ref_padded_h_list[tensor_id]; + const size_t width = ref_width_list[tensor_id]; + + if (isFp8Type(otype) && valid_h > 0) { + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("amax", output_list[tensor_id].amax(), ref_amax_list[tensor_id], atol_amax, rtol_amax); + compareResults("scale_inv", output_list[tensor_id].rowwise_scale_inv(), 1.f / output_list[tensor_id].scale(), + atol_amax, rtol_amax); + } + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_list[tensor_id], ref_output_c_list[tensor_id].data(), true, atol, rtol); + compareResults("output_t", output_list[tensor_id], ref_output_t_list[tensor_id].data(), false, atol, rtol); + } +} +#endif // #ifdef __HIP_PLATFORM_AMD__ + } // namespace class MultiCastTransposeTestSuite @@ -187,3 +318,35 @@ INSTANTIATE_TEST_SUITE_P( test::typeName(std::get<1>(info.param)); return name; }); + +#ifdef __HIP_PLATFORM_AMD__ +class MultiCastTransposeWithPaddingTestSuite + : public ::testing::TestWithParam> {}; + +TEST_P(MultiCastTransposeWithPaddingTestSuite, TestMultiCastTransposeWithPadding) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTestWithPadding(); + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + MultiCastTransposeWithPaddingTestSuite, + ::testing::Combine( + ::testing::Values(DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2)), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)); + return name; + }); +#endif diff --git a/tests/pytorch/test_sanity_hipified_cast_transpose.py b/tests/pytorch/test_sanity_hipified_cast_transpose.py index a6583cb77..5627295a9 100644 --- a/tests/pytorch/test_sanity_hipified_cast_transpose.py +++ b/tests/pytorch/test_sanity_hipified_cast_transpose.py @@ -64,3 +64,45 @@ def test_single_kernel_dispatch(shape, in_dtype, out_dtype, monkeypatch): assert len(ct_kernels) == 1, ( f"Expected exactly 1 cast_transpose kernel, got {len(ct_kernels)}: {ct_kernels}" ) + + +@pytest.mark.parametrize("num_experts", [4, 8, 64]) +@pytest.mark.parametrize("hidden_dim", [256, 1536, 4096]) +@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) +def test_fused_vs_unfused_padding_mct(num_experts, hidden_dim, fp8_dtype): + """Compare fused padding+MCT against unfused (pad BF16 then MCT).""" + actual_splits = [torch.randint(1, 200, (1,)).item() for _ in range(num_experts)] + align = 16 + padded_splits = [(m + align - 1) // align * align for m in actual_splits] + total_actual = sum(actual_splits) + total_padded = sum(padded_splits) + + inp_unpadded = _fill_uniform((total_actual, hidden_dim), torch.bfloat16) + + inp_padded = torch.zeros((total_padded, hidden_dim), dtype=torch.bfloat16, device="cuda") + src_offset = 0 + dst_offset = 0 + for actual, padded in zip(actual_splits, padded_splits): + inp_padded[dst_offset:dst_offset + actual] = inp_unpadded[src_offset:src_offset + actual] + src_offset += actual + dst_offset += padded + + scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") + amax = torch.zeros(1, dtype=torch.float32, device="cuda") + + def make_quantizers(): + return [Float8Quantizer(scale=scale.clone(), amax=amax.clone(), fp8_dtype=fp8_dtype, rowwise=True, columnwise=True) + for _ in range(num_experts)] + + outputs_unfused = tex.split_quantize(inp_padded, padded_splits, make_quantizers()) + + outputs_fused = tex.split_quantize(inp_unpadded, padded_splits, make_quantizers(), valid_split_sections=actual_splits) + + assert len(outputs_unfused) == len(outputs_fused) == num_experts + for e in range(num_experts): + data_u, trans_u = outputs_unfused[e].get_data_tensors(rowwise_data=True, columnwise_data=True) + data_f, trans_f = outputs_fused[e].get_data_tensors(rowwise_data=True, columnwise_data=True) + assert data_u.shape == data_f.shape, f"Expert {e}: rowwise shape mismatch" + assert torch.equal(data_u, data_f), f"Expert {e}: rowwise data mismatch" + assert trans_u.shape == trans_f.shape, f"Expert {e}: columnwise shape mismatch" + assert torch.equal(trans_u, trans_f), f"Expert {e}: columnwise data mismatch" diff --git a/transformer_engine/common/include/transformer_engine/transpose.h b/transformer_engine/common/include/transformer_engine/transpose.h index 659a48d97..ebddad947 100644 --- a/transformer_engine/common/include/transformer_engine/transpose.h +++ b/transformer_engine/common/include/transformer_engine/transpose.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -93,6 +95,23 @@ void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_outp void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, cudaStream_t stream); +#ifdef __HIP_PLATFORM_AMD__ +/*! \brief Cast and transpose multiple tensors with fused padding. + * + * Input tensors may have fewer rows than output tensors. Rows beyond + * valid_num_rows are zero-filled in the output and excluded from amax. + * + * \param[in] num_tensors Number of tensors. + * \param[in] input_list List of 2D input tensors (unpadded). + * \param[in,out] output_list List of casted tensors (padded). + * \param[in] valid_num_rows_list Per-tensor valid row count, or NULL + * for no padding (all rows valid). + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_multi_cast_transpose_with_padding(size_t num_tensors, const NVTETensor *input_list, NVTETensor *output_list, + const int *valid_num_rows_list, cudaStream_t stream); +#endif + /*! \brief Compute backward of GeLU operation on the input, then cast and transpose. * Additionally, reduce the result of the GeLU backward along the first dimension. * diff --git a/transformer_engine/common/transpose/multi_cast_transpose.cu b/transformer_engine/common/transpose/multi_cast_transpose.cu index 8e7dbd0d2..ce1f68b93 100644 --- a/transformer_engine/common/transpose/multi_cast_transpose.cu +++ b/transformer_engine/common/transpose/multi_cast_transpose.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -200,7 +202,7 @@ __global__ void __launch_bounds__(threads_per_block) } // namespace void multi_cast_transpose(const std::vector input_list, std::vector output_list, - cudaStream_t stream) { + cudaStream_t stream, const int *valid_num_rows_list = nullptr) { // Check that number of tensors is valid NVTE_CHECK(output_list.size() == input_list.size(), "Number of input and output tensors must match"); @@ -226,17 +228,36 @@ void multi_cast_transpose(const std::vector input_list, std::vector(valid_num_rows_list[tensor_id]), + "Input rows ", input.data.shape[0], " != valid_num_rows ", + valid_num_rows_list[tensor_id]); + NVTE_CHECK(output.data.shape[0] >= input.data.shape[0], + "Output rows ", output.data.shape[0], " < input rows ", input.data.shape[0]); + NVTE_CHECK(output.data.shape[1] == input.data.shape[1], + "Output cols ", output.data.shape[1], " != input cols ", input.data.shape[1]); + } else +#endif + { + NVTE_CHECK(output.data.shape == input.data.shape, "C output tensor shape ", output.data.shape, + "does not match input tensor shape ", input.data.shape); + } NVTE_CHECK(output.columnwise_data.shape.size() == 2, "T output tensor shape ", output.columnwise_data.shape, "does not match input tensor shape ", input.data.shape); NVTE_CHECK(output.columnwise_data.shape[0] == input.data.shape[1], "T output tensor shape ", output.columnwise_data.shape, "does not match input tensor shape ", input.data.shape); +#ifdef __HIP_PLATFORM_AMD__ + NVTE_CHECK(output.columnwise_data.shape[1] == output.data.shape[0], "T output tensor shape ", + output.columnwise_data.shape, "does not match output rows ", + output.data.shape[0]); +#else NVTE_CHECK(output.columnwise_data.shape[1] == input.data.shape[0], "T output tensor shape ", output.columnwise_data.shape, "does not match input tensor shape ", input.data.shape); +#endif } #ifdef __HIP_PLATFORM_AMD__ @@ -255,6 +276,7 @@ void multi_cast_transpose(const std::vector input_list, std::vector sinv_ptrs(n); std::vector rows(n); std::vector cols(n); + std::vector valid_rows(n); for (size_t i = 0; i < n; i++) { in_ptrs[i] = reinterpret_cast(input_list[i]->data.dptr); @@ -263,13 +285,17 @@ void multi_cast_transpose(const std::vector input_list, std::vector(output_list[i]->scale.dptr); amax_ptrs[i] = reinterpret_cast(output_list[i]->amax.dptr); sinv_ptrs[i] = reinterpret_cast(output_list[i]->scale_inv.dptr); - rows[i] = input_list[i]->data.shape[0]; + rows[i] = output_list[i]->data.shape[0]; cols[i] = input_list[i]->data.shape[1]; + valid_rows[i] = valid_num_rows_list + ? static_cast(valid_num_rows_list[i]) + : input_list[i]->data.shape[0]; } - rocm_multi_cast_transpose_dispatch(n, in_ptrs.data(), out_c_ptrs.data(), - out_t_ptrs.data(), scale_ptrs.data(), amax_ptrs.data(), sinv_ptrs.data(), rows.data(), - cols.data(), stream); + rocm_multi_cast_transpose_dispatch( + n, in_ptrs.data(), out_c_ptrs.data(), out_t_ptrs.data(), + scale_ptrs.data(), amax_ptrs.data(), sinv_ptrs.data(), + rows.data(), cols.data(), valid_rows.data(), stream); ); // NOLINT(*) ); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); @@ -330,7 +356,7 @@ void multi_cast_transpose(const std::vector input_list, std::vector(input_list[tensor_id]->data.dptr); + kernel_args.input_list[pos] = const_cast(input_list[tensor_id]->data.dptr); kernel_args.output_c_list[pos] = output_list[tensor_id]->data.dptr; kernel_args.output_t_list[pos] = output_list[tensor_id]->columnwise_data.dptr; kernel_args.scale_list[pos] = output_list[tensor_id]->scale.dptr; @@ -384,3 +410,19 @@ void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list, } multi_cast_transpose(input_list_, output_list_, stream); } + +#ifdef __HIP_PLATFORM_AMD__ +void nvte_multi_cast_transpose_with_padding(size_t num_tensors, const NVTETensor *input_list, + NVTETensor *output_list, + const int *valid_num_rows_list, + cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_cast_transpose_with_padding); + using namespace transformer_engine; + std::vector input_list_, output_list_; + for (size_t i = 0; i < num_tensors; i++) { + input_list_.push_back(convertNVTETensorCheck(input_list[i])); + output_list_.push_back(convertNVTETensorCheck(output_list[i])); + } + multi_cast_transpose(input_list_, output_list_, stream, valid_num_rows_list); +} +#endif diff --git a/transformer_engine/common/transpose/rocm_multi_cast_transpose.cuh b/transformer_engine/common/transpose/rocm_multi_cast_transpose.cuh index 4b668751b..4784bc586 100644 --- a/transformer_engine/common/transpose/rocm_multi_cast_transpose.cuh +++ b/transformer_engine/common/transpose/rocm_multi_cast_transpose.cuh @@ -19,6 +19,7 @@ struct RocmMultiCastTransposeArgs { void *amax_list[kMCTMaxTensors]; void *scale_inv_list[kMCTMaxTensors]; int num_rows_list[kMCTMaxTensors]; + int valid_num_rows_list[kMCTMaxTensors]; int row_length_list[kMCTMaxTensors]; int block_range[kMCTMaxTensors + 1]; int num_tensors; @@ -32,6 +33,7 @@ mct_cast_store( OType *__restrict__ output_c, const int row_length, const int num_rows, + const int valid_num_rows, const float scale, float &amax, NTVec (&local_t)[LOAD_SIZE / sizeof(IType)][ROCM_CT_WARP_SIZE / WARPS_PER_TILE], @@ -57,9 +59,11 @@ mct_cast_store( IVec in; OVecC out_c; - if (IS_EDGE && row >= num_rows) { + if (IS_EDGE && row >= valid_num_rows) { #pragma unroll - for (int j2 = 0; j2 < NVEC_IN; j2++) in.val[j2] = IType(0); + for (int j2 = 0; j2 < NVEC_IN; j2++) { + in.val[j2] = IType(0); + } } else { in.load(&input[row * row_length + col]); } @@ -72,9 +76,10 @@ mct_cast_store( const float v1 = (j2+1 < NVEC_IN) ? static_cast(in.val[j2+1]) : 0.0f; const float v2 = (j2+2 < NVEC_IN) ? static_cast(in.val[j2+2]) : 0.0f; const float v3 = (j2+3 < NVEC_IN) ? static_cast(in.val[j2+3]) : 0.0f; - if (!IS_EDGE || row < num_rows) + if (!IS_EDGE || row < valid_num_rows) { amax = fmaxf(amax, fmaxf(fmaxf(fabsf(v0), fabsf(v1)), fmaxf(fabsf(v2), fabsf(v3)))); + } uint32_t packed = rocm_pack_4xfloat8( v0 * scale, v1 * scale, v2 * scale, v3 * scale); uint8_t *bytes = reinterpret_cast(&packed); @@ -90,31 +95,28 @@ mct_cast_store( #pragma unroll for (int j2 = 0; j2 < NVEC_IN; j2++) { const float v = static_cast(in.val[j2]); - if (!IS_EDGE || row < num_rows) + if (!IS_EDGE || row < valid_num_rows) { amax = fmaxf(amax, fabsf(v)); + } const OType o = static_cast(v * scale); out_c.val[j2] = o; local_t[j2][iter].val[i2] = o; } } - if (!IS_EDGE || row < num_rows) + if (!IS_EDGE || row < num_rows) { out_c.nt_store(&output_c[row * row_length + col]); + } } } } template -__device__ __forceinline__ void -mct_transpose_store( - OType *__restrict__ output_t, - const int num_rows, +__device__ __forceinline__ void mct_transpose_store(OType *__restrict__ output_t, const int num_rows, NTVec (&smem)[ROCM_CT_WARP_SIZE][ROCM_CT_WARP_SIZE + 1], NTVec (&local_t)[LOAD_SIZE / sizeof(IType)][ROCM_CT_WARP_SIZE / WARPS_PER_TILE], - const int tidx, const int tidy, - const int row_base, const int col_base) -{ + const int tidx, const int tidy, const int row_base, const int col_base) { constexpr int NVEC_IN = LOAD_SIZE / sizeof(IType); constexpr int NVEC_OUT = STORE_SIZE / sizeof(OType); constexpr int NUM_ITERS = ROCM_CT_WARP_SIZE / WARPS_PER_TILE; @@ -177,6 +179,8 @@ rocm_multi_cast_transpose_kernel(RocmMultiCastTransposeArgs args) { const int num_rows = args.num_rows_list[tensor_id]; const int row_length = args.row_length_list[tensor_id]; + const int valid_num_rows = args.valid_num_rows_list[tensor_id]; + const IType *__restrict__ input = reinterpret_cast(args.input_list[tensor_id]); OType *__restrict__ output_c = reinterpret_cast(args.output_c_list[tensor_id]); OType *__restrict__ output_t = reinterpret_cast(args.output_t_list[tensor_id]); @@ -191,7 +195,7 @@ rocm_multi_cast_transpose_kernel(RocmMultiCastTransposeArgs args) { const int row_base = tile_m * TILE_ROWS; const int col_base = tile_n * TILE_COLS; - const bool is_edge = (row_base + TILE_ROWS > num_rows); + const bool is_edge = (row_base + TILE_ROWS > valid_num_rows); const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1.0f; float amax = 0.0f; @@ -202,14 +206,14 @@ rocm_multi_cast_transpose_kernel(RocmMultiCastTransposeArgs args) { if (is_edge) { mct_cast_store( - input, output_c, row_length, num_rows, scale, amax, local_t, + input, output_c, row_length, num_rows, valid_num_rows, scale, amax, local_t, tidx, tidy, row_base, col_base); mct_transpose_store( output_t, num_rows, smem, local_t, tidx, tidy, row_base, col_base); } else { mct_cast_store( - input, output_c, row_length, num_rows, scale, amax, local_t, + input, output_c, row_length, num_rows, valid_num_rows, scale, amax, local_t, tidx, tidy, row_base, col_base); mct_transpose_store( output_t, num_rows, smem, local_t, @@ -229,10 +233,12 @@ rocm_multi_cast_transpose_kernel(RocmMultiCastTransposeArgs args) { } template -void rocm_multi_cast_transpose_dispatch(size_t num_tensors, const IType *const *input_list, OType *const *output_c_list, - OType *const *output_t_list, const float *const *scale_list, float *const *amax_list, - float *const *scale_inv_list, const size_t *num_rows_list, - const size_t *row_length_list, hipStream_t stream) { +void rocm_multi_cast_transpose_dispatch(size_t num_tensors, const IType *const *input_list, + OType *const *output_c_list, OType *const *output_t_list, + const float *const *scale_list, float *const *amax_list, + float *const *scale_inv_list, const size_t *num_rows_list, + const size_t *row_length_list, const size_t *valid_num_rows_list, + hipStream_t stream) { constexpr int WPT = 16; constexpr int BLK = ROCM_CT_WARP_SIZE * WPT; constexpr int ISZ = sizeof(IType); @@ -280,14 +286,17 @@ void rocm_multi_cast_transpose_dispatch(size_t num_tensors, const IType *const * int tiles_n = cols / TILE_COLS; int tiles = tiles_m * tiles_n; - args.input_list[packed] = reinterpret_cast(input_list[i]); - args.output_c_list[packed] = reinterpret_cast(output_c_list[i]); - args.output_t_list[packed] = reinterpret_cast(output_t_list[i]); - args.scale_list[packed] = reinterpret_cast(scale_list[i]); - args.amax_list[packed] = amax_list[i]; - args.scale_inv_list[packed] = scale_inv_list[i]; - args.num_rows_list[packed] = rows; - args.row_length_list[packed] = cols; + int valid_rows = valid_num_rows_list ? valid_num_rows_list[i] : rows; + + args.input_list[packed] = reinterpret_cast(input_list[i]); + args.output_c_list[packed] = reinterpret_cast(output_c_list[i]); + args.output_t_list[packed] = reinterpret_cast(output_t_list[i]); + args.scale_list[packed] = reinterpret_cast(scale_list[i]); + args.amax_list[packed] = amax_list[i]; + args.scale_inv_list[packed] = scale_inv_list[i]; + args.num_rows_list[packed] = rows; + args.valid_num_rows_list[packed] = valid_rows; + args.row_length_list[packed] = cols; total_blocks += tiles; args.block_range[packed + 1] = total_blocks; packed++; diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index fd0013f3e..492df26b2 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -317,7 +317,11 @@ std::vector multi_tensor_quantize(const std::vector &ten std::vector split_quantize(const at::Tensor &tensor, const std::vector &split_sections, std::vector quantizer_list, - bool disable_bulk_allocation = false); + bool disable_bulk_allocation = false +#ifdef __HIP_PLATFORM_AMD__ + , std::optional> valid_split_sections = std::nullopt +#endif + ); /*************************************************************************************************** * Bias gradient fusions diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 9ed5502b7..d9a5d5870 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -7,6 +7,7 @@ ************************************************************************/ #include "transformer_engine/cast.h" +#include "transformer_engine/transpose.h" #include #include @@ -265,7 +266,11 @@ namespace { void multi_tensor_quantize_impl(const std::vector &input_list, std::vector &quantizer_py_list, std::vector> &quantizer_cpp_list, - std::vector &output_list) { + std::vector &output_list +#ifdef __HIP_PLATFORM_AMD__ + , const int *valid_num_rows = nullptr +#endif + ) { // Check number of tensors const size_t num_tensors = input_list.size(); NVTE_CHECK(quantizer_py_list.size() == num_tensors, "Expected ", num_tensors, @@ -300,8 +305,22 @@ void multi_tensor_quantize_impl(const std::vector &input_list, nvte_tensor_output_list.push_back(output_list[i].data()); } NVTE_SCOPED_GIL_RELEASE({ - nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(), - nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); +#ifdef __HIP_PLATFORM_AMD__ + if (valid_num_rows != nullptr) { + nvte_multi_cast_transpose_with_padding(nvte_tensor_input_list.size(), + nvte_tensor_input_list.data(), + nvte_tensor_output_list.data(), + valid_num_rows, + at::cuda::getCurrentCUDAStream()); + } else { +#endif + nvte_multi_cast_transpose(nvte_tensor_input_list.size(), + nvte_tensor_input_list.data(), + nvte_tensor_output_list.data(), + at::cuda::getCurrentCUDAStream()); +#ifdef __HIP_PLATFORM_AMD__ + } +#endif }); } else { // Quantize kernels individually @@ -1283,7 +1302,11 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, std::vector split_quantize(const at::Tensor &tensor, const std::vector &split_sections, std::vector quantizer_list, - bool disable_bulk_allocation) { + bool disable_bulk_allocation +#ifdef __HIP_PLATFORM_AMD__ + , std::optional> valid_split_sections +#endif + ) { init_extension(); // Check number of tensors @@ -1312,6 +1335,22 @@ std::vector split_quantize(const at::Tensor &tensor, size_t dim0_offset = 0; const size_t dim0_stride = input_shape[0] == 0 ? 0 : input_py.element_size() * input_size / input_shape[0]; +#ifdef __HIP_PLATFORM_AMD__ + const auto &input_sections = valid_split_sections.has_value() + ? *valid_split_sections : split_sections; + for (size_t i = 0; i < num_splits; i++) { + NVTE_CHECK(dim0_offset + input_sections[i] <= input_shape[0], + "Attempted to split tensor with shape=", input_shape, + " along dim 0 with split_sections=", split_sections); + void *split_dptr = static_cast(input_dptr + dim0_offset * dim0_stride); + std::vector in_shape = input_shape; + in_shape[0] = input_sections[i]; + input_list.emplace_back(makeTransformerEngineTensor(split_dptr, in_shape, input_dtype)); + split_shapes.push_back(input_shape); + split_shapes.back()[0] = split_sections[i]; + dim0_offset += input_sections[i]; + } +#else for (size_t i = 0; i < num_splits; ++i) { NVTE_CHECK(dim0_offset + split_sections[i] <= input_shape[0], "Attempted to split tensor with shape=", input_shape, @@ -1323,6 +1362,7 @@ std::vector split_quantize(const at::Tensor &tensor, input_list.emplace_back(makeTransformerEngineTensor(split_dptr, split_shape, input_dtype)); dim0_offset += split_sections[i]; } +#endif // Convert quantizers to C++ objects std::vector> quantizer_cpp_list; @@ -1428,9 +1468,25 @@ std::vector split_quantize(const at::Tensor &tensor, nvfp4_quantizers); break; } +#ifdef __HIP_PLATFORM_AMD__ + default: { + std::vector valid_num_rows_vec; + const int *valid_num_rows_ptr = nullptr; + if (valid_split_sections.has_value()) { + for (size_t s : *valid_split_sections) { + valid_num_rows_vec.push_back(static_cast(s)); + } + valid_num_rows_ptr = valid_num_rows_vec.data(); + } + multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, + output_cpp_list, valid_num_rows_ptr); + break; + } +#else default: // General multi-tensor quantization multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, output_cpp_list); +#endif } return output_py_list; diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 0ea760f89..eb42011b4 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -300,7 +300,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list")); m.def("split_quantize", &transformer_engine::pytorch::split_quantize, "Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"), - py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false); + py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false +#ifdef __HIP_PLATFORM_AMD__ + , py::arg("valid_split_sections") = py::none() +#endif + ); m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, "Grouped GEMM"); m.def("te_general_grouped_gemm_for_grouped_tensor", diff --git a/transformer_engine/pytorch/module/fp8_padding.py b/transformer_engine/pytorch/module/fp8_padding.py index a208a0534..3a0073a49 100644 --- a/transformer_engine/pytorch/module/fp8_padding.py +++ b/transformer_engine/pytorch/module/fp8_padding.py @@ -99,6 +99,15 @@ def __init__( self.num_gemms = num_gemms self.align_size = align_size + def compute_padded_splits(self, m_splits: List[int]) -> List[int]: + """Compute padded split sizes without allocating or copying data.""" + if self.align_size is None: + recipe = FP8GlobalStateManager.get_fp8_recipe() + self.align_size = get_align_size_for_quantization(recipe) + return [ + (m + self.align_size - 1) // self.align_size * self.align_size for m in m_splits + ] + @no_torch_dynamo() def forward( self, diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index a47912b4c..b70d9a9bd 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -104,6 +104,8 @@ def forward( save_original_input, debug, m_splits_tensor, + actual_m_splits, + unpad_output, ) = non_tensor_args # Check if Triton kernel should be used @@ -159,12 +161,13 @@ def forward( # Disable bulk allocation when CPU offloading is active: offloading skips small # tensors (like scales), but bulk allocation shares storage across all tensors, # so if scales can't be offloaded, nothing in the group can be offloaded. + fused_padding_kwargs = {} + if actual_m_splits is not None and IS_HIP_EXTENSION \ + and inp_view.shape[0] == sum(actual_m_splits): + fused_padding_kwargs["valid_split_sections"] = actual_m_splits inputmats = tex.split_quantize( - inp_view, - m_splits, - input_quantizers, - disable_bulk_allocation=cpu_offloading, - ) + inp_view, m_splits, input_quantizers, + disable_bulk_allocation=cpu_offloading, **fused_padding_kwargs) elif debug: inputmats = DebugQuantizer.multi_tensor_quantize( inp_view, input_quantizers, m_splits, activation_dtype @@ -237,6 +240,17 @@ def forward( **kwargs, ) + + output_unpadded = False + if unpad_output and actual_m_splits is not None and IS_HIP_EXTENSION and actual_m_splits != m_splits: + out_unpadded = torch.empty( + [sum(actual_m_splits), out.shape[-1]], + dtype=out.dtype, device=out.device, + ) + tex.fused_multi_row_unpadding(out, out_unpadded, m_splits, actual_m_splits) + out = out_unpadded + output_unpadded = True + if fp8_calibration: for i in range(num_gemms): # amax of input @@ -307,6 +321,8 @@ def forward( ctx.device = device ctx.output_quantizers = output_quantizers ctx.m_splits = m_splits + ctx.actual_m_splits = actual_m_splits if IS_HIP_EXTENSION else None + ctx.output_unpadded = output_unpadded ctx.m_splits_tensor = m_splits_tensor ctx.num_gemms = num_gemms ctx.activation_dtype = activation_dtype @@ -360,11 +376,22 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Preprocess grad output grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) + + bwd_fused_kwargs = {} + if ctx.output_unpadded and ctx.actual_m_splits is not None: + bwd_fused_kwargs["valid_split_sections"] = ctx.actual_m_splits + # Fused pad+MCT produces both rowwise (dgrad) and columnwise (wgrad). + for q in ctx.grad_output_quantizers: + if q is not None: + q.set_usage(rowwise=True, columnwise=True) grad_output = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms if ctx.fp8 and not ctx.debug: if ctx.use_bias: - grad_output_mats = torch.split(grad_output_view, ctx.m_splits) + grad_output_mats = torch.split( + grad_output_view, + ctx.actual_m_splits if ctx.output_unpadded else ctx.m_splits, + ) recipe = ctx.fp8_recipe if recipe.delayed() or recipe.float8_current_scaling() or recipe.mxfp8(): # Fused bias grad + quantize kernel @@ -378,17 +405,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], for i in range(ctx.num_gemms): grad_biases[i] = grad_output_mats[i].sum(dim=0) grad_output = tex.split_quantize( - grad_output_view, - ctx.m_splits, - ctx.grad_output_quantizers, - ) + grad_output_view, ctx.m_splits, + ctx.grad_output_quantizers, **bwd_fused_kwargs) else: # Multi-tensor quantize grad_output = tex.split_quantize( - grad_output_view, - ctx.m_splits, - ctx.grad_output_quantizers, - ) + grad_output_view, ctx.m_splits, + ctx.grad_output_quantizers, **bwd_fused_kwargs) elif ctx.debug: grad_output_mats = torch.split(grad_output_view, ctx.m_splits) for i in range(ctx.num_gemms): @@ -457,6 +480,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], **kwargs, ) + if ctx.actual_m_splits is not None and ctx.actual_m_splits != ctx.m_splits \ + and not ctx.output_unpadded: + dgrad_unpadded = torch.empty( + (sum(ctx.actual_m_splits), dgrad.shape[-1]), + dtype=dgrad.dtype, device=dgrad.device, + ) + tex.fused_multi_row_unpadding( + dgrad, dgrad_unpadded, ctx.m_splits, ctx.actual_m_splits, + ) + dgrad = dgrad_unpadded + if ctx.weights_requires_grad: wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD if ctx.fp8: @@ -495,7 +529,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], input_quantizer.set_usage(rowwise=False, columnwise=True) inputmats: list if ctx.fp8 and not ctx.debug: - inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) + save_fused_kwargs = {} + if ctx.actual_m_splits is not None and IS_HIP_EXTENSION \ + and inp_view.shape[0] == sum(ctx.actual_m_splits): + save_fused_kwargs["valid_split_sections"] = ctx.actual_m_splits + inputmats = tex.split_quantize( + inp_view, ctx.m_splits, ctx.input_quantizers, + **save_fused_kwargs) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( inp_view, @@ -907,6 +947,8 @@ def forward( m_splits: List[int], is_first_microbatch: Optional[bool] = None, m_splits_tensor: Optional[torch.Tensor] = None, + actual_m_splits: Optional[List[int]] = None, + unpad_output: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply the linear transformation to the input. @@ -930,6 +972,14 @@ def forward( * it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced) + actual_m_splits : Optional[List[int]], default = None + Unpadded per-group row counts when inp is unpadded and + m_splits contains the padded sizes. Used by the ROCm + fused-pad-cast-transpose path; ignored on CUDA. + unpad_output : bool, default = False + When True, unpad the GEMM output from sum(m_splits) to + sum(actual_m_splits) rows before returning. Used by the + ROCm fused-pad-cast-transpose path; ignored on CUDA. """ debug = self.is_debug_iter() @@ -994,6 +1044,8 @@ def forward( self.save_original_input, debug, m_splits_tensor, + actual_m_splits, + unpad_output, ) out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)