diff --git a/ci/pytorch.sh b/ci/pytorch.sh index 32fbf02f8..3bb282bc1 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -49,8 +49,9 @@ run_test_config(){ fi run 1 test_cuda_graphs.py run_default_fa 1 test_deferred_init.py - run_default_fa 1 test_quantized_tensor.py run_default_fa 1 test_float8_current_scaling_exact.py + run_default_fa 1 test_float8blockwisetensor.py + run_default_fa 1 test_quantized_tensor.py test $_fus_attn = auto -o $_fus_attn = ck && run 1 test_cpu_offloading.py test $_fus_attn = auto -o $_fus_attn = ck -o $_fus_attn = aotriton && NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 run 3 test_cpu_offloading_v1.py run_default_fa 1 test_fused_rope.py diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 901e5ec9f..13280028a 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -15,7 +15,7 @@ add_executable(test_operator test_cast_mxfp8.cu test_cast_mxfp8_grouped.cu test_cast_nvfp4_transpose.cu - test_cast_float8blockwise.cu #CUDA-only test + test_cast_float8blockwise.cu test_dequantize_mxfp8.cu test_transpose.cu test_cast_transpose.cu @@ -41,7 +41,6 @@ if(USE_ROCM) get_target_property(test_cuda_sources test_operator SOURCES) # Remove CUDA-only tests and add ROCm specific ones list(REMOVE_ITEM test_cuda_sources - test_cast_float8blockwise.cu test_swizzle.cu test_grouped_gemm.cu) list(APPEND test_cuda_sources diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu index b43cc6bd8..1674339b3 100644 --- a/tests/cpp/operator/test_cast_float8blockwise.cu +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -227,17 +229,28 @@ void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method } inline size_t scale_align_stride(size_t inner_elements) { +#ifdef __HIP_PLATFORM_AMD__ + return inner_elements; +#else return ((inner_elements + 4u - 1u) / 4u) * 4u; +#endif }; void compare_scaling_factors(const std::string& name, const float* test, const float* ref, const size_t row_blocks, const size_t col_blocks, const size_t test_stride, const size_t ref_stride) { +#ifdef __HIP_PLATFORM_AMD__ + const float atol = 1e-6f; +#endif for (int i = 0; i < row_blocks; ++i) { for (int j = 0; j < col_blocks; ++j) { const int test_idx = i * test_stride + j; const int ref_idx = i * ref_stride + j; +#ifdef __HIP_PLATFORM_AMD__ + ASSERT_FALSE(std::abs(test[test_idx] - ref[ref_idx]) > atol) +#else ASSERT_FALSE(test[test_idx] != ref[ref_idx]) +#endif << "Error in " << name << std::endl << "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx << "," << ref_idx; @@ -248,12 +261,19 @@ void compare_scaling_factors(const std::string& name, const float* test, const f void compare_scaling_factors_one_dimensional_blocks(const std::string& name, const float* test, const float* ref, const size_t rows, const size_t col_blocks) { +#ifdef __HIP_PLATFORM_AMD__ + const float atol = 1e-6f; +#endif const size_t test_stride = scale_align_stride(rows); for (int i = 0; i < rows; ++i) { for (int j = 0; j < col_blocks; ++j) { const int test_idx = i + test_stride * j; const int ref_idx = i + rows * j; +#ifdef __HIP_PLATFORM_AMD__ + ASSERT_FALSE(std::abs(test[test_idx] - ref[ref_idx]) > atol) +#else ASSERT_FALSE(test[test_idx] != ref[ref_idx]) +#endif << "Error in " << name << std::endl << "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx << "," << ref_idx; diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index 5fc6aa51c..ad996fb63 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -45,7 +45,10 @@ def _to_list(x: Union[Iterable, Any]) -> List: DimsType = Union[Iterable[int], int] # TODO replace with call to fp8.py when recipe added. -recipe_available = not IS_HIP_EXTENSION and (get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8) +if IS_HIP_EXTENSION: + recipe_available = get_device_compute_capability() >= (9, 4) +else: + recipe_available = get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8 reason_for_no_recipe = "Quantize kernels require TMA and are only relevant with GEMMS." diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 02eaaea93..33d9ba425 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -200,7 +200,7 @@ list(APPEND transformer_engine_cuda_sources transpose/cast_transpose_fusion.cu transpose/transpose_fusion.cu transpose/multi_cast_transpose.cu - transpose/quantize_transpose_vector_blockwise.cu #CUDA-only + transpose/quantize_transpose_vector_blockwise.cu transpose/swap_first_dims.cu dropout/dropout.cu fused_attn/flash_attn.cu @@ -233,7 +233,6 @@ list(APPEND transformer_engine_cuda_sources comm_gemm_overlap/userbuffers/userbuffers.cu) set(cuda_only_cuda_sources - transpose/quantize_transpose_vector_blockwise.cu fused_attn/fused_attn_f16_max512_seqlen.cu fused_attn/fused_attn_f16_arbitrary_seqlen.cu fused_attn/fused_attn_fp8.cu @@ -257,7 +256,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources multi_tensor/compute_scale.cu recipe/mxfp8_scaling.cu recipe/nvfp4.cu - transpose/quantize_transpose_square_blockwise.cu #CUDA-only + transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_vector_blockwise_fp4.cu) set(cuda_only_cuda_arch_specific_sources @@ -267,8 +266,7 @@ set(cuda_only_cuda_arch_specific_sources hadamard_transform/hadamard_transform_cast_fusion.cu hadamard_transform/group_hadamard_transform_cast_fusion.cu hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu - hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu - transpose/quantize_transpose_square_blockwise.cu) + hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu) # Compiling the files with the worst compilation time first to hopefully overlap # better with the faster-compiling cpp files diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 579caee06..1b52a7c68 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -164,7 +164,6 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, #endif break; } -#ifndef __HIP_PLATFORM_AMD__ case NVTE_BLOCK_SCALING_2D: { // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); @@ -196,7 +195,6 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, columnwise_option, force_pow_2_scales, noop_tensor->data, stream); break; } -#endif//#ifndef __HIP_PLATFORM_AMD__ default: NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); } @@ -317,7 +315,6 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens #endif break; } -#ifndef __HIP_PLATFORM_AMD__ case NVTE_BLOCK_SCALING_2D: { // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT), @@ -351,7 +348,6 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens columnwise_option, force_pow_2_scales, noop_tensor->data, stream); break; } -#endif //#ifndef __HIP_PLATFORM_AMD__ default: NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); } diff --git a/transformer_engine/common/recipe/recipe_common.cuh b/transformer_engine/common/recipe/recipe_common.cuh index 07839407a..e9ae0f1b2 100644 --- a/transformer_engine/common/recipe/recipe_common.cuh +++ b/transformer_engine/common/recipe/recipe_common.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -61,10 +63,16 @@ __device__ __forceinline__ float compute_scale_from_amax(float amax, float max_f template __device__ __forceinline__ float compute_scale_from_types(const float amax, const float eps, const float pow_2_scaling) { +// On AMD host, TypeExtrema::max is non-constexpr (runtime FNUZ detection) +#if defined(__HIP_PLATFORM_AMD__) && !defined(__HIP_DEVICE_COMPILE__) + const float fp8_max = detail::TypeExtrema::max; + const float value_for_inf = detail::TypeExtrema::max; +#else constexpr float fp8_max = TypeInfo::max_finite_value; // NOTE: We're relying on compute_scale_from_amax to have behavior where it // clips the mantissa of the max_finite_value if power of 2 scaling applies. constexpr float value_for_inf = TypeInfo::max_finite_value; +#endif return compute_scale_from_amax(amax, fp8_max, pow_2_scaling, eps, value_for_inf); } diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index 3a8536587..c89eebb37 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -28,7 +30,11 @@ namespace { // const values configuration +#if defined(__HIP_PLATFORM_AMD__) && !defined(__gfx1250__) +constexpr size_t kThreadsPerWarp = 64; +#else constexpr size_t kThreadsPerWarp = 32; +#endif #ifdef TMA_HW_SUPPORTED constexpr size_t BLOCK_TILE_DIM = 128; constexpr size_t WARP_TILE_DIM_X = 32; @@ -40,8 +46,12 @@ constexpr size_t BLOCK_TILE_DIM = 128; constexpr size_t WARP_TILE_DIM_X = 64; constexpr size_t WARP_TILE_DIM_Y = 32; constexpr size_t THREAD_TILE_DIM_X = 8; +#if defined(__HIP_PLATFORM_AMD__) && !defined(__gfx1250__) +constexpr size_t THREAD_TILE_DIM_Y = 4; +#else constexpr size_t THREAD_TILE_DIM_Y = 8; #endif +#endif #ifdef TMA_HW_SUPPORTED constexpr size_t NUM_BYTES_PER_BANK = 4; @@ -62,6 +72,7 @@ constexpr size_t NUM_THREADS_Y_IN_WARP = kThreadsPerWarp / NUM_THREADS_X_IN_WARP #define MIN(a, b) (a < b ? a : b) +#ifndef __HIP_PLATFORM_AMD__ template __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel(const IType* const input, OType* const output_c, @@ -247,6 +258,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) #endif } } +#endif // __HIP_PLATFORM_AMD__ template __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel_notaligned( @@ -357,10 +369,18 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose } } // Reduce amax in the warp (32x32 tile) +#ifdef __HIP_PLATFORM_AMD__ +#pragma unroll + for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) { + warp_tile_amax = fmaxf(amax, __shfl_xor(amax, delta, kThreadsPerWarp)); + amax = warp_tile_amax; + } +#else warp_tile_amax = warp_reduce_max(amax); // broadcast the amax to all threads in a warp from the lane 0 constexpr int lane_zero = 0; warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); +#endif // reduce warp_tile_amax across multiple warps in a thread block using shared mem if (tid_in_warp == 0) { @@ -456,6 +476,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose } } +#ifndef __HIP_PLATFORM_AMD__ template CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size_t global_dim_y) { CUtensorMapDataType dataType; @@ -473,6 +494,7 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size /*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType) * 8); return tensor_map_output_trans; } +#endif // __HIP_PLATFORM_AMD__ } // namespace } // namespace transformer_engine @@ -543,6 +565,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor return_transpose, kReturnTranspose, dim3 grid(num_blocks_x, num_blocks_y, 1); +#ifndef __HIP_PLATFORM_AMD__ const bool full_tile = row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; @@ -573,6 +596,18 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, pow_2_scale, noop_ptr); } // full-tile +#else + block_scaled_cast_transpose_kernel_notaligned + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, + pow_2_scale, noop_ptr); +#endif // __HIP_PLATFORM_AMD__ ) // return_transpose ) // OutputType ) // InputType diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index df869b433..86ef355ba 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -145,14 +147,23 @@ Step 3 (if columnwise transpose is False, COMPACT format): Skip Transpose, cast */ // clang-format on +#if defined(__HIP_PLATFORM_AMD__) && !defined(__gfx1250__) +constexpr size_t kThreadsPerWarp = 64; +#else constexpr size_t kThreadsPerWarp = 32; +#endif // Hyperparameters for performance tuning constexpr int kTileDim = 128; // Fixed to 128 beacause we are using 1x128 and 128x1 quantization constexpr int kNVecIn = 8; // The number of elements each LDG touches constexpr int kNVecOut = 16; // The number of elements each STG touches constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches + +#if defined(__HIP_PLATFORM_AMD__) && !defined(__gfx1250__) +constexpr int kThreadsPerBlock = 512; // Thread block size, 8 warps (wave64) in total +#else constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total +#endif // Auto-calculated constants, do not modify directly) static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem"); @@ -166,6 +177,78 @@ static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kT static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; +// gfx942 (MI300) has 64KB LDS; the full 128x128 fp32 staging tile overflows it. +#if defined(__HIP_PLATFORM_AMD__) && (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx942__)) +constexpr int kChunkCol = 64; +constexpr int kNumChunks = kTileDim / kChunkCol; +static_assert(kTileDim % kChunkCol == 0, "kTileDim must be divisible by kChunkCol"); +constexpr int kSMemColChunk = (kChunkCol / kNVecSMem) + 1; +constexpr int kSMemSizeChunk = kSMemRow * kSMemColChunk * kNVecSMem; +constexpr int kNumThreadsLoadChunk = kChunkCol / kNVecIn; +constexpr size_t kLdsLimitBytes = 64 * 1024; + +template +constexpr bool use_chunked_lds() { + return sizeof(IType) * static_cast(kSMemSize) > kLdsLimitBytes; +} + +template +size_t host_smem_bytes() { + if (transformer_engine::cuda::sm_arch() < 95 && + sizeof(IType) * static_cast(kSMemSize) > kLdsLimitBytes) { + return static_cast(kSMemSizeChunk) * sizeof(IType); + } + return static_cast(kSMemSize) * sizeof(IType); +} + +template +__device__ __forceinline__ void load_chunk_to_smem(Vec* smem, + const IType* const input, + const size_t row_length, const size_t num_rows, + const int chunk) { + using SMemVec = Vec; + union IVec { + Vec input_type; + Vec smem_type; + }; + constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoadChunk; + constexpr int num_iterations = kTileDim / r_stride; + const int c_s = (threadIdx.x % kNumThreadsLoadChunk) * (kNVecIn / kNVecSMem); + int r_s = threadIdx.x / kNumThreadsLoadChunk; + const size_t c_g = static_cast(blockIdx.x) * kTileDim + + static_cast(chunk) * kChunkCol + c_s * kNVecSMem; + size_t r_g = static_cast(blockIdx.y) * kTileDim + r_s; + const size_t stride_g = static_cast(r_stride) * row_length; + const size_t num_ele = + c_g < row_length ? min(static_cast(kNVecIn), row_length - c_g) : 0; + const IType* input_g = &input[r_g * row_length + c_g]; +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + IVec input_vec; + if constexpr (kAligned) { + input_vec.input_type.load_from(input_g); + } else { + if (r_g < num_rows) { + input_vec.input_type.load_from_elts(input_g, 0, num_ele); + } else { + input_vec.input_type.clear(); + } + } +#pragma unroll + for (int i = 0; i < kNVecIn / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem[r * kSMemColChunk + c] = input_vec.smem_type.data.elt[i]; + } + input_g += stride_g; + r_s += r_stride; + if constexpr (!kAligned) { + r_g += r_stride; + } + } +} +#endif // gfx942 helpers + template __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( const IType* const input, OType* const output_c, OType* const output_t, @@ -196,6 +279,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo SMemVec* smem = reinterpret_cast(&smem_base[0]); // Step 1: Load input to shared memory +#if defined(__HIP_PLATFORM_AMD__) && (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx942__)) + if constexpr (!use_chunked_lds()) +#endif { constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoad; // stride in rows of shared memory constexpr int num_iterations = kTileDim / r_stride; @@ -241,6 +327,88 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo __syncthreads(); // Step 2: Cast and store to output_c +#if defined(__HIP_PLATFORM_AMD__) && (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx942__)) + if (return_rowwise && use_chunked_lds()) { + constexpr int r_stride = kThreadsPerBlock / kNumThreadsStore; + constexpr int num_iterations = kTileDim / r_stride; + + int r_s = threadIdx.x / kNumThreadsStore; + const size_t c_g = static_cast(blockIdx.x) * kTileDim + + static_cast(threadIdx.x % kNumThreadsStore) * kNVecOut; + size_t r_g = static_cast(blockIdx.y) * kTileDim + r_s; + const size_t stride_g = static_cast(r_stride) * row_length; + const size_t num_ele = + c_g < row_length ? min(static_cast(kNVecOut), row_length - c_g) : 0; + const IType* input_g = &input[r_g * row_length + c_g]; + OType* output_g = &output_c[r_g * row_length + c_g]; + const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore; + const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0; +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + // Step 2.1: Load kNVecOut contiguous elements directly from global input + Vec in_vec; + if constexpr (kAligned) { + in_vec.load_from(input_g); + } else { + if (r_g < num_rows) { + in_vec.load_from_elts(input_g, 0, num_ele); + } else { + in_vec.clear(); + } + } + // Step 2.2: Compute local amax + CType amax = 0; +#pragma unroll + for (int j = 0; j < kNVecOut; ++j) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(in_vec.data.elt[j])); + } + // Step 2.3: Reduce amax +#pragma unroll + for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { + const float other_amax = __shfl_down(amax, delta, kNumThreadsStore); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax = __shfl(amax, src_lane, kNumThreadsStore); + // Step 2.4: Compute scale + CType scale = compute_scale_from_types(amax, epsilon, pow_2_scaling); + // Step 2.5: Write scale_inv + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (r_g < num_rows); + } + if (write_scale_inv) { + CType scale_inv = 1.0 / scale; + size_t row_idx = static_cast(blockIdx.y) * kTileDim + r_s; + size_t col_idx = static_cast(blockIdx.x); + tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; + } + // Step 2.6: Quantize + OVec output_vec; +#pragma unroll + for (int j = 0; j < kNVecOut; ++j) { + output_vec.data.elt[j] = static_cast(static_cast(in_vec.data.elt[j]) * scale); + } + // Step 2.7: Store output_c + if constexpr (kAligned) { + output_vec.store_to(output_g); + } else { + if (r_g < num_rows) { + output_vec.store_to_elts(output_g, 0, num_ele); + } + } + // Step 2.8: Advance + input_g += stride_g; + output_g += stride_g; + r_s += r_stride; + if constexpr (!kAligned) { + r_g += r_stride; + } + } + } else +#endif if (return_rowwise) { constexpr int r_stride = kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory @@ -258,8 +426,10 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of // the first thread to do the reduction. const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore; +#ifndef __HIP_PLATFORM_AMD__ // This mask represents which threads should do the reduction together. const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane; +#endif const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0; #pragma unroll for (int iter = 0; iter < num_iterations; ++iter) { @@ -284,12 +454,20 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo // Step 2.3: Reduce amax #pragma unroll for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { +#ifdef __HIP_PLATFORM_AMD__ + const float other_amax = __shfl_down(amax, delta, kNumThreadsStore); +#else const float other_amax = __shfl_down_sync(mask, amax, delta); +#endif __builtin_assume(amax >= 0); __builtin_assume(other_amax >= 0); amax = fmaxf(amax, other_amax); } +#ifdef __HIP_PLATFORM_AMD__ + amax = __shfl(amax, src_lane, kNumThreadsStore); +#else amax = __shfl_sync(mask, amax, src_lane); +#endif CType scale; // Step 2.4: Compute scale scale = compute_scale_from_types(amax, epsilon, pow_2_scaling); @@ -332,6 +510,77 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } // Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t +#if defined(__HIP_PLATFORM_AMD__) && (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx942__)) + if (return_columnwise_gemm_ready && use_chunked_lds()) { + const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; + const int c_s = threadIdx.x / kNumThreadsStore; + const size_t c_g = static_cast(blockIdx.y) * kTileDim + r_s; + const size_t num_ele = + c_g < num_rows ? min(static_cast(kNVecOut), num_rows - c_g) : 0; + const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore; + const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0; + const bool col_active = c_s < (kChunkCol / kNVecSMem); +#pragma unroll + for (int chunk = 0; chunk < kNumChunks; ++chunk) { + if (chunk != 0) { + __syncthreads(); + } + load_chunk_to_smem(smem, input, row_length, num_rows, chunk); + __syncthreads(); + + const size_t r_g = static_cast(blockIdx.x) * kTileDim + + static_cast(chunk) * kChunkCol + c_s * kNVecSMem; + OType* output_g = &output_t[r_g * num_rows + c_g]; + if (col_active) { + SMemVec smem_vec[kNVecOut]; +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + smem_vec[i] = smem[(r_s + i) * kSMemColChunk + c_s]; + } +#pragma unroll + for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) { + CType amax = 0; +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[smem_idx])); + } +#pragma unroll + for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { + const float other_amax = __shfl_down(amax, delta, kNumThreadsStore); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax = __shfl(amax, src_lane, kNumThreadsStore); + CType scale = compute_scale_from_types(amax, epsilon, pow_2_scaling); + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (r_g + smem_idx < row_length); + } + if (write_scale_inv) { + CType scale_inv = 1.0 / scale; + size_t row_idx = r_g + smem_idx; + size_t col_idx = static_cast(blockIdx.y); + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; + } + OVec output_vec; +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + output_vec.data.elt[i] = + static_cast(static_cast(smem_vec[i].data.elt[smem_idx]) * scale); + } + if constexpr (kAligned) { + output_vec.store_to(output_g + smem_idx * num_rows); + } else { + if (r_g + smem_idx < row_length) { + output_vec.store_to_elts(output_g + smem_idx * num_rows, 0, num_ele); + } + } + } + } + } + } else +#endif if (return_columnwise_gemm_ready) { constexpr int c_stride = kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory @@ -349,8 +598,10 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of // the first thread to do the reduction. const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore; +#ifndef __HIP_PLATFORM_AMD__ // This mask represents which threads should do the reduction together. const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane; +#endif const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0; #pragma unroll for (int iter = 0; iter < num_iterations; ++iter) { @@ -373,12 +624,20 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo // Step 3.3: Reduce amax #pragma unroll for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { +#ifdef __HIP_PLATFORM_AMD__ + const float other_amax = __shfl_down(amax, delta, kNumThreadsStore); +#else const float other_amax = __shfl_down_sync(mask, amax, delta); +#endif __builtin_assume(amax >= 0); __builtin_assume(other_amax >= 0); amax = fmaxf(amax, other_amax); } +#ifdef __HIP_PLATFORM_AMD__ + amax = __shfl(amax, src_lane, kNumThreadsStore); +#else amax = __shfl_sync(mask, amax, src_lane); +#endif // Step 3.4: Compute scale CType scale; scale = compute_scale_from_types(amax, epsilon, pow_2_scaling); @@ -419,6 +678,93 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } // Step 4 (return_columnwise_compact): cast in 128x1 style and store to output, skip transpose +#if defined(__HIP_PLATFORM_AMD__) && (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx942__)) + if (return_columnwise_compact && use_chunked_lds()) { + constexpr int kThreadTileRow = kTileDim / kThreadsPerWarp; + constexpr int kThreadTileCol = kNVecOut; + using RegVec = Vec; + using RegScaleVec = Vec; + constexpr int num_smem_reads = kNVecOut / kNVecSMem; + constexpr int warps_per_chunk = kNumWarps / kNumChunks; + const int thr_idx_in_warp = threadIdx.x % kThreadsPerWarp; + const int warp_idx = threadIdx.x / kThreadsPerWarp; + const int warp_in_chunk = warp_idx % warps_per_chunk; + const int r_s = thr_idx_in_warp * kThreadTileRow; + const int c_s = warp_in_chunk * num_smem_reads; + size_t r_g = static_cast(blockIdx.y) * kTileDim + r_s; +#pragma unroll + for (int chunk = 0; chunk < kNumChunks; ++chunk) { + if (chunk != 0) { + __syncthreads(); + } + load_chunk_to_smem(smem, input, row_length, num_rows, chunk); + __syncthreads(); + const bool warp_active = (warp_idx / warps_per_chunk) == chunk; + const size_t c_g = static_cast(blockIdx.x) * kTileDim + + static_cast(chunk) * kChunkCol + c_s * kNVecSMem; + const size_t num_ele = c_g < row_length + ? min(static_cast(kThreadTileCol), row_length - c_g) + : 0; + if (warp_active) { + RegVec reg_vec[kThreadTileRow]; + RegScaleVec thr_scale; +#pragma unroll + for (int i = 0; i < kThreadTileRow; ++i) { + int r = r_s + i; +#pragma unroll + for (int j = 0; j < num_smem_reads; ++j) { + int c = c_s + j; + SMemVec smem_vec = smem[r * kSMemColChunk + c]; +#pragma unroll + for (int k = 0; k < kNVecSMem; ++k) { + reg_vec[i].data.elt[j * kNVecSMem + k] = smem_vec.data.elt[k]; + } + } + } +#pragma unroll + for (int reg_idx = 0; reg_idx < kThreadTileCol; ++reg_idx) { + CType amax = 0; +#pragma unroll + for (int i = 0; i < kThreadTileRow; ++i) { + amax = fmaxf(amax, fabsf(reg_vec[i].data.elt[reg_idx])); + } + const bool is_src_lane = thr_idx_in_warp == 0; + amax = warp_reduce_max(amax); + constexpr int lane_zero = 0; + amax = __shfl(amax, lane_zero, kThreadsPerWarp); + CType scale = compute_scale_from_types(amax, epsilon, pow_2_scaling); + thr_scale.data.elt[reg_idx] = scale; + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (c_g + reg_idx < row_length); + } + if (write_scale_inv) { + CType scale_inv = 1.0 / scale; + size_t row_idx = static_cast(blockIdx.y); + size_t col_idx = c_g + reg_idx; + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; + } + } + for (int row_idx = 0; row_idx < kThreadTileRow; ++row_idx) { + OType* output_g = &output_t[(r_g + row_idx) * row_length + c_g]; + OVec output_vec; +#pragma unroll + for (int i = 0; i < kThreadTileCol; ++i) { + output_vec.data.elt[i] = static_cast( + static_cast(reg_vec[row_idx].data.elt[i]) * thr_scale.data.elt[i]); + } + if constexpr (kAligned) { + output_vec.store_to(output_g); + } else { + if (r_g + row_idx < num_rows) { + output_vec.store_to_elts(output_g, 0, num_ele); + } + } + } + } + } + } else +#endif if (return_columnwise_compact) { // thread tile should be 4x16, 16 means 8 smem reads constexpr int kThreadTileRow = kTileDim / kThreadsPerWarp; @@ -474,7 +820,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const bool is_src_lane = thr_idx_in_warp == 0; amax = warp_reduce_max(amax); constexpr int lane_zero = 0; +#ifdef __HIP_PLATFORM_AMD__ + amax = __shfl(amax, lane_zero, kThreadsPerWarp); +#else amax = __shfl_sync(0xFFFFFFFF, amax, lane_zero); +#endif // Step 3.4: Compute scale CType scale; scale = compute_scale_from_types(amax, epsilon, pow_2_scaling); @@ -617,7 +967,11 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor TRANSFORMER_ENGINE_SWITCH_CONDITION( full_tile, kAligned, +#if defined(__HIP_PLATFORM_AMD__) && (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx942__)) + size_t smem_bytes = host_smem_bytes(); +#else size_t smem_bytes = kSMemSize * sizeof(InputType); +#endif // shared memory must be requested up if (smem_bytes >= 48 * 1024) { cudaError_t err = cudaFuncSetAttribute( diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index dcd12b7a0..860f79891 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -103,7 +103,10 @@ def check_nvfp4_support() -> Tuple[bool, str]: def check_fp8_block_scaling_support() -> Tuple[bool, str]: """Return if fp8 block scaling support is available""" if IS_HIP_EXTENSION: - return False, "FP8 block scaled gemm not yet supported for ROCm" + gpu_arch = get_device_compute_capability() + if gpu_arch >= (9, 4): + return True, "" + return False, "Device arch gfx94x or newer is required for FP8 block scaling execution." if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9: return True, "" return (