diff --git a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh index 8cfada20b..d9d3dd0f9 100644 --- a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh @@ -28,6 +28,7 @@ #ifdef __HIP_PLATFORM_AMD__ #include "./rocm_vectorized_2d.cuh" +#include "../../util/rocm_device_utils.cuh" #endif namespace transformer_engine { diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index c657d930a..dca7752c7 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -32,6 +32,7 @@ #ifdef __HIP_PLATFORM_AMD__ #include "./rocm_vectorized_2d.cuh" +#include "../../util/rocm_device_utils.cuh" #endif namespace transformer_engine { diff --git a/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh index fee3db3e0..f338fe8e0 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh @@ -176,7 +176,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // --- Act rowwise quantization --- { __builtin_assume(act_amax >= 0); - const float scale_amax = subwarp_reduce_max_broadcast(act_amax); + const float scale_amax = rocm_subwarp_allreduce(act_amax, rocm_op::max{}); const e8m0_t biased_exp = ptx::float_to_e8m0(scale_amax * Quantized_Limits::max_norm_rcp); const float scale_inv = ptx::exp2f_rcp(biased_exp); @@ -210,7 +210,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // --- Gate rowwise quantization (BWD only) --- if constexpr (IS_DGATED) { __builtin_assume(gate_amax >= 0); - const float scale_amax = subwarp_reduce_max_broadcast(gate_amax); + const float scale_amax = rocm_subwarp_allreduce(gate_amax, rocm_op::max{}); const e8m0_t biased_exp = ptx::float_to_e8m0(scale_amax * Quantized_Limits::max_norm_rcp); const float scale_inv = ptx::exp2f_rcp(biased_exp); @@ -333,7 +333,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // --- Act rowwise quantization --- { __builtin_assume(act_amax >= 0); - const float scale_amax = subwarp_reduce_max_broadcast(act_amax); + const float scale_amax = rocm_subwarp_allreduce(act_amax, rocm_op::max{}); const e8m0_t biased_exp = ptx::float_to_e8m0(scale_amax * Quantized_Limits::max_norm_rcp); const float scale_inv = ptx::exp2f_rcp(biased_exp); @@ -367,7 +367,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // --- Gate rowwise quantization (BWD only) --- if constexpr (IS_DGATED) { __builtin_assume(gate_amax >= 0); - const float scale_amax = subwarp_reduce_max_broadcast(gate_amax); + const float scale_amax = rocm_subwarp_allreduce(gate_amax, rocm_op::max{}); const e8m0_t biased_exp = ptx::float_to_e8m0(scale_amax * Quantized_Limits::max_norm_rcp); const float scale_inv = ptx::exp2f_rcp(biased_exp); diff --git a/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh index 7a9a0d696..899038d22 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh @@ -7,8 +7,6 @@ // drop-in replacement for rocm quantize_mxfp8 kernels //#include "hip/hip_runtime.h" //dummy include to prevent hipification adding this header -#include "../../util/rocm_device_utils.cuh" - constexpr size_t MXFP8_CHUNK_DIM_Y = 64; constexpr size_t MXFP8_CHUNK_DIM_X = 64; constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; @@ -163,7 +161,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __builtin_assume(thread_amax >= 0); block_amax = fmaxf(block_amax, thread_amax); - const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); + const float subwarp_amax = rocm_subwarp_allreduce(thread_amax, rocm_op::max{}); const e8m0_t biased_exponent = ptx::float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); @@ -309,7 +307,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __builtin_assume(thread_amax >= 0); block_amax = fmaxf(block_amax, thread_amax); - const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); + const float subwarp_amax = rocm_subwarp_allreduce(thread_amax, rocm_op::max{}); const e8m0_t biased_exponent = ptx::float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); diff --git a/transformer_engine/common/util/rocm_device_utils.cuh b/transformer_engine/common/util/rocm_device_utils.cuh index 89c49b533..96a1b0759 100644 --- a/transformer_engine/common/util/rocm_device_utils.cuh +++ b/transformer_engine/common/util/rocm_device_utils.cuh @@ -135,6 +135,31 @@ __device__ __forceinline__ int rocm_upper_bound(const T* arr, int n, T val) { return lo; } +// Binary reduction ops for rocm_subwarp_allreduce +struct rocm_op { + struct max { + __device__ __forceinline__ float operator()(float a, float b) const { return fmaxf(a, b); } + }; + + struct min { + __device__ __forceinline__ float operator()(float a, float b) const { return fminf(a, b); } + }; + + struct sum { + __device__ __forceinline__ float operator()(float a, float b) const { return a + b; } + }; +}; + +// Butterfly all-reduce within a subwarp. All lanes get the result. +template +__device__ __forceinline__ T rocm_subwarp_allreduce(T val, const OP &op) { +#pragma unroll + for (int offset = WIDTH / 2; offset > 0; offset >>= 1) { + val = op(val, __shfl_xor(val, offset, WIDTH)); + } + return val; +} + template __device__ __forceinline__ float rocm_block_reduce_max(float val, int warp_id) { __shared__ float staging[WARPS];