From d6790482ce0a3c73732996369c19916bfdae2bd1 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 16 Jun 2026 18:52:39 +0000 Subject: [PATCH 1/2] gfx1250 mxfp8 gemm: add NN/NT transpose workaround --- transformer_engine/common/gemm/rocm_gemm.cu | 253 +++++++++++++++++++- 1 file changed, 250 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 5574fb9cd..da9f2dffa 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -199,6 +199,169 @@ struct GemmParam { int ldb = 0; // B column strides }; +constexpr int kMXFP8BlockSize = 32; +constexpr int kMXFP8ScaleGroupSize = 4; + +__global__ void transpose_u8_kernel(const uint8_t* __restrict__ input, + uint8_t* __restrict__ output, + const size_t rows, const size_t cols) { + const size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const size_t total = rows * cols; + if (idx >= total) return; + + const size_t row = idx / cols; + const size_t col = idx % cols; + output[col * rows + row] = input[idx]; +} + +template +__global__ void transpose_kernel(const T* __restrict__ input, + T* __restrict__ output, + const size_t rows, const size_t cols) { + const size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const size_t total = rows * cols; + if (idx >= total) return; + + const size_t row = idx / cols; + const size_t col = idx % cols; + output[col * rows + row] = input[idx]; +} + +__global__ void mxfp8_colwise_scale_to_rowwise_kernel(const uint8_t* __restrict__ input, + uint8_t* __restrict__ output, + const size_t k_scale, + const size_t padded_m, + const size_t valid_k_scale, + const size_t valid_m, + const bool input_swizzled, + const bool output_swizzled) { + const size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const size_t total = k_scale * padded_m; + if (idx >= total) return; + + const size_t m = idx / k_scale; + const size_t k = idx % k_scale; + const size_t group = k / kMXFP8ScaleGroupSize; + const size_t within = k % kMXFP8ScaleGroupSize; + const size_t swizzled_idx = group * (padded_m * kMXFP8ScaleGroupSize) + + m * kMXFP8ScaleGroupSize + within; + const size_t in_idx = input_swizzled ? swizzled_idx : k * padded_m + m; + const size_t out_idx = output_swizzled ? swizzled_idx : m * k_scale + k; + output[out_idx] = (k < valid_k_scale && m < valid_m) ? input[in_idx] : 127; +} + +void launch_transpose_u8(const void* input, void* output, const size_t rows, const size_t cols, + hipStream_t stream) { + constexpr int kBlockSize = 256; + const size_t total = rows * cols; + if (total == 0) return; + const int grid = static_cast((total + kBlockSize - 1) / kBlockSize); + transpose_u8_kernel<<>>( + reinterpret_cast(input), reinterpret_cast(output), rows, cols); + NVTE_CHECK_CUDA(hipGetLastError()); +} + +template +void launch_transpose_typed(const void* input, void* output, const size_t rows, const size_t cols, + hipStream_t stream) { + constexpr int kBlockSize = 256; + const size_t total = rows * cols; + if (total == 0) return; + const int grid = static_cast((total + kBlockSize - 1) / kBlockSize); + transpose_kernel<<>>( + reinterpret_cast(input), reinterpret_cast(output), rows, cols); + NVTE_CHECK_CUDA(hipGetLastError()); +} + +void launch_transpose_output(const DType dtype, const void* input, void* output, + const size_t rows, const size_t cols, hipStream_t stream) { + switch (dtype) { + case DType::kFloat32: + launch_transpose_typed(input, output, rows, cols, stream); + break; + case DType::kFloat16: + case DType::kBFloat16: + launch_transpose_typed(input, output, rows, cols, stream); + break; + default: + NVTE_ERROR("Unsupported MXFP8 NT transpose output dtype: ", to_string(dtype)); + } +} + +void launch_mxfp8_colwise_scale_to_rowwise(const void* input, void* output, + const size_t k_scale, const size_t padded_m, + const size_t valid_k_scale, const size_t valid_m, + const bool input_swizzled, + const bool output_swizzled, + hipStream_t stream) { + constexpr int kBlockSize = 256; + const size_t total = k_scale * padded_m; + if (total == 0) return; + const int grid = static_cast((total + kBlockSize - 1) / kBlockSize); + mxfp8_colwise_scale_to_rowwise_kernel<<>>( + reinterpret_cast(input), reinterpret_cast(output), k_scale, + padded_m, valid_k_scale, valid_m, input_swizzled, output_swizzled); + NVTE_CHECK_CUDA(hipGetLastError()); +} + +void* allocate_async_temp(std::vector& buffers, const size_t bytes, hipStream_t stream) { + if (bytes == 0) return nullptr; + void* ptr = nullptr; + NVTE_CHECK_CUDA(hipMallocAsync(&ptr, bytes, stream)); + buffers.push_back(ptr); + return ptr; +} + +void free_async_temps(const std::vector& buffers, hipStream_t stream) { + for (void* ptr : buffers) { + if (ptr != nullptr) { + NVTE_CHECK_CUDA(hipFreeAsync(ptr, stream)); + } + } +} + +Tensor make_mxfp8_rowwise_from_columnwise(const Tensor& input, std::vector& buffers, + const bool output_swizzled, hipStream_t stream) { + NVTE_CHECK(input.has_columnwise_data(), "MXFP8 transpose-to-TN requires column-wise data."); + NVTE_CHECK(input.columnwise_scale_inv.has_data(), + "MXFP8 transpose-to-TN requires column-wise scales."); + + const auto& data_shape = input.columnwise_data.shape; + NVTE_CHECK(data_shape.size() >= 2, "MXFP8 transpose-to-TN expects at least 2D data."); + const size_t cols = data_shape.back(); + const size_t rows = product(data_shape) / cols; + const size_t data_bytes = rows * cols * typeToSize(input.columnwise_data.dtype); + void* rowwise_data = allocate_async_temp(buffers, data_bytes, stream); + launch_transpose_u8(input.columnwise_data.dptr, rowwise_data, rows, cols, stream); + + const auto& scale_shape = input.columnwise_scale_inv.shape; + NVTE_CHECK(scale_shape.size() == 2, "MXFP8 transpose-to-TN expects 2D column-wise scales."); + const size_t k_scale = scale_shape[0]; + const size_t padded_m = scale_shape[1]; + const size_t valid_k_scale = (rows + kMXFP8BlockSize - 1) / kMXFP8BlockSize; + const size_t valid_m = cols; + void* rowwise_scale = nullptr; + if (input.with_gemm_swizzled_scales && output_swizzled) { + rowwise_scale = input.columnwise_scale_inv.dptr; + } else { + const size_t scale_bytes = k_scale * padded_m * typeToSize(input.columnwise_scale_inv.dtype); + rowwise_scale = allocate_async_temp(buffers, scale_bytes, stream); + launch_mxfp8_colwise_scale_to_rowwise(input.columnwise_scale_inv.dptr, rowwise_scale, k_scale, + padded_m, valid_k_scale, valid_m, + input.with_gemm_swizzled_scales, + output_swizzled, stream); + } + + Tensor output; + output.clear(); + output.scaling_mode = NVTE_MXFP8_1D_SCALING; + output.data = SimpleTensor(rowwise_data, std::vector{cols, rows}, input.columnwise_data.dtype); + output.scale_inv = SimpleTensor(rowwise_scale, std::vector{padded_m, k_scale}, + input.columnwise_scale_inv.dtype); + output.with_gemm_swizzled_scales = output_swizzled; + return output; +} + // FP4 e2m1 lookup table __device__ constexpr float kFP4E2M1Table[16] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, @@ -1707,6 +1870,82 @@ void release_service_stream(hipStream_t stream, struct ServiceStreamCtl &ctl) NVTE_CHECK_CUDA(hipEventDestroy(ctl.start_event)); } +bool try_mxfp8_non_tn_transpose_to_tn(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, + const Tensor *inputBias, Tensor *outputPreGelu, + cublasOperation_t transa, cublasOperation_t transb, + bool grad, void *workspace, size_t workspaceSize, + float alpha, float beta, bool use_split_accumulator, + int math_sm_count, hipStream_t stream, + hipblasLtHandle_t handle, int m, int n, int k) { + if (inputA->scaling_mode != NVTE_MXFP8_1D_SCALING || + inputB->scaling_mode != NVTE_MXFP8_1D_SCALING) { + return false; + } + + if (cuda::sm_arch() != 125) { + return false; + } + + if (!((transa == CUBLAS_OP_N && transb == CUBLAS_OP_N) || + (transa == CUBLAS_OP_N && transb == CUBLAS_OP_T))) { + return false; + } + if (beta != 0.0f || inputBias->data.dptr != nullptr || outputPreGelu->data.dptr != nullptr) { + return false; + } + + const bool output_swizzled = cuda::sm_arch() == 125; + std::vector temp_buffers; + bool launched = false; + try { + const Tensor A_tn = make_mxfp8_rowwise_from_columnwise(*inputA, temp_buffers, output_swizzled, + stream); + Tensor B_tn; + const Tensor *A_for_gemm = &A_tn; + const Tensor *B_for_gemm = nullptr; + Tensor D_tn; + Tensor *D_for_gemm = outputD; + int tn_m = m; + int tn_n = n; + int tn_ldd = m; + + if (transb == CUBLAS_OP_T) { + B_tn = make_mxfp8_rowwise_from_columnwise(*inputB, temp_buffers, output_swizzled, stream); + A_for_gemm = &B_tn; + B_for_gemm = &A_tn; + + const size_t d_bytes = static_cast(m) * n * typeToSize(outputD->data.dtype); + void* d_tn_ptr = allocate_async_temp(temp_buffers, d_bytes, stream); + D_tn.clear(); + D_tn.data = SimpleTensor(d_tn_ptr, std::vector{static_cast(m), + static_cast(n)}, + outputD->data.dtype); + D_for_gemm = &D_tn; + tn_m = n; + tn_n = m; + tn_ldd = n; + } else { + B_for_gemm = inputB; + } + + Tensor empty_bias; + Tensor empty_pre_gelu; + hipblaslt_gemm(A_for_gemm, B_for_gemm, D_for_gemm, &empty_bias, &empty_pre_gelu, tn_m, tn_n, k, + k, k, tn_ldd, CUBLAS_OP_T, CUBLAS_OP_N, grad, workspace, workspaceSize, alpha, + 0.0f, use_split_accumulator, math_sm_count, stream, handle); + + if (transb == CUBLAS_OP_T) { + launch_transpose_output(outputD->data.dtype, D_tn.data.dptr, outputD->data.dptr, m, n, stream); + } + launched = true; + } catch (...) { + free_async_temps(temp_buffers, stream); + throw; + } + free_async_temps(temp_buffers, stream); + return launched; +} + } // namespace @@ -1759,6 +1998,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ServiceStreamCtl ss_ctl; bool use_service_stream = (math_sm_count != 0) ? get_service_stream(math_sm_count, stream, ss_ctl) : false; + hipStream_t gemm_stream = use_service_stream ? ss_ctl.stream : stream; int num_streams = nvte_get_num_compute_streams(); NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < num_streams); @@ -1773,9 +2013,16 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, handle = hipblaslt_handles[compute_stream_offset]; } - hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd, transa, - transb, grad, workspace, workspaceSize, alpha, beta, use_split_accumulator, - math_sm_count, use_service_stream ? ss_ctl.stream : stream, handle); + bool handled = try_mxfp8_non_tn_transpose_to_tn(inputA, inputB, outputD, inputBias, + outputPreGelu, transa, transb, grad, + workspace, workspaceSize, alpha, beta, + use_split_accumulator, math_sm_count, + gemm_stream, handle, m, n, k); + if (!handled) { + hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd, transa, + transb, grad, workspace, workspaceSize, alpha, beta, use_split_accumulator, + math_sm_count, gemm_stream, handle); + } if (use_service_stream) { From 2ecb8864c02f6358220f33c7a9f1a9fc5ca50af4 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 16 Jun 2026 20:08:08 +0000 Subject: [PATCH 2/2] reenable NN/NT tests, fix OOB copy --- tests/cpp/operator/test_cublaslt_gemm.cu | 9 +-------- .../common/cast/mxfp8/rocm_vectorized_2d.cuh | 4 ++-- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index bc57df2c0..65f80572d 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -422,6 +422,7 @@ static void swizzle_mxfp8_scales(test::Tensor &t, bool rowwise) { nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0); NVTE_CHECK_CUDA(cudaDeviceSynchronize()); NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, d_tmp, num_scales, cudaMemcpyDeviceToDevice)); + t.set_with_gemm_swizzled_scales(true); NVTE_CHECK_CUDA(cudaFree(d_tmp)); } @@ -907,14 +908,6 @@ class ProdGEMMTestSuite : public ::testing::TestWithParam {}; TEST_P(ProdGEMMTestSuite, TestMxfp8Dq) { const auto& config = GetParam(); - cudaDeviceProp prop; - (void)cudaGetDeviceProperties(&prop, 0); - const bool is_tn = config.transa && !config.transb; - if (prop.major == 12 && prop.minor == 5 && !is_tn) { - GTEST_SKIP() << "hipBLASLt MXFP8 GEMM non-TN layout is not supported on gfx1250: " - << config.label; - } - TestParams params = {.m = config.m, .k = config.k, .n = config.n, .use_bias = false, .use_gelu = false, .transa = config.transa, .transb = config.transb, diff --git a/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh b/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh index 81dc46a85..e67440fb1 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh @@ -25,7 +25,7 @@ __device__ inline void copy_2d_to_shared(T *sh_ptr_base, const T *g_ptr, size_t size_t g_row = g_start_row + l_y; size_t g_col_primitive_start = g_start_col + l_x_vec * N_VEC; - if (g_row < total_rows) { + if (g_row < total_rows && g_col_primitive_start < total_cols) { const T* current_g_row_base_ptr = g_ptr + g_row * g_stride; VectorizedLoaderglobal_loader(current_g_row_base_ptr, total_cols); @@ -72,7 +72,7 @@ __device__ inline void bulk_tensor_2d_shared_to_global(const T *sh_ptr_base, T * shared_loader.load(l_x_vec, chunk_dim_x); - if (g_row < total_rows) { + if (g_row < total_rows && g_col_primitive_start < total_cols) { global_storer.storage_.scratch_ = shared_loader.storage_.scratch_; global_storer.store(g_col_primitive_start / N_VEC, total_cols); }