-
Notifications
You must be signed in to change notification settings - Fork 32
enable blockwise FP8 quantization on rocm #609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
8335488
6226301
bdf905e
676d1f0
f8a0fc5
231e381
e158d3e
70c35df
7ede21d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copyright
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
|
ipanfilo marked this conversation as resolved.
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| /************************************************************************* | ||
| * This file was modified for portability to AMDGPU | ||
| * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
|
|
@@ -28,7 +30,11 @@ namespace { | |
|
|
||
| // const values configuration | ||
|
|
||
| #if defined(__HIP_PLATFORM_AMD__) && !defined(__gfx1250__) | ||
| constexpr size_t kThreadsPerWarp = 64; | ||
|
ipanfilo marked this conversation as resolved.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is kThreadsPerWarp only used by device code and not any dispatch functions?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is used to compute NUM_THREADS_Y_IN_WARP in L71 for constexpr computation. And other than this, it is only used in the device code. |
||
| #else | ||
| constexpr size_t kThreadsPerWarp = 32; | ||
| #endif | ||
| #ifdef TMA_HW_SUPPORTED | ||
| constexpr size_t BLOCK_TILE_DIM = 128; | ||
| constexpr size_t WARP_TILE_DIM_X = 32; | ||
|
|
@@ -40,8 +46,12 @@ constexpr size_t BLOCK_TILE_DIM = 128; | |
| constexpr size_t WARP_TILE_DIM_X = 64; | ||
| constexpr size_t WARP_TILE_DIM_Y = 32; | ||
| constexpr size_t THREAD_TILE_DIM_X = 8; | ||
| #if defined(__HIP_PLATFORM_AMD__) && !defined(__gfx1250__) | ||
| constexpr size_t THREAD_TILE_DIM_Y = 4; | ||
| #else | ||
| constexpr size_t THREAD_TILE_DIM_Y = 8; | ||
| #endif | ||
| #endif | ||
|
|
||
| #ifdef TMA_HW_SUPPORTED | ||
| constexpr size_t NUM_BYTES_PER_BANK = 4; | ||
|
|
@@ -62,6 +72,7 @@ constexpr size_t NUM_THREADS_Y_IN_WARP = kThreadsPerWarp / NUM_THREADS_X_IN_WARP | |
|
|
||
| #define MIN(a, b) (a < b ? a : b) | ||
|
|
||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| template <bool kReturnTranspose, typename CType, typename IType, typename OType> | ||
| __global__ void __launch_bounds__(THREADS_PER_BLOCK) | ||
| block_scaled_cast_transpose_kernel(const IType* const input, OType* const output_c, | ||
|
|
@@ -247,6 +258,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) | |
| #endif | ||
| } | ||
| } | ||
| #endif // __HIP_PLATFORM_AMD__ | ||
|
|
||
| template <bool kReturnTranspose, typename CType, typename IType, typename OType> | ||
| __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel_notaligned( | ||
|
|
@@ -357,10 +369,18 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose | |
| } | ||
| } | ||
| // Reduce amax in the warp (32x32 tile) | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| #pragma unroll | ||
| for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) { | ||
|
ipanfilo marked this conversation as resolved.
|
||
| warp_tile_amax = fmaxf(amax, __shfl_xor(amax, delta, kThreadsPerWarp)); | ||
| amax = warp_tile_amax; | ||
| } | ||
| #else | ||
| warp_tile_amax = warp_reduce_max<kThreadsPerWarp>(amax); | ||
| // broadcast the amax to all threads in a warp from the lane 0 | ||
| constexpr int lane_zero = 0; | ||
| warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); | ||
| #endif | ||
|
|
||
| // reduce warp_tile_amax across multiple warps in a thread block using shared mem | ||
| if (tid_in_warp == 0) { | ||
|
|
@@ -456,6 +476,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose | |
| } | ||
| } | ||
|
|
||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| template <typename OutputType> | ||
| CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size_t global_dim_y) { | ||
| CUtensorMapDataType dataType; | ||
|
|
@@ -473,6 +494,7 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size | |
| /*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType) * 8); | ||
| return tensor_map_output_trans; | ||
| } | ||
| #endif // __HIP_PLATFORM_AMD__ | ||
|
|
||
| } // namespace | ||
| } // namespace transformer_engine | ||
|
|
@@ -543,6 +565,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor | |
| return_transpose, kReturnTranspose, | ||
|
|
||
| dim3 grid(num_blocks_x, num_blocks_y, 1); | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| const bool full_tile = | ||
| row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; | ||
|
|
||
|
alextmagro marked this conversation as resolved.
|
||
|
|
@@ -573,6 +596,18 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor | |
| scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, | ||
| pow_2_scale, noop_ptr); | ||
| } // full-tile | ||
| #else | ||
| block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, InputType, | ||
| OutputType> | ||
| <<<grid, THREADS_PER_BLOCK, 0, stream>>>( | ||
| reinterpret_cast<const InputType*>(input.dptr), | ||
| reinterpret_cast<OutputType*>(output.dptr), | ||
| reinterpret_cast<OutputType*>(output_t.dptr), | ||
| reinterpret_cast<float*>(scale_inv.dptr), | ||
| reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, | ||
| scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, | ||
| pow_2_scale, noop_ptr); | ||
| #endif // __HIP_PLATFORM_AMD__ | ||
| ) // return_transpose | ||
| ) // OutputType | ||
| ) // InputType | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.