From e36286044947e81b4e76e519ba032620ce3a608b Mon Sep 17 00:00:00 2001 From: Abhishek Date: Tue, 21 Apr 2026 23:54:14 -0700 Subject: [PATCH 1/8] feat: add support for grouped GEMM swizzling with variable shapes and update C++ operator interface Signed-off-by: Abhishek --- tests/cpp/operator/test_swizzle.cu | 145 ++++++++++- tests/cpp/test_common.cu | 14 +- .../include/transformer_engine/swizzle.h | 2 +- transformer_engine/common/swizzle/swizzle.cu | 228 +++++++++++++++++- .../pytorch/csrc/extensions/swizzle.cpp | 14 +- 5 files changed, 381 insertions(+), 22 deletions(-) diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index 806a2482ab..d74e12cb55 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -281,7 +281,7 @@ void performTestGroupedSwizzleMXFP8(const int num_tensors, const size_t M, const NVTE_CHECK_CUDA(cudaMemset(grouped_output.columnwise_scale_inv.get(), 0, num_tensors * col_numel)); nvte_swizzle_grouped_scaling_factors(grouped_input.get_handle(), - grouped_output.get_handle(), 0); + grouped_output.get_handle(), nullptr, 0); std::vector output_row(num_tensors * row_numel); std::vector output_col(num_tensors * col_numel); @@ -481,7 +481,7 @@ void performTestGroupedSwizzleUnswizzleRoundtrip(const int num_tensors, const si NVTE_CHECK_CUDA(cudaMemset(grouped_fin.scale_inv.get(), 0, num_tensors * row_numel)); NVTE_CHECK_CUDA(cudaMemset(grouped_fin.columnwise_scale_inv.get(), 0, num_tensors * col_numel)); - nvte_swizzle_grouped_scaling_factors(grouped_orig.get_handle(), grouped_mid.get_handle(), 0); + nvte_swizzle_grouped_scaling_factors(grouped_orig.get_handle(), grouped_mid.get_handle(), nullptr, 0); nvte_unswizzle_grouped_scaling_factors(grouped_mid.get_handle(), grouped_fin.get_handle(), 0); std::vector result_row(num_tensors * row_numel); @@ -506,6 +506,147 @@ void performTestGroupedSwizzleUnswizzleRoundtrip(const int num_tensors, const si num_tensors * col_numel); } +void performTestGroupedSwizzleMXFP8Variable(const std::vector>& shapes) { + using namespace transformer_engine; + using namespace test; + + int num_tensors = shapes.size(); + std::vector> input_tensors; + std::vector> output_tensors; + std::vector input_ptrs; + std::vector output_ptrs; + input_tensors.reserve(num_tensors); + output_tensors.reserve(num_tensors); + input_ptrs.reserve(num_tensors); + output_ptrs.reserve(num_tensors); + + constexpr size_t BLOCK_SIZE = 32; + for (int i = 0; i < num_tensors; ++i) { + const std::vector shape{shapes[i].first, shapes[i].second}; + auto input = std::make_unique("input_" + std::to_string(i), shape, + DType::kFloat8E4M3, true, true, + NVTE_MXFP8_1D_SCALING); + auto output = std::make_unique("output_" + std::to_string(i), shape, + DType::kFloat8E4M3, true, true, + NVTE_MXFP8_1D_SCALING); + fillUniform(input.get()); + fillUniform(output.get()); + + // Zero padding + input->to_cpu(); + const NVTEShape rs = input->rowwise_scale_inv_shape(); + zero_scale_inv_padding(input->rowwise_cpu_scale_inv_ptr(), + rs.data[0], rs.data[1], + shapes[i].first, (shapes[i].second + BLOCK_SIZE - 1) / BLOCK_SIZE); + const NVTEShape cs = input->columnwise_scale_inv_shape(); + zero_scale_inv_padding(input->columnwise_cpu_scale_inv_ptr(), + cs.data[0], cs.data[1], + (shapes[i].first + BLOCK_SIZE - 1) / BLOCK_SIZE, shapes[i].second); + input->from_cpu(); + + input_ptrs.push_back(input.get()); + output_ptrs.push_back(output.get()); + input_tensors.emplace_back(std::move(input)); + output_tensors.emplace_back(std::move(output)); + } + + GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING); + GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING); + + const uint8_t input_swizzled = 0; + nvte_set_grouped_tensor_param(grouped_input.get_handle(), + kNVTEGroupedWithGEMMSwizzledScales, + &input_swizzled, sizeof(input_swizzled)); + const uint8_t output_swizzled = 1; + nvte_set_grouped_tensor_param(grouped_output.get_handle(), + kNVTEGroupedWithGEMMSwizzledScales, + &output_swizzled, sizeof(output_swizzled)); + + // Workspace allocation + size_t num_int_elems = num_tensors + 3; + if (num_int_elems % 2 != 0) num_int_elems++; + size_t workspace_size = num_int_elems * sizeof(int) + (num_tensors + 1) * sizeof(size_t); + workspace_size = (workspace_size + 255) & ~255; // roundup to 256 + void* d_workspace; + NVTE_CHECK_CUDA(cudaMalloc(&d_workspace, workspace_size)); + + nvte_swizzle_grouped_scaling_factors(grouped_input.get_handle(), + grouped_output.get_handle(), + d_workspace, 0); + + cudaDeviceSynchronize(); + NVTE_CHECK_CUDA(cudaGetLastError()); + + // Verification + size_t row_offset = 0; + size_t col_offset = 0; + for (int i = 0; i < num_tensors; ++i) { + const NVTEShape row_shape = input_tensors[i]->rowwise_scale_inv_shape(); + const NVTEShape col_shape = input_tensors[i]->columnwise_scale_inv_shape(); + const size_t row_numel = row_shape.data[0] * row_shape.data[1]; + const size_t col_numel = col_shape.data[0] * col_shape.data[1]; + + std::vector output_row_host(row_numel); + std::vector output_col_host(col_numel); + NVTE_CHECK_CUDA(cudaMemcpy(output_row_host.data(), + static_cast(grouped_output.scale_inv.get()) + row_offset, + row_numel, cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(output_col_host.data(), + static_cast(grouped_output.columnwise_scale_inv.get()) + col_offset, + col_numel, cudaMemcpyDeviceToHost)); + + std::vector ref_row(row_numel); + std::vector ref_col(col_numel); + compute_ref_swizzle<128, 4, true>(input_tensors[i]->rowwise_cpu_scale_inv_ptr(), + ref_row.data(), + row_shape.data[0], row_shape.data[1]); + compute_ref_swizzle<128, 4, false>( + input_tensors[i]->columnwise_cpu_scale_inv_ptr(), + ref_col.data(), + col_shape.data[1], col_shape.data[0]); + + compareResults("grouped_swizzle_variable_rowwise_" + std::to_string(i), + output_row_host.data(), ref_row.data(), row_numel); + compareResults("grouped_swizzle_variable_colwise_" + std::to_string(i), + output_col_host.data(), ref_col.data(), col_numel); + + row_offset += row_numel; + col_offset += col_numel; + } + NVTE_CHECK_CUDA(cudaFree(d_workspace)); +} + +class SwizzleGroupedVariableTestSuite + : public ::testing::TestWithParam>> {}; + +TEST_P(SwizzleGroupedVariableTestSuite, TestGroupedSwizzleMXFP8Variable) { + const auto shapes = GetParam(); + performTestGroupedSwizzleMXFP8Variable(shapes); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + SwizzleGroupedVariableTestSuite, + ::testing::Values( + // Case 1: num_tensors = 1 (n+3 = 4, even). Check simple alignment. + std::vector>{{1024, 1024}}, + + // Case 2: num_tensors = 2 (n+3 = 5, odd). Forces padding logic to trigger. + std::vector>{{128, 128}, {256, 256}}, + + // Case 3: Mixed small/irregular shapes. + std::vector>{{200, 160}, {33, 64}, {1, 32}}, + + // Case 4: Large workload to verify persistent grid (looping behavior). + // 10 tensors * (4096x4096 data) = 10 * (32x32 tiles) = 10,240 tiles total. + // This reliably exceeds the grid size on all modern GPUs. + std::vector>(10, {4096, 4096}) + ), + [](const testing::TestParamInfo& info) { + return "VariableShapes_" + std::to_string(info.index) + "_N" + std::to_string(info.param.size()); + } +); + class SwizzleGroupedTestSuite : public ::testing::TestWithParam> {}; diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 5196684118..b8bc38935f 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -1099,7 +1099,7 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, const bool same_last = std::all_of(last_dims.begin(), last_dims.end(), [&](int64_t v) { return v == last_dims[0]; }); - std::vector offsets(num_tensors, 0); + std::vector offsets(num_tensors + 1, 0); auto random_padding = [&]() -> int64_t { // Random padding ensuring 16-byte alignment regardless of element size // cuBLAS requires aligned pointers for vectorized loads @@ -1118,12 +1118,11 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, const bool need_offsets = !same_first || !same_last; const bool use_random_padding = need_offsets && scaling_mode != NVTE_MXFP8_1D_SCALING; if (need_offsets) { - offsets[0] = 0; - for (size_t i = 1; i < num_tensors; ++i) { + for (size_t i = 1; i < num_tensors + 1; ++i) { offsets[i] = offsets[i - 1] + numel(i - 1) + (use_random_padding ? random_padding() : 0); } } else { - for (size_t i = 0; i < num_tensors; ++i) { + for (size_t i = 0; i < num_tensors + 1; ++i) { offsets[i] = static_cast(i) * numel(0); } } @@ -1211,10 +1210,11 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, } if (!same_first || !same_last) { - grouped.offsets_dev = cuda_alloc(num_tensors * sizeof(int64_t)); + size_t num_off = num_tensors + 1; + grouped.offsets_dev = cuda_alloc(num_off * sizeof(int64_t)); NVTE_CHECK_CUDA(cudaMemcpy(grouped.offsets_dev.get(), offsets.data(), - num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); - NVTEShape off_shape = nvte_make_shape(&num_tensors, 1); + num_off * sizeof(int64_t), cudaMemcpyHostToDevice)); + NVTEShape off_shape = nvte_make_shape(&num_off, 1); NVTEBasicTensor off_tensor{grouped.offsets_dev.get(), kNVTEInt64, off_shape}; nvte_set_grouped_tensor_param(h, kNVTEGroupedTensorOffsets, &off_tensor, sizeof(off_tensor)); } diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index 4e28de3beb..98d245f36d 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -105,7 +105,7 @@ void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, * - all tensors in the grouped tensor must have the same shape. */ void nvte_swizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGroupedTensor output, - cudaStream_t stream); + void* workspace, cudaStream_t stream); /*! \brief Unswizzling scaling factors from the interleaved GEMM layout back to row-major (grouped) * diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 6c59776245..0c4a3917d4 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -1552,10 +1552,157 @@ void nvte_multi_tensor_unswizzle_scaling_factors(const NVTETensor* inputs, NVTET multi_tensor_unswizzle_scaling_factors(input_list, output_list, stream); } + namespace transformer_engine { +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + grouped_swizzle_scaling_variable_shape_kernel( + const void* input, + void* output, + const int64_t* m_array, + const int64_t* k_array, + const int* block_offsets, + const size_t* scale_offsets, + int* global_counter, + int num_tensors, + bool rowwise) { + + __shared__ int linear_block_id; + while (true) { + if (threadIdx.x == 0 && threadIdx.y == 0) { + linear_block_id = atomicAdd(global_counter, 1); + } + __syncthreads(); + + int tensor_id = -1; + int low = 0; + int high = num_tensors - 1; + while (low <= high) { + int mid = low + (high - low) / 2; + if (linear_block_id >= block_offsets[mid] && linear_block_id < block_offsets[mid + 1]) { + tensor_id = mid; + break; + } else if (linear_block_id < block_offsets[mid]) { + high = mid - 1; + } else { + low = mid + 1; + } + } + + if (tensor_id == -1) return; + + int local_block_id = linear_block_id - block_offsets[tensor_id]; + + size_t M = rowwise ? m_array[tensor_id] : k_array[tensor_id]; + size_t K = rowwise ? k_array[tensor_id] : m_array[tensor_id]; + + size_t padded_m = round_up_to_multiple(M, 128); + size_t padded_k = round_up_to_multiple(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE)), 4); + + int num_tiles_m = padded_m / SF_TILE_DIM_M; + int num_tiles_k = padded_k / SF_TILE_DIM_K; + + int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + + int grid_dim_x = rowwise ? DIVUP(num_tiles_k, n_tiles_in_tb) : DIVUP(num_tiles_k, TB_DIM); + int grid_dim_y = rowwise ? num_tiles_m : DIVUP(num_tiles_m, vec_load_size); + + int block_x = local_block_id % grid_dim_x; + int block_y = local_block_id / grid_dim_x; + + const uint8_t* input_base = reinterpret_cast(input) + scale_offsets[tensor_id]; + uint8_t* output_base = reinterpret_cast(output) + scale_offsets[tensor_id]; + + int original_M = static_cast(M); + int original_K = static_cast(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE))); + + if (rowwise) { + if (vec_load_size == 4) { + swizzle_row_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, + block_x, block_y, grid_dim_x, grid_dim_y); + } else if (vec_load_size == 2) { + swizzle_row_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, + block_x, block_y, grid_dim_x, grid_dim_y); + } else { + swizzle_row_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, + block_x, block_y, grid_dim_x, grid_dim_y); + } + } else { + if (vec_load_size == 4) { + swizzle_col_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, + block_x, block_y, grid_dim_x, grid_dim_y); + } else if (vec_load_size == 2) { + swizzle_col_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, + block_x, block_y, grid_dim_x, grid_dim_y); + } else { + swizzle_col_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, + block_x, block_y, grid_dim_x, grid_dim_y); + } + } + __syncthreads(); + } +} + +__global__ void compute_grouped_swizzle_setup( + const int64_t* m_array, + const int64_t* k_array, + int* block_offsets, + size_t* scale_offsets, + int* total_blocks, + int* global_counter, + size_t num_tensors, + bool rowwise, + size_t scale_elem_size) { + + if (blockIdx.x == 0 && threadIdx.x == 0) { + int current_block_offset = 0; + size_t current_scale_offset = 0; + + for (size_t i = 0; i < num_tensors; ++i) { + block_offsets[i] = current_block_offset; + scale_offsets[i] = current_scale_offset; + + size_t m = rowwise ? m_array[i] : k_array[i]; + size_t k = rowwise ? k_array[i] : m_array[i]; + + size_t padded_m = round_up_to_multiple(m, 128); + size_t padded_k = round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); + + int num_tiles_m = padded_m / 128; + int num_tiles_k = padded_k / 4; + + int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); + if (vec_load_size == 3) vec_load_size = 1; + + int blocks_m = num_tiles_m; + int blocks_k = DIVUP(num_tiles_k, TB_DIM * vec_load_size); + if (!rowwise) { + blocks_m = DIVUP(num_tiles_m, vec_load_size); + blocks_k = DIVUP(num_tiles_k, TB_DIM); + } + + current_block_offset += blocks_m * blocks_k; + current_scale_offset += padded_m * padded_k * scale_elem_size; + } + + block_offsets[num_tensors] = current_block_offset; + scale_offsets[num_tensors] = current_scale_offset; + *total_blocks = current_block_offset; + *global_counter = 0; + } +} + void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* output, - cudaStream_t stream) { + void* workspace, cudaStream_t stream) { // Check scaling mode NVTE_CHECK(input->scaling_mode == NVTE_MXFP8_1D_SCALING, "Grouped swizzle supports only MXFP8 scaling."); @@ -1575,10 +1722,15 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* return; } - // Only support uniform shapes for graph-safe grouped swizzle - NVTE_CHECK(input->all_same_shape(), "Grouped swizzle requires uniform tensor shapes."); - NVTE_CHECK(input->all_same_last_dim() && input->all_same_first_dim(), - "Grouped swizzle requires uniform tensor shapes."); + const int64_t* m_array = reinterpret_cast(input->first_dims.dptr); + const int64_t* k_array = reinterpret_cast(input->last_dims.dptr); + const bool is_variable_shape = (m_array != nullptr && k_array != nullptr); + + if (!is_variable_shape) { + // Fallback to uniform shape implementation + NVTE_CHECK(input->all_same_shape(), "Grouped swizzle requires uniform tensor shapes."); + NVTE_CHECK(input->all_same_last_dim() && input->all_same_first_dim(), + "Grouped swizzle requires uniform tensor shapes."); // Assumption is that all the tensors share the same shapes and are contgiuous. // And so we dont need to pass array of input/output pointers(due to conttiguity) @@ -1708,6 +1860,68 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* if (has_columnwise_scale_inv) { launch_grouped_swizzle(false); } + } else { + // Variable shape implementation using Device-Side Block Scheduler + size_t num_tensors = input->num_tensors; + NVTE_CHECK(workspace != nullptr, "Workspace must be provided for variable shape grouped swizzle."); + + size_t int_stride = num_tensors + 3; + if (int_stride % 2 != 0) int_stride++; + int* d_block_offsets = reinterpret_cast(workspace); + int* d_global_counter = d_block_offsets + num_tensors + 1; + int* d_total_blocks = d_global_counter + 1; + size_t* d_scale_offsets = reinterpret_cast(d_block_offsets + int_stride); + + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + const dim3 block_size(TB_DIM, TB_DIM); + const int max_slm_size = TB_DIM * 4 * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + + auto launch_grouped_swizzle_variable = [&](bool rowwise) { + const size_t scale_elem_size = rowwise ? typeToSize(input->scale_inv.dtype) + : typeToSize(input->columnwise_scale_inv.dtype); + + compute_grouped_swizzle_setup<<<1, 1, 0, stream>>>( + m_array, k_array, d_block_offsets, d_scale_offsets, d_total_blocks, + d_global_counter, num_tensors, rowwise, scale_elem_size); + + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_scaling_variable_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, max_slm_size)); + + int device_id; + cudaGetDevice(&device_id); + int num_SMs; + cudaDeviceGetAttribute(&num_SMs, cudaDevAttrMultiProcessorCount, device_id); + // Find out how many blocks of this specific kernel can fit on one SM + int max_active_blocks_per_sm; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_per_sm, + grouped_swizzle_scaling_variable_shape_kernel, + TB_DIM * TB_DIM, // block size + max_slm_size // dynamic shared memory + ); + int persistent_blocks = num_SMs * max_active_blocks_per_sm; + dim3 num_blocks(persistent_blocks); + + const void* input_ptr = rowwise ? input->scale_inv.dptr : input->columnwise_scale_inv.dptr; + void* output_ptr = rowwise ? output->scale_inv.dptr : output->columnwise_scale_inv.dptr; + + grouped_swizzle_scaling_variable_shape_kernel + <<>>( + input_ptr, output_ptr, m_array, k_array, d_block_offsets, + d_scale_offsets, d_global_counter, num_tensors, rowwise); + + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + if (has_rowwise_scale_inv) { + launch_grouped_swizzle_variable(true); + } + if (has_columnwise_scale_inv) { + launch_grouped_swizzle_variable(false); + } + } } void unswizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* output, @@ -1820,11 +2034,11 @@ void unswizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor } // namespace transformer_engine void nvte_swizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGroupedTensor output, - cudaStream_t stream) { + void* workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_swizzle_grouped_scaling_factors); using namespace transformer_engine; swizzle_grouped_scaling_factors(convertNVTEGroupedTensorCheck(input), - convertNVTEGroupedTensorCheck(output), stream); + convertNVTEGroupedTensorCheck(output), workspace, stream); } void nvte_unswizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGroupedTensor output, diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index a6b4e7569d..d9eb29dff9 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -357,11 +357,6 @@ std::optional maybe_swizzle_grouped_tensor(GroupedTensorW } const auto first_dims = input.get_first_dims(); const auto last_dims = input.get_last_dims(); - if (first_dims.data_ptr != nullptr || last_dims.data_ptr != nullptr) { - NVTE_ERROR( - "Grouped GEMM swizzle requires uniform shapes for now (first_dims/last_dims must be " - "absent)."); - } std::optional rowwise_scales_pyt; std::optional columnwise_scales_pyt; @@ -403,8 +398,17 @@ std::optional maybe_swizzle_grouped_tensor(GroupedTensorW } swizzle_output.set_with_gemm_swizzled_scales(true); + + size_t num_tensors = input.num_tensors(); + size_t num_int_elems = num_tensors + 3; // n+1 block_offsets + gc + tb + if (num_int_elems % 2 != 0) num_int_elems++; // pad to even for size_t alignment + size_t workspace_size = num_int_elems * sizeof(int) + (num_tensors + 1) * sizeof(size_t); + workspace_size = roundup(workspace_size, 256); + auto workspace = allocateSpace(std::vector{workspace_size}, transformer_engine::DType::kByte, false); + NVTE_SCOPED_GIL_RELEASE({ nvte_swizzle_grouped_scaling_factors(swizzle_input.data(), swizzle_output.data(), + getDataPtr(workspace), at::cuda::getCurrentCUDAStream()); }); From ab16ec7a4f21eb8b94dfa4a7f4dc6bfe8bc99c61 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Wed, 22 Apr 2026 00:28:56 -0700 Subject: [PATCH 2/8] Added confirmation with uniformity in one of the dimensions Signed-off-by: Abhishek --- tests/cpp/operator/test_swizzle.cu | 12 +++++--- transformer_engine/common/swizzle/swizzle.cu | 31 ++++++++++++++------ 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index d74e12cb55..1a8a7561f1 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -637,10 +637,14 @@ INSTANTIATE_TEST_SUITE_P( // Case 3: Mixed small/irregular shapes. std::vector>{{200, 160}, {33, 64}, {1, 32}}, - // Case 4: Large workload to verify persistent grid (looping behavior). - // 10 tensors * (4096x4096 data) = 10 * (32x32 tiles) = 10,240 tiles total. - // This reliably exceeds the grid size on all modern GPUs. - std::vector>(10, {4096, 4096}) + // Case 4: Large workload to verify persistent grid + std::vector>(10, {4096, 4096}), + + // Case 5: Variable M, Uniform K (Semi-variable) + std::vector>{{128, 256}, {512, 256}, {64, 256}}, + + // Case 6: Uniform M, Variable K (Semi-variable) + std::vector>{{512, 128}, {512, 1024}, {512, 32}} ), [](const testing::TestParamInfo& info) { return "VariableShapes_" + std::to_string(info.index) + "_N" + std::to_string(info.param.size()); diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 0c4a3917d4..640866a87d 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -1566,7 +1566,9 @@ __global__ void __launch_bounds__(TB_DIM* TB_DIM) const size_t* scale_offsets, int* global_counter, int num_tensors, - bool rowwise) { + bool rowwise, + size_t common_m, + size_t common_k) { __shared__ int linear_block_id; while (true) { @@ -1594,8 +1596,10 @@ __global__ void __launch_bounds__(TB_DIM* TB_DIM) int local_block_id = linear_block_id - block_offsets[tensor_id]; - size_t M = rowwise ? m_array[tensor_id] : k_array[tensor_id]; - size_t K = rowwise ? k_array[tensor_id] : m_array[tensor_id]; + size_t M = rowwise ? (m_array ? m_array[tensor_id] : common_m) + : (k_array ? k_array[tensor_id] : common_k); + size_t K = rowwise ? (k_array ? k_array[tensor_id] : common_k) + : (m_array ? m_array[tensor_id] : common_m); size_t padded_m = round_up_to_multiple(M, 128); size_t padded_k = round_up_to_multiple(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE)), 4); @@ -1661,7 +1665,9 @@ __global__ void compute_grouped_swizzle_setup( int* global_counter, size_t num_tensors, bool rowwise, - size_t scale_elem_size) { + size_t scale_elem_size, + size_t common_m, + size_t common_k) { if (blockIdx.x == 0 && threadIdx.x == 0) { int current_block_offset = 0; @@ -1671,8 +1677,10 @@ __global__ void compute_grouped_swizzle_setup( block_offsets[i] = current_block_offset; scale_offsets[i] = current_scale_offset; - size_t m = rowwise ? m_array[i] : k_array[i]; - size_t k = rowwise ? k_array[i] : m_array[i]; + size_t m = rowwise ? (m_array ? m_array[i] : common_m) + : (k_array ? k_array[i] : common_k); + size_t k = rowwise ? (k_array ? k_array[i] : common_k) + : (m_array ? m_array[i] : common_m); size_t padded_m = round_up_to_multiple(m, 128); size_t padded_k = round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); @@ -1724,7 +1732,7 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* const int64_t* m_array = reinterpret_cast(input->first_dims.dptr); const int64_t* k_array = reinterpret_cast(input->last_dims.dptr); - const bool is_variable_shape = (m_array != nullptr && k_array != nullptr); + const bool is_variable_shape = !input->all_same_shape(); if (!is_variable_shape) { // Fallback to uniform shape implementation @@ -1881,9 +1889,13 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* const size_t scale_elem_size = rowwise ? typeToSize(input->scale_inv.dtype) : typeToSize(input->columnwise_scale_inv.dtype); + size_t common_m = input->all_same_first_dim() ? input->get_common_first_dim() : 0; + size_t common_k = input->all_same_last_dim() ? input->get_common_last_dim() : 0; + compute_grouped_swizzle_setup<<<1, 1, 0, stream>>>( m_array, k_array, d_block_offsets, d_scale_offsets, d_total_blocks, - d_global_counter, num_tensors, rowwise, scale_elem_size); + d_global_counter, num_tensors, rowwise, scale_elem_size, + common_m, common_k); NVTE_CHECK_CUDA(cudaFuncSetAttribute( grouped_swizzle_scaling_variable_shape_kernel, @@ -1910,7 +1922,8 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* grouped_swizzle_scaling_variable_shape_kernel <<>>( input_ptr, output_ptr, m_array, k_array, d_block_offsets, - d_scale_offsets, d_global_counter, num_tensors, rowwise); + d_scale_offsets, d_global_counter, num_tensors, rowwise, + common_m, common_k); NVTE_CHECK_CUDA(cudaGetLastError()); }; From 16735fd62c36e42236d9745fcc263077b954c0a4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Apr 2026 07:42:29 +0000 Subject: [PATCH 3/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/cpp/operator/test_swizzle.cu | 20 +- transformer_engine/common/swizzle/swizzle.cu | 450 +++++++++--------- .../pytorch/csrc/extensions/swizzle.cpp | 12 +- 3 files changed, 234 insertions(+), 248 deletions(-) diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index 1a8a7561f1..d02017fc72 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -552,7 +552,7 @@ void performTestGroupedSwizzleMXFP8Variable(const std::vector output_row_host(row_numel); std::vector output_col_host(col_numel); - NVTE_CHECK_CUDA(cudaMemcpy(output_row_host.data(), - static_cast(grouped_output.scale_inv.get()) + row_offset, + NVTE_CHECK_CUDA(cudaMemcpy(output_row_host.data(), + static_cast(grouped_output.scale_inv.get()) + row_offset, row_numel, cudaMemcpyDeviceToHost)); - NVTE_CHECK_CUDA(cudaMemcpy(output_col_host.data(), - static_cast(grouped_output.columnwise_scale_inv.get()) + col_offset, + NVTE_CHECK_CUDA(cudaMemcpy(output_col_host.data(), + static_cast(grouped_output.columnwise_scale_inv.get()) + col_offset, col_numel, cudaMemcpyDeviceToHost)); std::vector ref_row(row_numel); @@ -605,9 +605,9 @@ void performTestGroupedSwizzleMXFP8Variable(const std::vector>{{1024, 1024}}, - + // Case 2: num_tensors = 2 (n+3 = 5, odd). Forces padding logic to trigger. std::vector>{{128, 128}, {256, 256}}, // Case 3: Mixed small/irregular shapes. std::vector>{{200, 160}, {33, 64}, {1, 32}}, - + // Case 4: Large workload to verify persistent grid std::vector>(10, {4096, 4096}), diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 640866a87d..6de1260aaf 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -1552,156 +1552,139 @@ void nvte_multi_tensor_unswizzle_scaling_factors(const NVTETensor* inputs, NVTET multi_tensor_unswizzle_scaling_factors(input_list, output_list, stream); } - namespace transformer_engine { template __global__ void __launch_bounds__(TB_DIM* TB_DIM) - grouped_swizzle_scaling_variable_shape_kernel( - const void* input, - void* output, - const int64_t* m_array, - const int64_t* k_array, - const int* block_offsets, - const size_t* scale_offsets, - int* global_counter, - int num_tensors, - bool rowwise, - size_t common_m, - size_t common_k) { - + grouped_swizzle_scaling_variable_shape_kernel(const void* input, void* output, + const int64_t* m_array, const int64_t* k_array, + const int* block_offsets, + const size_t* scale_offsets, int* global_counter, + int num_tensors, bool rowwise, size_t common_m, + size_t common_k) { __shared__ int linear_block_id; while (true) { - if (threadIdx.x == 0 && threadIdx.y == 0) { + if (threadIdx.x == 0 && threadIdx.y == 0) { linear_block_id = atomicAdd(global_counter, 1); - } - __syncthreads(); + } + __syncthreads(); - int tensor_id = -1; - int low = 0; - int high = num_tensors - 1; - while (low <= high) { + int tensor_id = -1; + int low = 0; + int high = num_tensors - 1; + while (low <= high) { int mid = low + (high - low) / 2; if (linear_block_id >= block_offsets[mid] && linear_block_id < block_offsets[mid + 1]) { - tensor_id = mid; - break; + tensor_id = mid; + break; } else if (linear_block_id < block_offsets[mid]) { - high = mid - 1; + high = mid - 1; } else { - low = mid + 1; + low = mid + 1; } - } + } - if (tensor_id == -1) return; - - int local_block_id = linear_block_id - block_offsets[tensor_id]; - - size_t M = rowwise ? (m_array ? m_array[tensor_id] : common_m) - : (k_array ? k_array[tensor_id] : common_k); - size_t K = rowwise ? (k_array ? k_array[tensor_id] : common_k) - : (m_array ? m_array[tensor_id] : common_m); - - size_t padded_m = round_up_to_multiple(M, 128); - size_t padded_k = round_up_to_multiple(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE)), 4); - - int num_tiles_m = padded_m / SF_TILE_DIM_M; - int num_tiles_k = padded_k / SF_TILE_DIM_K; - - int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); - if (vec_load_size == 3) vec_load_size = 1; - int n_tiles_in_tb = TB_DIM * vec_load_size; + if (tensor_id == -1) return; + + int local_block_id = linear_block_id - block_offsets[tensor_id]; + + size_t M = rowwise ? (m_array ? m_array[tensor_id] : common_m) + : (k_array ? k_array[tensor_id] : common_k); + size_t K = rowwise ? (k_array ? k_array[tensor_id] : common_k) + : (m_array ? m_array[tensor_id] : common_m); + + size_t padded_m = round_up_to_multiple(M, 128); + size_t padded_k = round_up_to_multiple(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE)), 4); + + int num_tiles_m = padded_m / SF_TILE_DIM_M; + int num_tiles_k = padded_k / SF_TILE_DIM_K; + + int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + + int grid_dim_x = rowwise ? DIVUP(num_tiles_k, n_tiles_in_tb) : DIVUP(num_tiles_k, TB_DIM); + int grid_dim_y = rowwise ? num_tiles_m : DIVUP(num_tiles_m, vec_load_size); - int grid_dim_x = rowwise ? DIVUP(num_tiles_k, n_tiles_in_tb) : DIVUP(num_tiles_k, TB_DIM); - int grid_dim_y = rowwise ? num_tiles_m : DIVUP(num_tiles_m, vec_load_size); - - int block_x = local_block_id % grid_dim_x; - int block_y = local_block_id / grid_dim_x; + int block_x = local_block_id % grid_dim_x; + int block_y = local_block_id / grid_dim_x; - const uint8_t* input_base = reinterpret_cast(input) + scale_offsets[tensor_id]; - uint8_t* output_base = reinterpret_cast(output) + scale_offsets[tensor_id]; + const uint8_t* input_base = reinterpret_cast(input) + scale_offsets[tensor_id]; + uint8_t* output_base = reinterpret_cast(output) + scale_offsets[tensor_id]; - int original_M = static_cast(M); - int original_K = static_cast(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE))); + int original_M = static_cast(M); + int original_K = static_cast(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE))); - if (rowwise) { + if (rowwise) { if (vec_load_size == 4) { - swizzle_row_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, - block_x, block_y, grid_dim_x, grid_dim_y); + swizzle_row_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, + grid_dim_x, grid_dim_y); } else if (vec_load_size == 2) { - swizzle_row_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, - block_x, block_y, grid_dim_x, grid_dim_y); + swizzle_row_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, + grid_dim_x, grid_dim_y); } else { - swizzle_row_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, - block_x, block_y, grid_dim_x, grid_dim_y); + swizzle_row_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, + grid_dim_x, grid_dim_y); } - } else { + } else { if (vec_load_size == 4) { - swizzle_col_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, - block_x, block_y, grid_dim_x, grid_dim_y); + swizzle_col_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, + grid_dim_x, grid_dim_y); } else if (vec_load_size == 2) { - swizzle_col_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, - block_x, block_y, grid_dim_x, grid_dim_y); + swizzle_col_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, + grid_dim_x, grid_dim_y); } else { - swizzle_col_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, - block_x, block_y, grid_dim_x, grid_dim_y); + swizzle_col_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, + grid_dim_x, grid_dim_y); } - } - __syncthreads(); + } + __syncthreads(); } } -__global__ void compute_grouped_swizzle_setup( - const int64_t* m_array, - const int64_t* k_array, - int* block_offsets, - size_t* scale_offsets, - int* total_blocks, - int* global_counter, - size_t num_tensors, - bool rowwise, - size_t scale_elem_size, - size_t common_m, - size_t common_k) { - +__global__ void compute_grouped_swizzle_setup(const int64_t* m_array, const int64_t* k_array, + int* block_offsets, size_t* scale_offsets, + int* total_blocks, int* global_counter, + size_t num_tensors, bool rowwise, + size_t scale_elem_size, size_t common_m, + size_t common_k) { if (blockIdx.x == 0 && threadIdx.x == 0) { int current_block_offset = 0; size_t current_scale_offset = 0; - + for (size_t i = 0; i < num_tensors; ++i) { block_offsets[i] = current_block_offset; scale_offsets[i] = current_scale_offset; - - size_t m = rowwise ? (m_array ? m_array[i] : common_m) - : (k_array ? k_array[i] : common_k); - size_t k = rowwise ? (k_array ? k_array[i] : common_k) - : (m_array ? m_array[i] : common_m); - + + size_t m = rowwise ? (m_array ? m_array[i] : common_m) : (k_array ? k_array[i] : common_k); + size_t k = rowwise ? (k_array ? k_array[i] : common_k) : (m_array ? m_array[i] : common_m); + size_t padded_m = round_up_to_multiple(m, 128); size_t padded_k = round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); - + int num_tiles_m = padded_m / 128; int num_tiles_k = padded_k / 4; - + int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); if (vec_load_size == 3) vec_load_size = 1; - + int blocks_m = num_tiles_m; int blocks_k = DIVUP(num_tiles_k, TB_DIM * vec_load_size); if (!rowwise) { - blocks_m = DIVUP(num_tiles_m, vec_load_size); - blocks_k = DIVUP(num_tiles_k, TB_DIM); + blocks_m = DIVUP(num_tiles_m, vec_load_size); + blocks_k = DIVUP(num_tiles_k, TB_DIM); } - + current_block_offset += blocks_m * blocks_k; current_scale_offset += padded_m * padded_k * scale_elem_size; } - + block_offsets[num_tensors] = current_block_offset; scale_offsets[num_tensors] = current_scale_offset; *total_blocks = current_block_offset; @@ -1740,141 +1723,146 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* NVTE_CHECK(input->all_same_last_dim() && input->all_same_first_dim(), "Grouped swizzle requires uniform tensor shapes."); - // Assumption is that all the tensors share the same shapes and are contgiuous. - // And so we dont need to pass array of input/output pointers(due to conttiguity) - // as well as array of shapes(due to uniform shapes). - const size_t first_dim = input->get_common_first_dim(); - const size_t last_dim = input->get_common_last_dim(); + // Assumption is that all the tensors share the same shapes and are contgiuous. + // And so we dont need to pass array of input/output pointers(due to conttiguity) + // as well as array of shapes(due to uniform shapes). + const size_t first_dim = input->get_common_first_dim(); + const size_t last_dim = input->get_common_last_dim(); - constexpr int SF_TILE_DIM_M = 128; - constexpr int SF_TILE_DIM_K = 4; - const dim3 block_size(TB_DIM, TB_DIM); + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + const dim3 block_size(TB_DIM, TB_DIM); - auto launch_grouped_swizzle = [&](bool rowwise) { - const size_t m = rowwise ? first_dim : last_dim; - const size_t k = rowwise ? last_dim : first_dim; - const size_t padded_m = round_up_to_multiple(m, 128); - const size_t padded_k = - round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); - const size_t scale_elems = padded_m * padded_k; + auto launch_grouped_swizzle = [&](bool rowwise) { + const size_t m = rowwise ? first_dim : last_dim; + const size_t k = rowwise ? last_dim : first_dim; + const size_t padded_m = round_up_to_multiple(m, 128); + const size_t padded_k = + round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); + const size_t scale_elems = padded_m * padded_k; - const size_t scale_elem_size = rowwise ? typeToSize(input->scale_inv.dtype) - : typeToSize(input->columnwise_scale_inv.dtype); - const size_t scale_stride_bytes = scale_elems * scale_elem_size; + const size_t scale_elem_size = rowwise ? typeToSize(input->scale_inv.dtype) + : typeToSize(input->columnwise_scale_inv.dtype); + const size_t scale_stride_bytes = scale_elems * scale_elem_size; - if (rowwise) { - NVTE_CHECK(input->scale_inv.numel() == input->num_tensors * scale_elems, - "Grouped input scale_inv size does not match expected packed size."); - NVTE_CHECK(output->scale_inv.numel() == output->num_tensors * scale_elems, - "Grouped output scale_inv size does not match expected packed size."); - } else { - NVTE_CHECK(input->columnwise_scale_inv.numel() == input->num_tensors * scale_elems, - "Grouped input columnwise_scale_inv size does not match expected packed size."); - NVTE_CHECK(output->columnwise_scale_inv.numel() == output->num_tensors * scale_elems, - "Grouped output columnwise_scale_inv size does not match expected packed size."); - } + if (rowwise) { + NVTE_CHECK(input->scale_inv.numel() == input->num_tensors * scale_elems, + "Grouped input scale_inv size does not match expected packed size."); + NVTE_CHECK(output->scale_inv.numel() == output->num_tensors * scale_elems, + "Grouped output scale_inv size does not match expected packed size."); + } else { + NVTE_CHECK(input->columnwise_scale_inv.numel() == input->num_tensors * scale_elems, + "Grouped input columnwise_scale_inv size does not match expected packed size."); + NVTE_CHECK(output->columnwise_scale_inv.numel() == output->num_tensors * scale_elems, + "Grouped output columnwise_scale_inv size does not match expected packed size."); + } - const int num_tiles_m = padded_m / SF_TILE_DIM_M; - const int num_tiles_k = padded_k / SF_TILE_DIM_K; - int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); - if (vec_load_size == 3) vec_load_size = 1; - const int n_tiles_in_tb = TB_DIM * vec_load_size; + const int num_tiles_m = padded_m / SF_TILE_DIM_M; + const int num_tiles_k = padded_k / SF_TILE_DIM_K; + int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); + if (vec_load_size == 3) vec_load_size = 1; + const int n_tiles_in_tb = TB_DIM * vec_load_size; - dim3 num_blocks; - if (rowwise) { - num_blocks = dim3(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m, input->num_tensors); - } else { - num_blocks = - dim3(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size), input->num_tensors); - } - const int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + dim3 num_blocks; + if (rowwise) { + num_blocks = dim3(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m, input->num_tensors); + } else { + num_blocks = + dim3(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size), input->num_tensors); + } + const int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - const int original_M = static_cast(rowwise ? first_dim : last_dim); - const int original_K = static_cast(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE))); - const void* input_ptr = rowwise ? input->scale_inv.dptr : input->columnwise_scale_inv.dptr; - void* output_ptr = rowwise ? output->scale_inv.dptr : output->columnwise_scale_inv.dptr; + const int original_M = static_cast(rowwise ? first_dim : last_dim); + const int original_K = static_cast(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE))); + const void* input_ptr = rowwise ? input->scale_inv.dptr : input->columnwise_scale_inv.dptr; + void* output_ptr = rowwise ? output->scale_inv.dptr : output->columnwise_scale_inv.dptr; - if (rowwise) { - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_row_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_row_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - case 2: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_row_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_row_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - case 1: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_row_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_row_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - } - } else { - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_col_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_col_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - case 2: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_col_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_col_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - case 1: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_col_scaling_uniform_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - grouped_swizzle_col_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - scale_stride_bytes); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); + if (rowwise) { + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_row_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_row_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + scale_stride_bytes); + break; + case 2: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_row_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_row_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + scale_stride_bytes); + break; + case 1: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_row_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_row_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + scale_stride_bytes); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + } + } else { + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_col_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_col_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + scale_stride_bytes); + break; + case 2: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_col_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_col_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + scale_stride_bytes); + break; + case 1: + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_col_scaling_uniform_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + grouped_swizzle_col_scaling_uniform_shape_kernel + <<>>(input_ptr, output_ptr, padded_m, + padded_k, original_M, original_K, + scale_stride_bytes); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + } } - } - NVTE_CHECK_CUDA(cudaGetLastError()); - }; + NVTE_CHECK_CUDA(cudaGetLastError()); + }; - if (has_rowwise_scale_inv) { - launch_grouped_swizzle(true); - } - if (has_columnwise_scale_inv) { - launch_grouped_swizzle(false); - } + if (has_rowwise_scale_inv) { + launch_grouped_swizzle(true); + } + if (has_columnwise_scale_inv) { + launch_grouped_swizzle(false); + } } else { // Variable shape implementation using Device-Side Block Scheduler size_t num_tensors = input->num_tensors; - NVTE_CHECK(workspace != nullptr, "Workspace must be provided for variable shape grouped swizzle."); + NVTE_CHECK(workspace != nullptr, + "Workspace must be provided for variable shape grouped swizzle."); size_t int_stride = num_tensors + 3; - if (int_stride % 2 != 0) int_stride++; + if (int_stride % 2 != 0) int_stride++; int* d_block_offsets = reinterpret_cast(workspace); int* d_global_counter = d_block_offsets + num_tensors + 1; int* d_total_blocks = d_global_counter + 1; @@ -1888,19 +1876,18 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* auto launch_grouped_swizzle_variable = [&](bool rowwise) { const size_t scale_elem_size = rowwise ? typeToSize(input->scale_inv.dtype) : typeToSize(input->columnwise_scale_inv.dtype); - + size_t common_m = input->all_same_first_dim() ? input->get_common_first_dim() : 0; size_t common_k = input->all_same_last_dim() ? input->get_common_last_dim() : 0; compute_grouped_swizzle_setup<<<1, 1, 0, stream>>>( - m_array, k_array, d_block_offsets, d_scale_offsets, d_total_blocks, - d_global_counter, num_tensors, rowwise, scale_elem_size, - common_m, common_k); + m_array, k_array, d_block_offsets, d_scale_offsets, d_total_blocks, d_global_counter, + num_tensors, rowwise, scale_elem_size, common_m, common_k); NVTE_CHECK_CUDA(cudaFuncSetAttribute( grouped_swizzle_scaling_variable_shape_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_slm_size)); - + int device_id; cudaGetDevice(&device_id); int num_SMs; @@ -1910,8 +1897,8 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_per_sm, grouped_swizzle_scaling_variable_shape_kernel, - TB_DIM * TB_DIM, // block size - max_slm_size // dynamic shared memory + TB_DIM * TB_DIM, // block size + max_slm_size // dynamic shared memory ); int persistent_blocks = num_SMs * max_active_blocks_per_sm; dim3 num_blocks(persistent_blocks); @@ -1921,9 +1908,8 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* grouped_swizzle_scaling_variable_shape_kernel <<>>( - input_ptr, output_ptr, m_array, k_array, d_block_offsets, - d_scale_offsets, d_global_counter, num_tensors, rowwise, - common_m, common_k); + input_ptr, output_ptr, m_array, k_array, d_block_offsets, d_scale_offsets, + d_global_counter, num_tensors, rowwise, common_m, common_k); NVTE_CHECK_CUDA(cudaGetLastError()); }; diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index d9eb29dff9..58d7aef0cc 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -398,18 +398,18 @@ std::optional maybe_swizzle_grouped_tensor(GroupedTensorW } swizzle_output.set_with_gemm_swizzled_scales(true); - + size_t num_tensors = input.num_tensors(); - size_t num_int_elems = num_tensors + 3; // n+1 block_offsets + gc + tb - if (num_int_elems % 2 != 0) num_int_elems++; // pad to even for size_t alignment + size_t num_int_elems = num_tensors + 3; // n+1 block_offsets + gc + tb + if (num_int_elems % 2 != 0) num_int_elems++; // pad to even for size_t alignment size_t workspace_size = num_int_elems * sizeof(int) + (num_tensors + 1) * sizeof(size_t); workspace_size = roundup(workspace_size, 256); - auto workspace = allocateSpace(std::vector{workspace_size}, transformer_engine::DType::kByte, false); + auto workspace = + allocateSpace(std::vector{workspace_size}, transformer_engine::DType::kByte, false); NVTE_SCOPED_GIL_RELEASE({ nvte_swizzle_grouped_scaling_factors(swizzle_input.data(), swizzle_output.data(), - getDataPtr(workspace), - at::cuda::getCurrentCUDAStream()); + getDataPtr(workspace), at::cuda::getCurrentCUDAStream()); }); if (swizzle_rowwise) { From 749692ce2f6d4d35378153f8f0b1468575e600c1 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Sat, 25 Apr 2026 15:55:54 -0700 Subject: [PATCH 4/8] Using single kernel for variable m and k Signed-off-by: Abhishek --- tests/cpp/operator/test_swizzle.cu | 15 +- .../include/transformer_engine/swizzle.h | 2 +- transformer_engine/common/swizzle/swizzle.cu | 279 +++++++++--------- .../pytorch/csrc/extensions/swizzle.cpp | 12 +- 4 files changed, 138 insertions(+), 170 deletions(-) diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index d02017fc72..e8a600e1b3 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -281,7 +281,7 @@ void performTestGroupedSwizzleMXFP8(const int num_tensors, const size_t M, const NVTE_CHECK_CUDA(cudaMemset(grouped_output.columnwise_scale_inv.get(), 0, num_tensors * col_numel)); nvte_swizzle_grouped_scaling_factors(grouped_input.get_handle(), - grouped_output.get_handle(), nullptr, 0); + grouped_output.get_handle(), 0); std::vector output_row(num_tensors * row_numel); std::vector output_col(num_tensors * col_numel); @@ -481,7 +481,7 @@ void performTestGroupedSwizzleUnswizzleRoundtrip(const int num_tensors, const si NVTE_CHECK_CUDA(cudaMemset(grouped_fin.scale_inv.get(), 0, num_tensors * row_numel)); NVTE_CHECK_CUDA(cudaMemset(grouped_fin.columnwise_scale_inv.get(), 0, num_tensors * col_numel)); - nvte_swizzle_grouped_scaling_factors(grouped_orig.get_handle(), grouped_mid.get_handle(), nullptr, 0); + nvte_swizzle_grouped_scaling_factors(grouped_orig.get_handle(), grouped_mid.get_handle(), 0); nvte_unswizzle_grouped_scaling_factors(grouped_mid.get_handle(), grouped_fin.get_handle(), 0); std::vector result_row(num_tensors * row_numel); @@ -562,17 +562,9 @@ void performTestGroupedSwizzleMXFP8Variable(const std::vector __global__ void __launch_bounds__(TB_DIM* TB_DIM) - grouped_swizzle_scaling_variable_shape_kernel(const void* input, void* output, - const int64_t* m_array, const int64_t* k_array, - const int* block_offsets, - const size_t* scale_offsets, int* global_counter, - int num_tensors, bool rowwise, size_t common_m, - size_t common_k) { - __shared__ int linear_block_id; - while (true) { - if (threadIdx.x == 0 && threadIdx.y == 0) { - linear_block_id = atomicAdd(global_counter, 1); - } - __syncthreads(); - - int tensor_id = -1; - int low = 0; - int high = num_tensors - 1; - while (low <= high) { - int mid = low + (high - low) / 2; - if (linear_block_id >= block_offsets[mid] && linear_block_id < block_offsets[mid + 1]) { - tensor_id = mid; - break; - } else if (linear_block_id < block_offsets[mid]) { - high = mid - 1; - } else { - low = mid + 1; + grouped_swizzle_scaling_variable_shape_kernel( + const void* input, + void* output, + const int64_t* m_array, + const int64_t* k_array, + int num_tensors, + bool rowwise, + size_t scale_elem_size, + size_t common_m, + size_t common_k) { + + extern __shared__ int s_metadata[]; + int* s_total_blocks = &s_metadata[0]; + + // Warp reduction to compute total workload + if (threadIdx.x < 32 && threadIdx.y == 0) { + int local_blocks = 0; + for (int i = threadIdx.x; i < num_tensors; i += 32) { + size_t m = rowwise ? (m_array ? m_array[i] : common_m) + : (k_array ? k_array[i] : common_k); + size_t k = rowwise ? (k_array ? k_array[i] : common_k) + : (m_array ? m_array[i] : common_m); + + size_t padded_m = round_up_to_multiple(m, 128); + size_t padded_k = round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); + + int num_tiles_m = padded_m / SF_TILE_DIM_M; + int num_tiles_k = padded_k / SF_TILE_DIM_K; + + int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + + int grid_dim_x = rowwise ? DIVUP(num_tiles_k, n_tiles_in_tb) : DIVUP(num_tiles_k, TB_DIM); + int grid_dim_y = rowwise ? num_tiles_m : DIVUP(num_tiles_m, vec_load_size); + local_blocks += grid_dim_x * grid_dim_y; } - } - - if (tensor_id == -1) return; - - int local_block_id = linear_block_id - block_offsets[tensor_id]; - - size_t M = rowwise ? (m_array ? m_array[tensor_id] : common_m) - : (k_array ? k_array[tensor_id] : common_k); - size_t K = rowwise ? (k_array ? k_array[tensor_id] : common_k) - : (m_array ? m_array[tensor_id] : common_m); - - size_t padded_m = round_up_to_multiple(M, 128); - size_t padded_k = round_up_to_multiple(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE)), 4); - - int num_tiles_m = padded_m / SF_TILE_DIM_M; - int num_tiles_k = padded_k / SF_TILE_DIM_K; - - int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); - if (vec_load_size == 3) vec_load_size = 1; - int n_tiles_in_tb = TB_DIM * vec_load_size; - - int grid_dim_x = rowwise ? DIVUP(num_tiles_k, n_tiles_in_tb) : DIVUP(num_tiles_k, TB_DIM); - int grid_dim_y = rowwise ? num_tiles_m : DIVUP(num_tiles_m, vec_load_size); - - int block_x = local_block_id % grid_dim_x; - int block_y = local_block_id / grid_dim_x; - - const uint8_t* input_base = reinterpret_cast(input) + scale_offsets[tensor_id]; - uint8_t* output_base = reinterpret_cast(output) + scale_offsets[tensor_id]; - - int original_M = static_cast(M); - int original_K = static_cast(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE))); - - if (rowwise) { - if (vec_load_size == 4) { - swizzle_row_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, - grid_dim_x, grid_dim_y); - } else if (vec_load_size == 2) { - swizzle_row_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, - grid_dim_x, grid_dim_y); - } else { - swizzle_row_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, - grid_dim_x, grid_dim_y); + + for (int offset = 16; offset > 0; offset /= 2) { + local_blocks += __shfl_down_sync(0xffffffff, local_blocks, offset); } - } else { - if (vec_load_size == 4) { - swizzle_col_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, - grid_dim_x, grid_dim_y); - } else if (vec_load_size == 2) { - swizzle_col_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, - grid_dim_x, grid_dim_y); - } else { - swizzle_col_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, - grid_dim_x, grid_dim_y); - } - } - __syncthreads(); + if (threadIdx.x == 0) *s_total_blocks = local_blocks; } -} - -__global__ void compute_grouped_swizzle_setup(const int64_t* m_array, const int64_t* k_array, - int* block_offsets, size_t* scale_offsets, - int* total_blocks, int* global_counter, - size_t num_tensors, bool rowwise, - size_t scale_elem_size, size_t common_m, - size_t common_k) { - if (blockIdx.x == 0 && threadIdx.x == 0) { - int current_block_offset = 0; - size_t current_scale_offset = 0; - - for (size_t i = 0; i < num_tensors; ++i) { - block_offsets[i] = current_block_offset; - scale_offsets[i] = current_scale_offset; + __syncthreads(); - size_t m = rowwise ? (m_array ? m_array[i] : common_m) : (k_array ? k_array[i] : common_k); - size_t k = rowwise ? (k_array ? k_array[i] : common_k) : (m_array ? m_array[i] : common_m); + const int total_blocks = *s_total_blocks; + + // Persistent-grid loop + for (int linear_block_id = blockIdx.x; linear_block_id < total_blocks; linear_block_id += gridDim.x) { + // Discover tensor_id and local_block_id via linear scan + int tensor_id = 0; + int current_block_base = 0; + size_t current_scale_base = 0; + int grid_dim_x = 0; + int grid_dim_y = 0; + size_t M = 0, K = 0; + int vec_load_size = 0; + + for (int i = 0; i < num_tensors; ++i) { + M = rowwise ? (m_array ? m_array[i] : common_m) + : (k_array ? k_array[i] : common_k); + K = rowwise ? (k_array ? k_array[i] : common_k) + : (m_array ? m_array[i] : common_m); + + size_t padded_m = round_up_to_multiple(M, 128); + size_t padded_k = round_up_to_multiple(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE)), 4); + + int num_tiles_m = padded_m / SF_TILE_DIM_M; + int num_tiles_k = padded_k / SF_TILE_DIM_K; + + vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + + grid_dim_x = rowwise ? DIVUP(num_tiles_k, n_tiles_in_tb) : DIVUP(num_tiles_k, TB_DIM); + grid_dim_y = rowwise ? num_tiles_m : DIVUP(num_tiles_m, vec_load_size); + int blocks_i = grid_dim_x * grid_dim_y; + + if (linear_block_id < current_block_base + blocks_i) { + tensor_id = i; + break; + } + current_block_base += blocks_i; + current_scale_base += padded_m * padded_k * scale_elem_size; + } - size_t padded_m = round_up_to_multiple(m, 128); - size_t padded_k = round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); + int local_block_id = linear_block_id - current_block_base; + int block_x = local_block_id % grid_dim_x; + int block_y = local_block_id / grid_dim_x; - int num_tiles_m = padded_m / 128; - int num_tiles_k = padded_k / 4; + const uint8_t* input_base = reinterpret_cast(input) + current_scale_base; + uint8_t* output_base = reinterpret_cast(output) + current_scale_base; - int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); - if (vec_load_size == 3) vec_load_size = 1; + size_t padded_m = round_up_to_multiple(M, 128); + size_t padded_k = round_up_to_multiple(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE)), 4); + int original_M = static_cast(M); + int original_K = static_cast(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE))); - int blocks_m = num_tiles_m; - int blocks_k = DIVUP(num_tiles_k, TB_DIM * vec_load_size); - if (!rowwise) { - blocks_m = DIVUP(num_tiles_m, vec_load_size); - blocks_k = DIVUP(num_tiles_k, TB_DIM); + if (rowwise) { + if (vec_load_size == 4) { + swizzle_row_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, + block_x, block_y, grid_dim_x, grid_dim_y); + } else if (vec_load_size == 2) { + swizzle_row_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, + block_x, block_y, grid_dim_x, grid_dim_y); + } else { + swizzle_row_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, + block_x, block_y, grid_dim_x, grid_dim_y); + } + } else { + if (vec_load_size == 4) { + swizzle_col_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, + block_x, block_y, grid_dim_x, grid_dim_y); + } else if (vec_load_size == 2) { + swizzle_col_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, + block_x, block_y, grid_dim_x, grid_dim_y); + } else { + swizzle_col_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, + block_x, block_y, grid_dim_x, grid_dim_y); + } } - - current_block_offset += blocks_m * blocks_k; - current_scale_offset += padded_m * padded_k * scale_elem_size; - } - - block_offsets[num_tensors] = current_block_offset; - scale_offsets[num_tensors] = current_scale_offset; - *total_blocks = current_block_offset; - *global_counter = 0; } } void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* output, - void* workspace, cudaStream_t stream) { + cudaStream_t stream) { // Check scaling mode NVTE_CHECK(input->scaling_mode == NVTE_MXFP8_1D_SCALING, "Grouped swizzle supports only MXFP8 scaling."); @@ -1858,15 +1852,6 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* } else { // Variable shape implementation using Device-Side Block Scheduler size_t num_tensors = input->num_tensors; - NVTE_CHECK(workspace != nullptr, - "Workspace must be provided for variable shape grouped swizzle."); - - size_t int_stride = num_tensors + 3; - if (int_stride % 2 != 0) int_stride++; - int* d_block_offsets = reinterpret_cast(workspace); - int* d_global_counter = d_block_offsets + num_tensors + 1; - int* d_total_blocks = d_global_counter + 1; - size_t* d_scale_offsets = reinterpret_cast(d_block_offsets + int_stride); constexpr int SF_TILE_DIM_M = 128; constexpr int SF_TILE_DIM_K = 4; @@ -1880,25 +1865,25 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* size_t common_m = input->all_same_first_dim() ? input->get_common_first_dim() : 0; size_t common_k = input->all_same_last_dim() ? input->get_common_last_dim() : 0; - compute_grouped_swizzle_setup<<<1, 1, 0, stream>>>( - m_array, k_array, d_block_offsets, d_scale_offsets, d_total_blocks, d_global_counter, - num_tensors, rowwise, scale_elem_size, common_m, common_k); + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + const int max_slm_size = TB_DIM * 4 * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int metadata_shmem = sizeof(int); // s_total_blocks NVTE_CHECK_CUDA(cudaFuncSetAttribute( grouped_swizzle_scaling_variable_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, max_slm_size)); - + cudaFuncAttributeMaxDynamicSharedMemorySize, max_slm_size + metadata_shmem)); + int device_id; cudaGetDevice(&device_id); int num_SMs; cudaDeviceGetAttribute(&num_SMs, cudaDevAttrMultiProcessorCount, device_id); - // Find out how many blocks of this specific kernel can fit on one SM int max_active_blocks_per_sm; cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_per_sm, grouped_swizzle_scaling_variable_shape_kernel, - TB_DIM * TB_DIM, // block size - max_slm_size // dynamic shared memory + TB_DIM * TB_DIM, + max_slm_size + metadata_shmem ); int persistent_blocks = num_SMs * max_active_blocks_per_sm; dim3 num_blocks(persistent_blocks); @@ -1907,9 +1892,9 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* void* output_ptr = rowwise ? output->scale_inv.dptr : output->columnwise_scale_inv.dptr; grouped_swizzle_scaling_variable_shape_kernel - <<>>( - input_ptr, output_ptr, m_array, k_array, d_block_offsets, d_scale_offsets, - d_global_counter, num_tensors, rowwise, common_m, common_k); + <<>>( + input_ptr, output_ptr, m_array, k_array, num_tensors, rowwise, + scale_elem_size, common_m, common_k); NVTE_CHECK_CUDA(cudaGetLastError()); }; @@ -2033,11 +2018,11 @@ void unswizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor } // namespace transformer_engine void nvte_swizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGroupedTensor output, - void* workspace, cudaStream_t stream) { + cudaStream_t stream) { NVTE_API_CALL(nvte_swizzle_grouped_scaling_factors); using namespace transformer_engine; swizzle_grouped_scaling_factors(convertNVTEGroupedTensorCheck(input), - convertNVTEGroupedTensorCheck(output), workspace, stream); + convertNVTEGroupedTensorCheck(output), stream); } void nvte_unswizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGroupedTensor output, diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index 58d7aef0cc..d0ba427edb 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -398,18 +398,10 @@ std::optional maybe_swizzle_grouped_tensor(GroupedTensorW } swizzle_output.set_with_gemm_swizzled_scales(true); - - size_t num_tensors = input.num_tensors(); - size_t num_int_elems = num_tensors + 3; // n+1 block_offsets + gc + tb - if (num_int_elems % 2 != 0) num_int_elems++; // pad to even for size_t alignment - size_t workspace_size = num_int_elems * sizeof(int) + (num_tensors + 1) * sizeof(size_t); - workspace_size = roundup(workspace_size, 256); - auto workspace = - allocateSpace(std::vector{workspace_size}, transformer_engine::DType::kByte, false); - + NVTE_SCOPED_GIL_RELEASE({ nvte_swizzle_grouped_scaling_factors(swizzle_input.data(), swizzle_output.data(), - getDataPtr(workspace), at::cuda::getCurrentCUDAStream()); + at::cuda::getCurrentCUDAStream()); }); if (swizzle_rowwise) { From 5836d2c47cc952f4b416b92795667a0e630ead50 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 25 Apr 2026 23:37:31 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/swizzle/swizzle.cu | 223 +++++++++--------- .../pytorch/csrc/extensions/swizzle.cpp | 2 +- 2 files changed, 107 insertions(+), 118 deletions(-) diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index f67e1c20fd..c7b5dcdbda 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -1556,133 +1556,124 @@ namespace transformer_engine { template __global__ void __launch_bounds__(TB_DIM* TB_DIM) - grouped_swizzle_scaling_variable_shape_kernel( - const void* input, - void* output, - const int64_t* m_array, - const int64_t* k_array, - int num_tensors, - bool rowwise, - size_t scale_elem_size, - size_t common_m, - size_t common_k) { - + grouped_swizzle_scaling_variable_shape_kernel(const void* input, void* output, + const int64_t* m_array, const int64_t* k_array, + int num_tensors, bool rowwise, + size_t scale_elem_size, size_t common_m, + size_t common_k) { extern __shared__ int s_metadata[]; int* s_total_blocks = &s_metadata[0]; - + // Warp reduction to compute total workload if (threadIdx.x < 32 && threadIdx.y == 0) { - int local_blocks = 0; - for (int i = threadIdx.x; i < num_tensors; i += 32) { - size_t m = rowwise ? (m_array ? m_array[i] : common_m) - : (k_array ? k_array[i] : common_k); - size_t k = rowwise ? (k_array ? k_array[i] : common_k) - : (m_array ? m_array[i] : common_m); - - size_t padded_m = round_up_to_multiple(m, 128); - size_t padded_k = round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); - - int num_tiles_m = padded_m / SF_TILE_DIM_M; - int num_tiles_k = padded_k / SF_TILE_DIM_K; - - int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); - if (vec_load_size == 3) vec_load_size = 1; - int n_tiles_in_tb = TB_DIM * vec_load_size; - - int grid_dim_x = rowwise ? DIVUP(num_tiles_k, n_tiles_in_tb) : DIVUP(num_tiles_k, TB_DIM); - int grid_dim_y = rowwise ? num_tiles_m : DIVUP(num_tiles_m, vec_load_size); - local_blocks += grid_dim_x * grid_dim_y; - } - - for (int offset = 16; offset > 0; offset /= 2) { - local_blocks += __shfl_down_sync(0xffffffff, local_blocks, offset); - } - if (threadIdx.x == 0) *s_total_blocks = local_blocks; + int local_blocks = 0; + for (int i = threadIdx.x; i < num_tensors; i += 32) { + size_t m = rowwise ? (m_array ? m_array[i] : common_m) : (k_array ? k_array[i] : common_k); + size_t k = rowwise ? (k_array ? k_array[i] : common_k) : (m_array ? m_array[i] : common_m); + + size_t padded_m = round_up_to_multiple(m, 128); + size_t padded_k = round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); + + int num_tiles_m = padded_m / SF_TILE_DIM_M; + int num_tiles_k = padded_k / SF_TILE_DIM_K; + + int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + + int grid_dim_x = rowwise ? DIVUP(num_tiles_k, n_tiles_in_tb) : DIVUP(num_tiles_k, TB_DIM); + int grid_dim_y = rowwise ? num_tiles_m : DIVUP(num_tiles_m, vec_load_size); + local_blocks += grid_dim_x * grid_dim_y; + } + + for (int offset = 16; offset > 0; offset /= 2) { + local_blocks += __shfl_down_sync(0xffffffff, local_blocks, offset); + } + if (threadIdx.x == 0) *s_total_blocks = local_blocks; } __syncthreads(); const int total_blocks = *s_total_blocks; // Persistent-grid loop - for (int linear_block_id = blockIdx.x; linear_block_id < total_blocks; linear_block_id += gridDim.x) { - // Discover tensor_id and local_block_id via linear scan - int tensor_id = 0; - int current_block_base = 0; - size_t current_scale_base = 0; - int grid_dim_x = 0; - int grid_dim_y = 0; - size_t M = 0, K = 0; - int vec_load_size = 0; - - for (int i = 0; i < num_tensors; ++i) { - M = rowwise ? (m_array ? m_array[i] : common_m) - : (k_array ? k_array[i] : common_k); - K = rowwise ? (k_array ? k_array[i] : common_k) - : (m_array ? m_array[i] : common_m); - - size_t padded_m = round_up_to_multiple(M, 128); - size_t padded_k = round_up_to_multiple(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE)), 4); - - int num_tiles_m = padded_m / SF_TILE_DIM_M; - int num_tiles_k = padded_k / SF_TILE_DIM_K; - - vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); - if (vec_load_size == 3) vec_load_size = 1; - int n_tiles_in_tb = TB_DIM * vec_load_size; - - grid_dim_x = rowwise ? DIVUP(num_tiles_k, n_tiles_in_tb) : DIVUP(num_tiles_k, TB_DIM); - grid_dim_y = rowwise ? num_tiles_m : DIVUP(num_tiles_m, vec_load_size); - int blocks_i = grid_dim_x * grid_dim_y; - - if (linear_block_id < current_block_base + blocks_i) { - tensor_id = i; - break; - } - current_block_base += blocks_i; - current_scale_base += padded_m * padded_k * scale_elem_size; + for (int linear_block_id = blockIdx.x; linear_block_id < total_blocks; + linear_block_id += gridDim.x) { + // Discover tensor_id and local_block_id via linear scan + int tensor_id = 0; + int current_block_base = 0; + size_t current_scale_base = 0; + int grid_dim_x = 0; + int grid_dim_y = 0; + size_t M = 0, K = 0; + int vec_load_size = 0; + + for (int i = 0; i < num_tensors; ++i) { + M = rowwise ? (m_array ? m_array[i] : common_m) : (k_array ? k_array[i] : common_k); + K = rowwise ? (k_array ? k_array[i] : common_k) : (m_array ? m_array[i] : common_m); + + size_t padded_m = round_up_to_multiple(M, 128); + size_t padded_k = round_up_to_multiple(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE)), 4); + + int num_tiles_m = padded_m / SF_TILE_DIM_M; + int num_tiles_k = padded_k / SF_TILE_DIM_K; + + vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + + grid_dim_x = rowwise ? DIVUP(num_tiles_k, n_tiles_in_tb) : DIVUP(num_tiles_k, TB_DIM); + grid_dim_y = rowwise ? num_tiles_m : DIVUP(num_tiles_m, vec_load_size); + int blocks_i = grid_dim_x * grid_dim_y; + + if (linear_block_id < current_block_base + blocks_i) { + tensor_id = i; + break; } + current_block_base += blocks_i; + current_scale_base += padded_m * padded_k * scale_elem_size; + } - int local_block_id = linear_block_id - current_block_base; - int block_x = local_block_id % grid_dim_x; - int block_y = local_block_id / grid_dim_x; + int local_block_id = linear_block_id - current_block_base; + int block_x = local_block_id % grid_dim_x; + int block_y = local_block_id / grid_dim_x; - const uint8_t* input_base = reinterpret_cast(input) + current_scale_base; - uint8_t* output_base = reinterpret_cast(output) + current_scale_base; + const uint8_t* input_base = reinterpret_cast(input) + current_scale_base; + uint8_t* output_base = reinterpret_cast(output) + current_scale_base; - size_t padded_m = round_up_to_multiple(M, 128); - size_t padded_k = round_up_to_multiple(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE)), 4); - int original_M = static_cast(M); - int original_K = static_cast(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE))); + size_t padded_m = round_up_to_multiple(M, 128); + size_t padded_k = round_up_to_multiple(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE)), 4); + int original_M = static_cast(M); + int original_K = static_cast(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE))); - if (rowwise) { - if (vec_load_size == 4) { - swizzle_row_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, - block_x, block_y, grid_dim_x, grid_dim_y); - } else if (vec_load_size == 2) { - swizzle_row_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, - block_x, block_y, grid_dim_x, grid_dim_y); - } else { - swizzle_row_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, - block_x, block_y, grid_dim_x, grid_dim_y); - } + if (rowwise) { + if (vec_load_size == 4) { + swizzle_row_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, + grid_dim_x, grid_dim_y); + } else if (vec_load_size == 2) { + swizzle_row_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, + grid_dim_x, grid_dim_y); } else { - if (vec_load_size == 4) { - swizzle_col_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, - block_x, block_y, grid_dim_x, grid_dim_y); - } else if (vec_load_size == 2) { - swizzle_col_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, - block_x, block_y, grid_dim_x, grid_dim_y); - } else { - swizzle_col_scaling_kernel_impl( - input_base, output_base, padded_m, padded_k, original_M, original_K, - block_x, block_y, grid_dim_x, grid_dim_y); - } + swizzle_row_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, + grid_dim_x, grid_dim_y); + } + } else { + if (vec_load_size == 4) { + swizzle_col_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, + grid_dim_x, grid_dim_y); + } else if (vec_load_size == 2) { + swizzle_col_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, + grid_dim_x, grid_dim_y); + } else { + swizzle_col_scaling_kernel_impl( + input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, + grid_dim_x, grid_dim_y); } + } } } @@ -1868,12 +1859,12 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* constexpr int SF_TILE_DIM_M = 128; constexpr int SF_TILE_DIM_K = 4; const int max_slm_size = TB_DIM * 4 * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - const int metadata_shmem = sizeof(int); // s_total_blocks + const int metadata_shmem = sizeof(int); // s_total_blocks NVTE_CHECK_CUDA(cudaFuncSetAttribute( grouped_swizzle_scaling_variable_shape_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_slm_size + metadata_shmem)); - + int device_id; cudaGetDevice(&device_id); int num_SMs; @@ -1882,9 +1873,7 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks_per_sm, grouped_swizzle_scaling_variable_shape_kernel, - TB_DIM * TB_DIM, - max_slm_size + metadata_shmem - ); + TB_DIM * TB_DIM, max_slm_size + metadata_shmem); int persistent_blocks = num_SMs * max_active_blocks_per_sm; dim3 num_blocks(persistent_blocks); @@ -1893,8 +1882,8 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* grouped_swizzle_scaling_variable_shape_kernel <<>>( - input_ptr, output_ptr, m_array, k_array, num_tensors, rowwise, - scale_elem_size, common_m, common_k); + input_ptr, output_ptr, m_array, k_array, num_tensors, rowwise, scale_elem_size, + common_m, common_k); NVTE_CHECK_CUDA(cudaGetLastError()); }; diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index d0ba427edb..3325685a61 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -398,7 +398,7 @@ std::optional maybe_swizzle_grouped_tensor(GroupedTensorW } swizzle_output.set_with_gemm_swizzled_scales(true); - + NVTE_SCOPED_GIL_RELEASE({ nvte_swizzle_grouped_scaling_factors(swizzle_input.data(), swizzle_output.data(), at::cuda::getCurrentCUDAStream()); From 057607047eb14c9552d93bbfabfab4048f5cd19b Mon Sep 17 00:00:00 2001 From: Abhishek Date: Wed, 29 Apr 2026 18:25:29 -0700 Subject: [PATCH 6/8] Cached blocks per sm for device and removed redundant checks Signed-off-by: Abhishek --- transformer_engine/common/swizzle/swizzle.cu | 71 +++++++++++-------- .../pytorch/csrc/extensions/swizzle.cpp | 2 - 2 files changed, 42 insertions(+), 31 deletions(-) diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 0bed1edd45..8619d91220 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -8,10 +8,13 @@ #include #include +#include #include #include +#include #include "../common.h" +#include "../util/cuda_runtime.h" #include "../util/logging.h" #include "transformer_engine/transformer_engine.h" @@ -2005,6 +2008,28 @@ __global__ void __launch_bounds__(TB_DIM* TB_DIM) } } +template +int grouped_swizzle_variable_max_active_blocks_per_sm(int device_id) { + static std::vector cache(cuda::num_devices(), -1); + static std::vector flags(cuda::num_devices()); + NVTE_CHECK(0 <= device_id && device_id < cuda::num_devices(), "invalid CUDA device ID"); + + auto init = [&]() { + constexpr int metadata_shmem = sizeof(int); // s_total_blocks + constexpr int dynamic_smem_size = + TB_DIM * 4 * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t) + metadata_shmem; + int max_active_blocks_per_sm; + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_per_sm, + grouped_swizzle_scaling_variable_shape_kernel, + TB_DIM * TB_DIM, dynamic_smem_size)); + NVTE_CHECK(max_active_blocks_per_sm > 0, "Occupancy query returned 0 blocks per SM."); + cache[device_id] = max_active_blocks_per_sm; + }; + std::call_once(flags[device_id], init); + return cache[device_id]; +} + void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* output, cudaStream_t stream) { // Check scaling mode @@ -2032,10 +2057,6 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* if (!is_variable_shape) { // Fallback to uniform shape implementation - NVTE_CHECK(input->all_same_shape(), "Grouped swizzle requires uniform tensor shapes."); - NVTE_CHECK(input->all_same_last_dim() && input->all_same_first_dim(), - "Grouped swizzle requires uniform tensor shapes."); - // Assumption is that all the tensors share the same shapes and are contgiuous. // And so we dont need to pass array of input/output pointers(due to conttiguity) // as well as array of shapes(due to uniform shapes). @@ -2176,40 +2197,32 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* constexpr int SF_TILE_DIM_K = 4; const dim3 block_size(TB_DIM, TB_DIM); const int max_slm_size = TB_DIM * 4 * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int metadata_shmem = sizeof(int); // s_total_blocks + const int dynamic_smem_size = max_slm_size + metadata_shmem; + + size_t common_m = input->all_same_first_dim() ? input->get_common_first_dim() : 0; + size_t common_k = input->all_same_last_dim() ? input->get_common_last_dim() : 0; + + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + grouped_swizzle_scaling_variable_shape_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_smem_size)); + + const int device_id = cuda::current_device(); + const int num_SMs = cuda::sm_count(device_id); + const int max_active_blocks_per_sm = + grouped_swizzle_variable_max_active_blocks_per_sm(device_id); + const int persistent_blocks = num_SMs * max_active_blocks_per_sm; + const dim3 num_blocks(persistent_blocks); auto launch_grouped_swizzle_variable = [&](bool rowwise) { const size_t scale_elem_size = rowwise ? typeToSize(input->scale_inv.dtype) : typeToSize(input->columnwise_scale_inv.dtype); - size_t common_m = input->all_same_first_dim() ? input->get_common_first_dim() : 0; - size_t common_k = input->all_same_last_dim() ? input->get_common_last_dim() : 0; - - constexpr int SF_TILE_DIM_M = 128; - constexpr int SF_TILE_DIM_K = 4; - const int max_slm_size = TB_DIM * 4 * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - const int metadata_shmem = sizeof(int); // s_total_blocks - - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - grouped_swizzle_scaling_variable_shape_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, max_slm_size + metadata_shmem)); - - int device_id; - cudaGetDevice(&device_id); - int num_SMs; - cudaDeviceGetAttribute(&num_SMs, cudaDevAttrMultiProcessorCount, device_id); - int max_active_blocks_per_sm; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks_per_sm, - grouped_swizzle_scaling_variable_shape_kernel, - TB_DIM * TB_DIM, max_slm_size + metadata_shmem); - int persistent_blocks = num_SMs * max_active_blocks_per_sm; - dim3 num_blocks(persistent_blocks); - const void* input_ptr = rowwise ? input->scale_inv.dptr : input->columnwise_scale_inv.dptr; void* output_ptr = rowwise ? output->scale_inv.dptr : output->columnwise_scale_inv.dptr; grouped_swizzle_scaling_variable_shape_kernel - <<>>( + <<>>( input_ptr, output_ptr, m_array, k_array, num_tensors, rowwise, scale_elem_size, common_m, common_k); diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index eaa4e95ce9..7a198ca70b 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -379,8 +379,6 @@ std::optional maybe_swizzle_grouped_tensor(GroupedTensorW if (!swizzle_rowwise && !swizzle_columnwise) { return std::nullopt; } - const auto first_dims = input.get_first_dims(); - const auto last_dims = input.get_last_dims(); std::optional rowwise_scales_pyt; std::optional columnwise_scales_pyt; From 04003c5f700dd0055caa35dfef3d364d9ee49882 Mon Sep 17 00:00:00 2001 From: Abhishek Date: Wed, 29 Apr 2026 18:56:22 -0700 Subject: [PATCH 7/8] Updated the code with newer changes in main Signed-off-by: Abhishek --- transformer_engine/common/swizzle/swizzle.cu | 35 +++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 11dca2cb57..66fb11d014 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -2090,38 +2090,41 @@ __global__ void __launch_bounds__(TB_DIM* TB_DIM) const uint8_t* input_base = reinterpret_cast(input) + current_scale_base; uint8_t* output_base = reinterpret_cast(output) + current_scale_base; - size_t padded_m = round_up_to_multiple(M, 128); - size_t padded_k = round_up_to_multiple(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE)), 4); - int original_M = static_cast(M); - int original_K = static_cast(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE))); + const int padded_m = static_cast(round_up_to_multiple(M, 128)); + const int padded_k = + static_cast(round_up_to_multiple(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE)), 4)); + const int original_M = static_cast(M); + const int original_K = static_cast(DIVUP(K, static_cast(MXFP8_BLOCK_SIZE))); + const bool padding_m = (block_y == grid_dim_y - 1) && (original_M < padded_m); + const bool padding_k = (block_x == grid_dim_x - 1) && (original_K < padded_k); if (rowwise) { if (vec_load_size == 4) { - swizzle_row_scaling_kernel_impl( + dispatch_swizzle_row_scaling_kernel_impl( input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, - grid_dim_x, grid_dim_y); + grid_dim_x, grid_dim_y, padding_k, padding_m); } else if (vec_load_size == 2) { - swizzle_row_scaling_kernel_impl( + dispatch_swizzle_row_scaling_kernel_impl( input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, - grid_dim_x, grid_dim_y); + grid_dim_x, grid_dim_y, padding_k, padding_m); } else { - swizzle_row_scaling_kernel_impl( + dispatch_swizzle_row_scaling_kernel_impl( input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, - grid_dim_x, grid_dim_y); + grid_dim_x, grid_dim_y, padding_k, padding_m); } } else { if (vec_load_size == 4) { - swizzle_col_scaling_kernel_impl( + dispatch_swizzle_col_scaling_kernel_impl( input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, - grid_dim_x, grid_dim_y); + grid_dim_x, grid_dim_y, padding_k, padding_m); } else if (vec_load_size == 2) { - swizzle_col_scaling_kernel_impl( + dispatch_swizzle_col_scaling_kernel_impl( input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, - grid_dim_x, grid_dim_y); + grid_dim_x, grid_dim_y, padding_k, padding_m); } else { - swizzle_col_scaling_kernel_impl( + dispatch_swizzle_col_scaling_kernel_impl( input_base, output_base, padded_m, padded_k, original_M, original_K, block_x, block_y, - grid_dim_x, grid_dim_y); + grid_dim_x, grid_dim_y, padding_k, padding_m); } } } From 9bf8dcb935002316aa0aee71862ada6e74361833 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 Apr 2026 02:27:50 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/swizzle/swizzle.cu | 24 ++++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 66fb11d014..c7ed407a59 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -2220,7 +2220,7 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* padded_m; const size_t scale_elem_size = rowwise ? typeToSize(input->scale_inv.dtype) - : typeToSize(input->columnwise_scale_inv.dtype); + : typeToSize(input->columnwise_scale_inv.dtype); const size_t input_scale_numel = rowwise ? input->scale_inv.numel() : input->columnwise_scale_inv.numel(); @@ -2234,13 +2234,13 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* input_is_compact = true; } else { NVTE_ERROR("Grouped input ", (rowwise ? "scale_inv" : "columnwise_scale_inv"), - " size does not match expected packed size (got ", input_scale_numel, - ", expected either ", input->num_tensors * padded_scale_elems, - " (per-tensor padded) or ", compact_total_scale_elems, " (compact))."); + " size does not match expected packed size (got ", input_scale_numel, + ", expected either ", input->num_tensors * padded_scale_elems, + " (per-tensor padded) or ", compact_total_scale_elems, " (compact))."); } NVTE_CHECK(output_scale_numel == input->num_tensors * padded_scale_elems, "Grouped output ", - (rowwise ? "scale_inv" : "columnwise_scale_inv"), - " size does not match expected per-tensor padded size."); + (rowwise ? "scale_inv" : "columnwise_scale_inv"), + " size does not match expected per-tensor padded size."); const size_t input_stride_bytes = (input_is_compact ? compact_scale_elems : padded_scale_elems) * scale_elem_size; @@ -2272,9 +2272,9 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* grouped_swizzle_row_scaling_uniform_shape_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); grouped_swizzle_row_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - input_stride_bytes, output_stride_bytes); + <<>>( + input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, + input_stride_bytes, output_stride_bytes); }); } else { TRANSFORMER_ENGINE_VECTORIZED_LOAD_INTEGER_TYPE_SWITCH(vec_load_size, LType, { @@ -2282,9 +2282,9 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* grouped_swizzle_col_scaling_uniform_shape_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); grouped_swizzle_col_scaling_uniform_shape_kernel - <<>>(input_ptr, output_ptr, padded_m, - padded_k, original_M, original_K, - input_stride_bytes, output_stride_bytes); + <<>>( + input_ptr, output_ptr, padded_m, padded_k, original_M, original_K, + input_stride_bytes, output_stride_bytes); }); } NVTE_CHECK_CUDA(cudaGetLastError());