Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ static void swizzle_mxfp8_scales(test::Tensor &t, bool rowwise) {
nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, d_tmp, num_scales, cudaMemcpyDeviceToDevice));
t.set_with_gemm_swizzled_scales(true);
NVTE_CHECK_CUDA(cudaFree(d_tmp));
}

Expand Down Expand Up @@ -907,14 +908,6 @@ class ProdGEMMTestSuite : public ::testing::TestWithParam<ProdGemmConfig> {};
TEST_P(ProdGEMMTestSuite, TestMxfp8Dq) {
const auto& config = GetParam();

cudaDeviceProp prop;
(void)cudaGetDeviceProperties(&prop, 0);
const bool is_tn = config.transa && !config.transb;
if (prop.major == 12 && prop.minor == 5 && !is_tn) {
GTEST_SKIP() << "hipBLASLt MXFP8 GEMM non-TN layout is not supported on gfx1250: "
<< config.label;
}

TestParams params = {.m = config.m, .k = config.k, .n = config.n,
.use_bias = false, .use_gelu = false,
.transa = config.transa, .transb = config.transb,
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ __device__ inline void copy_2d_to_shared(T *sh_ptr_base, const T *g_ptr, size_t
size_t g_row = g_start_row + l_y;
size_t g_col_primitive_start = g_start_col + l_x_vec * N_VEC;

if (g_row < total_rows) {
if (g_row < total_rows && g_col_primitive_start < total_cols) {
const T* current_g_row_base_ptr = g_ptr + g_row * g_stride;
VectorizedLoader<T, N_VEC, ALIGNED_ACCESS>global_loader(current_g_row_base_ptr, total_cols);

Expand Down Expand Up @@ -72,7 +72,7 @@ __device__ inline void bulk_tensor_2d_shared_to_global(const T *sh_ptr_base, T *

shared_loader.load(l_x_vec, chunk_dim_x);

if (g_row < total_rows) {
if (g_row < total_rows && g_col_primitive_start < total_cols) {
global_storer.storage_.scratch_ = shared_loader.storage_.scratch_;
global_storer.store(g_col_primitive_start / N_VEC, total_cols);
}
Expand Down
253 changes: 250 additions & 3 deletions transformer_engine/common/gemm/rocm_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,169 @@ struct GemmParam {
int ldb = 0; // B column strides
};

constexpr int kMXFP8BlockSize = 32;
constexpr int kMXFP8ScaleGroupSize = 4;

__global__ void transpose_u8_kernel(const uint8_t* __restrict__ input,
uint8_t* __restrict__ output,
const size_t rows, const size_t cols) {
const size_t idx = static_cast<size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const size_t total = rows * cols;
if (idx >= total) return;

const size_t row = idx / cols;
const size_t col = idx % cols;
output[col * rows + row] = input[idx];
}

template <typename T>
__global__ void transpose_kernel(const T* __restrict__ input,
T* __restrict__ output,
const size_t rows, const size_t cols) {
const size_t idx = static_cast<size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const size_t total = rows * cols;
if (idx >= total) return;

const size_t row = idx / cols;
const size_t col = idx % cols;
output[col * rows + row] = input[idx];
}

__global__ void mxfp8_colwise_scale_to_rowwise_kernel(const uint8_t* __restrict__ input,
uint8_t* __restrict__ output,
const size_t k_scale,
const size_t padded_m,
const size_t valid_k_scale,
const size_t valid_m,
const bool input_swizzled,
const bool output_swizzled) {
const size_t idx = static_cast<size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const size_t total = k_scale * padded_m;
if (idx >= total) return;

const size_t m = idx / k_scale;
const size_t k = idx % k_scale;
const size_t group = k / kMXFP8ScaleGroupSize;
const size_t within = k % kMXFP8ScaleGroupSize;
const size_t swizzled_idx = group * (padded_m * kMXFP8ScaleGroupSize)
+ m * kMXFP8ScaleGroupSize + within;
const size_t in_idx = input_swizzled ? swizzled_idx : k * padded_m + m;
const size_t out_idx = output_swizzled ? swizzled_idx : m * k_scale + k;
output[out_idx] = (k < valid_k_scale && m < valid_m) ? input[in_idx] : 127;
}

void launch_transpose_u8(const void* input, void* output, const size_t rows, const size_t cols,
hipStream_t stream) {
constexpr int kBlockSize = 256;
const size_t total = rows * cols;
if (total == 0) return;
const int grid = static_cast<int>((total + kBlockSize - 1) / kBlockSize);
transpose_u8_kernel<<<grid, kBlockSize, 0, stream>>>(
reinterpret_cast<const uint8_t*>(input), reinterpret_cast<uint8_t*>(output), rows, cols);
NVTE_CHECK_CUDA(hipGetLastError());
}

template <typename T>
void launch_transpose_typed(const void* input, void* output, const size_t rows, const size_t cols,
hipStream_t stream) {
constexpr int kBlockSize = 256;
const size_t total = rows * cols;
if (total == 0) return;
const int grid = static_cast<int>((total + kBlockSize - 1) / kBlockSize);
transpose_kernel<T><<<grid, kBlockSize, 0, stream>>>(
reinterpret_cast<const T*>(input), reinterpret_cast<T*>(output), rows, cols);
NVTE_CHECK_CUDA(hipGetLastError());
}

void launch_transpose_output(const DType dtype, const void* input, void* output,
const size_t rows, const size_t cols, hipStream_t stream) {
switch (dtype) {
case DType::kFloat32:
launch_transpose_typed<uint32_t>(input, output, rows, cols, stream);
break;
case DType::kFloat16:
case DType::kBFloat16:
launch_transpose_typed<uint16_t>(input, output, rows, cols, stream);
break;
default:
NVTE_ERROR("Unsupported MXFP8 NT transpose output dtype: ", to_string(dtype));
}
}

void launch_mxfp8_colwise_scale_to_rowwise(const void* input, void* output,
const size_t k_scale, const size_t padded_m,
const size_t valid_k_scale, const size_t valid_m,
const bool input_swizzled,
const bool output_swizzled,
hipStream_t stream) {
constexpr int kBlockSize = 256;
const size_t total = k_scale * padded_m;
if (total == 0) return;
const int grid = static_cast<int>((total + kBlockSize - 1) / kBlockSize);
mxfp8_colwise_scale_to_rowwise_kernel<<<grid, kBlockSize, 0, stream>>>(
reinterpret_cast<const uint8_t*>(input), reinterpret_cast<uint8_t*>(output), k_scale,
padded_m, valid_k_scale, valid_m, input_swizzled, output_swizzled);
NVTE_CHECK_CUDA(hipGetLastError());
}

void* allocate_async_temp(std::vector<void*>& buffers, const size_t bytes, hipStream_t stream) {
if (bytes == 0) return nullptr;
void* ptr = nullptr;
NVTE_CHECK_CUDA(hipMallocAsync(&ptr, bytes, stream));
buffers.push_back(ptr);
return ptr;
}

void free_async_temps(const std::vector<void*>& buffers, hipStream_t stream) {
for (void* ptr : buffers) {
if (ptr != nullptr) {
NVTE_CHECK_CUDA(hipFreeAsync(ptr, stream));
}
}
}

Tensor make_mxfp8_rowwise_from_columnwise(const Tensor& input, std::vector<void*>& buffers,
const bool output_swizzled, hipStream_t stream) {
NVTE_CHECK(input.has_columnwise_data(), "MXFP8 transpose-to-TN requires column-wise data.");
NVTE_CHECK(input.columnwise_scale_inv.has_data(),
"MXFP8 transpose-to-TN requires column-wise scales.");

const auto& data_shape = input.columnwise_data.shape;
NVTE_CHECK(data_shape.size() >= 2, "MXFP8 transpose-to-TN expects at least 2D data.");
const size_t cols = data_shape.back();
const size_t rows = product(data_shape) / cols;
const size_t data_bytes = rows * cols * typeToSize(input.columnwise_data.dtype);
void* rowwise_data = allocate_async_temp(buffers, data_bytes, stream);
launch_transpose_u8(input.columnwise_data.dptr, rowwise_data, rows, cols, stream);

const auto& scale_shape = input.columnwise_scale_inv.shape;
NVTE_CHECK(scale_shape.size() == 2, "MXFP8 transpose-to-TN expects 2D column-wise scales.");
const size_t k_scale = scale_shape[0];
const size_t padded_m = scale_shape[1];
const size_t valid_k_scale = (rows + kMXFP8BlockSize - 1) / kMXFP8BlockSize;
const size_t valid_m = cols;
void* rowwise_scale = nullptr;
if (input.with_gemm_swizzled_scales && output_swizzled) {
rowwise_scale = input.columnwise_scale_inv.dptr;
} else {
const size_t scale_bytes = k_scale * padded_m * typeToSize(input.columnwise_scale_inv.dtype);
rowwise_scale = allocate_async_temp(buffers, scale_bytes, stream);
launch_mxfp8_colwise_scale_to_rowwise(input.columnwise_scale_inv.dptr, rowwise_scale, k_scale,
padded_m, valid_k_scale, valid_m,
input.with_gemm_swizzled_scales,
output_swizzled, stream);
}

Tensor output;
output.clear();
output.scaling_mode = NVTE_MXFP8_1D_SCALING;
output.data = SimpleTensor(rowwise_data, std::vector<size_t>{cols, rows}, input.columnwise_data.dtype);
output.scale_inv = SimpleTensor(rowwise_scale, std::vector<size_t>{padded_m, k_scale},
input.columnwise_scale_inv.dtype);
output.with_gemm_swizzled_scales = output_swizzled;
return output;
}

// FP4 e2m1 lookup table
__device__ constexpr float kFP4E2M1Table[16] = {
0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f,
Expand Down Expand Up @@ -1707,6 +1870,82 @@ void release_service_stream(hipStream_t stream, struct ServiceStreamCtl &ctl)
NVTE_CHECK_CUDA(hipEventDestroy(ctl.start_event));
}

bool try_mxfp8_non_tn_transpose_to_tn(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const Tensor *inputBias, Tensor *outputPreGelu,
cublasOperation_t transa, cublasOperation_t transb,
bool grad, void *workspace, size_t workspaceSize,
float alpha, float beta, bool use_split_accumulator,
int math_sm_count, hipStream_t stream,
hipblasLtHandle_t handle, int m, int n, int k) {
if (inputA->scaling_mode != NVTE_MXFP8_1D_SCALING ||
inputB->scaling_mode != NVTE_MXFP8_1D_SCALING) {
return false;
}

if (cuda::sm_arch() != 125) {
return false;
}

if (!((transa == CUBLAS_OP_N && transb == CUBLAS_OP_N) ||
(transa == CUBLAS_OP_N && transb == CUBLAS_OP_T))) {
return false;
}
if (beta != 0.0f || inputBias->data.dptr != nullptr || outputPreGelu->data.dptr != nullptr) {
return false;
}

const bool output_swizzled = cuda::sm_arch() == 125;
std::vector<void*> temp_buffers;
bool launched = false;
try {
const Tensor A_tn = make_mxfp8_rowwise_from_columnwise(*inputA, temp_buffers, output_swizzled,
stream);
Tensor B_tn;
const Tensor *A_for_gemm = &A_tn;
const Tensor *B_for_gemm = nullptr;
Tensor D_tn;
Tensor *D_for_gemm = outputD;
int tn_m = m;
int tn_n = n;
int tn_ldd = m;

if (transb == CUBLAS_OP_T) {
B_tn = make_mxfp8_rowwise_from_columnwise(*inputB, temp_buffers, output_swizzled, stream);
A_for_gemm = &B_tn;
B_for_gemm = &A_tn;

const size_t d_bytes = static_cast<size_t>(m) * n * typeToSize(outputD->data.dtype);
void* d_tn_ptr = allocate_async_temp(temp_buffers, d_bytes, stream);
D_tn.clear();
D_tn.data = SimpleTensor(d_tn_ptr, std::vector<size_t>{static_cast<size_t>(m),
static_cast<size_t>(n)},
outputD->data.dtype);
D_for_gemm = &D_tn;
tn_m = n;
tn_n = m;
tn_ldd = n;
} else {
B_for_gemm = inputB;
}

Tensor empty_bias;
Tensor empty_pre_gelu;
hipblaslt_gemm(A_for_gemm, B_for_gemm, D_for_gemm, &empty_bias, &empty_pre_gelu, tn_m, tn_n, k,
k, k, tn_ldd, CUBLAS_OP_T, CUBLAS_OP_N, grad, workspace, workspaceSize, alpha,
0.0f, use_split_accumulator, math_sm_count, stream, handle);

if (transb == CUBLAS_OP_T) {
launch_transpose_output(outputD->data.dtype, D_tn.data.dptr, outputD->data.dptr, m, n, stream);
}
launched = true;
} catch (...) {
free_async_temps(temp_buffers, stream);
throw;
}
free_async_temps(temp_buffers, stream);
return launched;
}

} // namespace


Expand Down Expand Up @@ -1759,6 +1998,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
ServiceStreamCtl ss_ctl;
bool use_service_stream =
(math_sm_count != 0) ? get_service_stream(math_sm_count, stream, ss_ctl) : false;
hipStream_t gemm_stream = use_service_stream ? ss_ctl.stream : stream;

int num_streams = nvte_get_num_compute_streams();
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < num_streams);
Expand All @@ -1773,9 +2013,16 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
handle = hipblaslt_handles[compute_stream_offset];
}

hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd, transa,
transb, grad, workspace, workspaceSize, alpha, beta, use_split_accumulator,
math_sm_count, use_service_stream ? ss_ctl.stream : stream, handle);
bool handled = try_mxfp8_non_tn_transpose_to_tn(inputA, inputB, outputD, inputBias,
outputPreGelu, transa, transb, grad,
workspace, workspaceSize, alpha, beta,
use_split_accumulator, math_sm_count,
gemm_stream, handle, m, n, k);
if (!handled) {
hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd, transa,
transb, grad, workspace, workspaceSize, alpha, beta, use_split_accumulator,
math_sm_count, gemm_stream, handle);
}

if (use_service_stream)
{
Expand Down
Loading