diff --git a/build_tools/hipify/hipify.py b/build_tools/hipify/hipify.py index 9487e0bb2..ceb350648 100644 --- a/build_tools/hipify/hipify.py +++ b/build_tools/hipify/hipify.py @@ -52,7 +52,8 @@ def do_hipify(te_root: Union[Path, str], src_dir: Union[Path, str], project_directory=src_dir, output_directory=src_dir, includes=["*/common/*", str(Path(src_dir)/"*")], - ignores=["*/amd_detail/*", "*/aotriton/*", "*/ck_fused_attn/*", "*/rocshmem_api/*"], + ignores=["*/amd_detail/*", "*/aotriton/*", "*/ck_fused_attn/*", "*/rocshmem_api/*", + "*/small_seq_kernels/*"], header_include_dirs=include_dirs, custom_map_list= te_root / "build_tools" / "hipify" / "custom_map.json", extra_files=[], diff --git a/tests/cpp/small_seq_kernels/CMakeLists.txt b/tests/cpp/small_seq_kernels/CMakeLists.txt new file mode 100644 index 000000000..1b635530f --- /dev/null +++ b/tests/cpp/small_seq_kernels/CMakeLists.txt @@ -0,0 +1,221 @@ +cmake_minimum_required(VERSION 3.21) + +# Declare project with both CXX and HIP languages. +# Requires hip-lang CMake package (available under ${ROCM_PATH}/lib/cmake). +project(crossattn_hip_kernel LANGUAGES CXX HIP) + +# --------------------------------------------------------------------------- +# ROCm / HIP setup +# --------------------------------------------------------------------------- + +if(NOT DEFINED ROCM_PATH) + set(ROCM_PATH "/opt/rocm" CACHE PATH "Path to ROCm installation") +endif() + +list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib/cmake") +find_package(hip REQUIRED CONFIG) + +# GPU architecture — override with -DGPU_TARGETS=gfx906 etc. +# set(GPU_TARGETS "gfx950" CACHE STRING "GPU architecture targets") +set(GPU_TARGETS "gfx942" CACHE STRING "GPU architecture targets") +set(CMAKE_HIP_ARCHITECTURES "${GPU_TARGETS}") + +# --------------------------------------------------------------------------- +# Language standards +# --------------------------------------------------------------------------- + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_HIP_STANDARD 17) +set(CMAKE_HIP_STANDARD_REQUIRED ON) + +# --------------------------------------------------------------------------- +# Kernel headers +# +# The MFMA kernel headers (attn_*.h) are vendored once for the Transformer +# Engine build under transformer_engine/common/fused_attn_rocm/small_seq_kernels. +# These reference tests include them directly from that canonical location +# instead of keeping a duplicate copy here. +# --------------------------------------------------------------------------- + +set(KERNEL_INCLUDE_DIR + "${CMAKE_SOURCE_DIR}/../../../transformer_engine/common/fused_attn_rocm/small_seq_kernels" + CACHE PATH "Path to the vendored small-seq MFMA kernel headers") + +# --------------------------------------------------------------------------- +# CPU reference library +# +# Compiled as plain C++ (no HIP device code) and linked by both test +# executables. Avoids duplicating CPU reference compilation. +# --------------------------------------------------------------------------- + +add_library(attn_ref STATIC + ref/attn_fwd_ref.cpp + ref/attn_bwd_ref.cpp +) +# Mark as HIP so clang++ handles hip_bfloat16 conversions (float<->bfloat16) +# correctly. No device kernels are compiled; clang just enables the host-side +# HIP type support. +set_source_files_properties( + ref/attn_fwd_ref.cpp + ref/attn_bwd_ref.cpp + PROPERTIES LANGUAGE HIP +) +target_include_directories(attn_ref PUBLIC + ${KERNEL_INCLUDE_DIR} + ${CMAKE_SOURCE_DIR}/ref +) +target_compile_options(attn_ref PRIVATE --offload-arch=${GPU_TARGETS}) +target_link_libraries(attn_ref PUBLIC hip::host) + +# --------------------------------------------------------------------------- +# test_fwd executable +# +# test_fwd.cpp includes attn_fwd.h which contains __global__ kernels. +# Mark it as LANGUAGE HIP so clang++ uses -x hip and sees <<< >>> syntax. +# --------------------------------------------------------------------------- + +add_executable(test_fwd tests/test_fwd.cpp) +set_source_files_properties(tests/test_fwd.cpp PROPERTIES LANGUAGE HIP) + +target_include_directories(test_fwd PRIVATE + ${KERNEL_INCLUDE_DIR} + ${CMAKE_SOURCE_DIR}/ref + ${CMAKE_SOURCE_DIR}/tests +) +target_link_libraries(test_fwd PRIVATE attn_ref hip::host) +target_compile_options(test_fwd PRIVATE -O3 --offload-arch=${GPU_TARGETS}) + +# --------------------------------------------------------------------------- +# test_bwd executable +# --------------------------------------------------------------------------- + +add_executable(test_bwd tests/test_bwd.cpp) +set_source_files_properties(tests/test_bwd.cpp PROPERTIES LANGUAGE HIP) + +target_include_directories(test_bwd PRIVATE + ${KERNEL_INCLUDE_DIR} + ${CMAKE_SOURCE_DIR}/ref + ${CMAKE_SOURCE_DIR}/tests +) +target_link_libraries(test_bwd PRIVATE attn_ref hip::host) +target_compile_options(test_bwd PRIVATE -O3 --offload-arch=${GPU_TARGETS}) + +# --------------------------------------------------------------------------- +# test_fwd_mfma executable +# +# Tests the fused MFMA forward kernel (attn_fwd_mfma.h). +# --------------------------------------------------------------------------- + +add_executable(test_fwd_mfma tests/test_fwd_mfma.cpp) +set_source_files_properties(tests/test_fwd_mfma.cpp PROPERTIES LANGUAGE HIP) + +target_include_directories(test_fwd_mfma PRIVATE + ${KERNEL_INCLUDE_DIR} + ${CMAKE_SOURCE_DIR}/ref + ${CMAKE_SOURCE_DIR}/tests +) +target_link_libraries(test_fwd_mfma PRIVATE attn_ref hip::host) +target_compile_options(test_fwd_mfma PRIVATE -O3 --offload-arch=${GPU_TARGETS}) + +# --------------------------------------------------------------------------- +# test_fwd_mfma_16x16 executable +# +# Tests the fused MFMA 16x16x16 forward kernel (attn_fwd_mfma_16x16.h). +# --------------------------------------------------------------------------- + +add_executable(test_fwd_mfma_16x16 tests/test_fwd_mfma_16x16.cpp) +set_source_files_properties(tests/test_fwd_mfma_16x16.cpp PROPERTIES LANGUAGE HIP) + +target_include_directories(test_fwd_mfma_16x16 PRIVATE + ${KERNEL_INCLUDE_DIR} + ${CMAKE_SOURCE_DIR}/ref + ${CMAKE_SOURCE_DIR}/tests +) +target_link_libraries(test_fwd_mfma_16x16 PRIVATE attn_ref hip::host) +target_compile_options(test_fwd_mfma_16x16 PRIVATE -O3 --offload-arch=${GPU_TARGETS}) + +# --------------------------------------------------------------------------- +# test_mfma_head_dims executable +# +# Forward MFMA 16x16 correctness for head_dim 128, 256, 512 (small config). +# --------------------------------------------------------------------------- + +add_executable(test_mfma_head_dims tests/test_mfma_head_dims.cpp) +set_source_files_properties(tests/test_mfma_head_dims.cpp PROPERTIES LANGUAGE HIP) + +target_include_directories(test_mfma_head_dims PRIVATE + ${KERNEL_INCLUDE_DIR} + ${CMAKE_SOURCE_DIR}/ref + ${CMAKE_SOURCE_DIR}/tests +) +target_link_libraries(test_mfma_head_dims PRIVATE attn_ref hip::host) +target_compile_options(test_mfma_head_dims PRIVATE -O3 --offload-arch=${GPU_TARGETS}) + +# --------------------------------------------------------------------------- +# test_fwd_mfma_multiq executable +# +# Tests multi-Q dispatch across 4x4x4 and 16x16x16 MFMA kernels. +# --------------------------------------------------------------------------- + +add_executable(test_fwd_mfma_multiq tests/test_fwd_mfma_multiq.cpp) +set_source_files_properties(tests/test_fwd_mfma_multiq.cpp PROPERTIES LANGUAGE HIP) + +target_include_directories(test_fwd_mfma_multiq PRIVATE + ${KERNEL_INCLUDE_DIR} + ${CMAKE_SOURCE_DIR}/ref + ${CMAKE_SOURCE_DIR}/tests +) +target_link_libraries(test_fwd_mfma_multiq PRIVATE attn_ref hip::host) +target_compile_options(test_fwd_mfma_multiq PRIVATE -O3 --offload-arch=${GPU_TARGETS}) + +# --------------------------------------------------------------------------- +# test_bwd_mfma_16x16 executable +# +# Tests the MFMA 16x16x16 backward kernels (attn_bwd_mfma_16x16.h). +# --------------------------------------------------------------------------- + +add_executable(test_bwd_mfma_16x16 tests/test_bwd_mfma_16x16.cpp) +set_source_files_properties(tests/test_bwd_mfma_16x16.cpp PROPERTIES LANGUAGE HIP) + +target_include_directories(test_bwd_mfma_16x16 PRIVATE + ${KERNEL_INCLUDE_DIR} + ${CMAKE_SOURCE_DIR}/ref + ${CMAKE_SOURCE_DIR}/tests +) +target_link_libraries(test_bwd_mfma_16x16 PRIVATE attn_ref hip::host) +target_compile_options(test_bwd_mfma_16x16 PRIVATE -O3 --offload-arch=${GPU_TARGETS}) + +# --------------------------------------------------------------------------- +# test_varlen_mfma_16x16 executable +# +# Unified varlen test for MFMA 16x16x16 forward + backward kernels. +# --------------------------------------------------------------------------- + +add_executable(test_varlen_mfma_16x16 tests/test_varlen_mfma_16x16.cpp) +set_source_files_properties(tests/test_varlen_mfma_16x16.cpp PROPERTIES LANGUAGE HIP) + +target_include_directories(test_varlen_mfma_16x16 PRIVATE + ${KERNEL_INCLUDE_DIR} + ${CMAKE_SOURCE_DIR}/ref + ${CMAKE_SOURCE_DIR}/tests +) +target_link_libraries(test_varlen_mfma_16x16 PRIVATE attn_ref hip::host) +target_compile_options(test_varlen_mfma_16x16 PRIVATE -O3 --offload-arch=${GPU_TARGETS}) + +# --------------------------------------------------------------------------- +# test_small_seq_sweep executable +# +# Small-sequence sweep benchmark (seqlen 1..17, bs=2048, fwd+bwd, TE format). +# --------------------------------------------------------------------------- + +add_executable(test_small_seq_sweep tests/test_small_seq_sweep.cpp) +set_source_files_properties(tests/test_small_seq_sweep.cpp PROPERTIES LANGUAGE HIP) + +target_include_directories(test_small_seq_sweep PRIVATE + ${KERNEL_INCLUDE_DIR} + ${CMAKE_SOURCE_DIR}/ref + ${CMAKE_SOURCE_DIR}/tests +) +target_link_libraries(test_small_seq_sweep PRIVATE attn_ref hip::host) +target_compile_options(test_small_seq_sweep PRIVATE -O3 --offload-arch=${GPU_TARGETS}) diff --git a/tests/cpp/small_seq_kernels/ref/attn_bwd_ref.cpp b/tests/cpp/small_seq_kernels/ref/attn_bwd_ref.cpp new file mode 100644 index 000000000..7ac8ec2d4 --- /dev/null +++ b/tests/cpp/small_seq_kernels/ref/attn_bwd_ref.cpp @@ -0,0 +1,253 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "attn_bwd_ref.h" + +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// Helper function implementations +// --------------------------------------------------------------------------- + +template +void matmul(const T* A, const T* B, T* C, int rows_a, int cols_a, int cols_b) +{ + for(int i = 0; i < rows_a; i++) + { + for(int j = 0; j < cols_b; j++) + { + float sum = 0.0f; + for(int k = 0; k < cols_a; k++) + { + sum += float(A[i * cols_a + k]) * float(B[k * cols_b + j]); + } + C[i * cols_b + j] = T(sum); + } + } +} + +template +void transpose(const T* A, T* A_T, int rows, int cols) +{ + for(int i = 0; i < rows; i++) + { + for(int j = 0; j < cols; j++) + { + A_T[j * rows + i] = A[i * cols + j]; + } + } +} + +template +void sum_last_dim(const T* A, T* sums, int rows, int cols) +{ + for(int i = 0; i < rows; i++) + { + float sum = 0.0f; + for(int j = 0; j < cols; j++) + { + sum += float(A[i * cols + j]); + } + sums[i] = T(sum); + } +} + +// --------------------------------------------------------------------------- +// Backward pass implementation +// --------------------------------------------------------------------------- + +template +void attn_backward(const T* Q, + const T* K, + const T* V, + const T* grad_O, + const T* attn_weights, + const T* dropout_mask, + float dropout_p, + T* grad_Q, + T* grad_K, + T* grad_V, + int batch, + int head_num, + int max_kv_seq, + int head_dim, + CausalMaskType mask_type, + const int* cu_seqlens_q, + const int* cu_seqlens_q_padded, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + int total_padded_q, + int total_padded_kv_seq, + int max_seq_q, + bool bf16_weights) +{ + float scale = 1.0f / std::sqrt(static_cast(head_dim)); + float dropout_scale = (dropout_p > 0.0f) ? (1.0f / (1.0f - dropout_p)) : 1.0f; + + // Temporary buffers sized for multi-Q + std::vector K_cont_buf(max_kv_seq * head_dim); + std::vector V_cont_buf(max_kv_seq * head_dim); + std::vector grad_K_cont_buf(max_kv_seq * head_dim); + std::vector grad_V_cont_buf(max_kv_seq * head_dim); + std::vector grad_attn(max_seq_q * max_kv_seq); + std::vector grad_scores(max_seq_q * max_kv_seq); + + // Initialize gradients to zero + std::memset(grad_Q, 0, total_padded_q * head_num * head_dim * sizeof(T)); + std::memset(grad_K, 0, total_padded_kv_seq * head_num * head_dim * sizeof(T)); + std::memset(grad_V, 0, total_padded_kv_seq * head_num * head_dim * sizeof(T)); + + for(int b = 0; b < batch; b++) + { + // Skip batches where actual Q seq is 0 + int actual_q_seq = cu_seqlens_q[b + 1] - cu_seqlens_q[b]; + if(actual_q_seq == 0) + continue; + + int kv_seq = cu_seqlens_kv[b + 1] - cu_seqlens_kv[b]; + int q_off = cu_seqlens_q_padded[b]; // padded Q storage offset + int kv_off = cu_seqlens_kv_padded[b]; // padded KV storage offset + int kv_stride = head_num * head_dim; + + for(int h = 0; h < head_num; h++) + { + // K/V: [total_padded_seq_kv, head_num, head_dim] + int offset_kv_base = kv_off * head_num * head_dim + h * head_dim; + const T* K_bh = K + offset_kv_base; + const T* V_bh = V + offset_kv_base; + T* grad_K_bh = grad_K + offset_kv_base; + T* grad_V_bh = grad_V + offset_kv_base; + + // Flatten K/V into contiguous row-major buffers [kv_seq, head_dim] + for(int i = 0; i < kv_seq; i++) + for(int j = 0; j < head_dim; j++) + { + K_cont_buf[i * head_dim + j] = K_bh[i * kv_stride + j]; + V_cont_buf[i * head_dim + j] = V_bh[i * kv_stride + j]; + } + + // Zero grad_V accumulator + std::fill(grad_V_cont_buf.begin(), grad_V_cont_buf.begin() + kv_seq * head_dim, T(0.0f)); + // Zero grad_K accumulator + std::fill(grad_K_cont_buf.begin(), grad_K_cont_buf.begin() + kv_seq * head_dim, T(0.0f)); + + // --- Process each Q row --- + for(int q_idx = 0; q_idx < actual_q_seq; q_idx++) + { + // Q/grad_O/grad_Q: [total_padded_seq_q, head_num, head_dim] + int offset_Q = ((q_off + q_idx) * head_num + h) * head_dim; + const T* Q_bh_q = Q + offset_Q; + const T* grad_O_bh_q = grad_O + offset_Q; + T* grad_Q_bh_q = grad_Q + offset_Q; + + // attn_weights/dropout_mask: [total_padded_q, head_num, max_kv_seq] + int offset_attn = ((q_off + q_idx) * head_num + h) * max_kv_seq; + const T* attn_bh_q = attn_weights + offset_attn; + int offset_drop = dropout_mask ? offset_attn : 0; + const T* dropout_bh_q = dropout_mask ? dropout_mask + offset_drop : nullptr; + + // Step 1: grad_V[j,d] += attn[q,j] * grad_O[q,d] (accumulate over Q rows) + for(int j = 0; j < kv_seq; j++) + { + float aw = float(attn_bh_q[j]); + if(bf16_weights) aw = float(hip_bfloat16(aw)); + for(int d = 0; d < head_dim; d++) + grad_V_cont_buf[j * head_dim + d] = + T(float(grad_V_cont_buf[j * head_dim + d]) + aw * float(grad_O_bh_q[d])); + } + + // Step 2: grad_attn[q,j] = dot(grad_O[q,:], V[j,:]) + for(int j = 0; j < kv_seq; j++) + { + float s = 0.0f; + for(int d = 0; d < head_dim; d++) + s += float(grad_O_bh_q[d]) * float(V_cont_buf[j * head_dim + d]); + grad_attn[q_idx * max_kv_seq + j] = s; + } + + // Step 3: Dropout backward + if(dropout_p > 0.0f && dropout_bh_q != nullptr) + for(int j = 0; j < kv_seq; j++) + grad_attn[q_idx * max_kv_seq + j] *= float(dropout_bh_q[j]) * dropout_scale; + + // Step 4: Softmax backward — per Q row, independent + // grad_score[q,j] = attn[q,j] * (grad_attn[q,j] - dot_sum) + float dot_sum = 0.0f; + for(int j = 0; j < kv_seq; j++) + { + float aw = float(attn_bh_q[j]); + if(bf16_weights) aw = float(hip_bfloat16(aw)); + dot_sum += grad_attn[q_idx * max_kv_seq + j] * aw; + } + for(int j = 0; j < kv_seq; j++) + { + float aw = float(attn_bh_q[j]); + if(bf16_weights) aw = float(hip_bfloat16(aw)); + grad_scores[q_idx * max_kv_seq + j] = aw * (grad_attn[q_idx * max_kv_seq + j] - dot_sum); + } + + // Step 5: Mask backward + if(mask_type == CausalMaskType::TOP_LEFT) + { + for(int j = 0; j < kv_seq; j++) + if(j > q_idx) grad_scores[q_idx * max_kv_seq + j] = 0.0f; + } + + // Step 6: grad_Q[q,d] = sum_j grad_scores[q,j] * K[j,d] * scale + for(int d = 0; d < head_dim; d++) + { + float s = 0.0f; + for(int j = 0; j < kv_seq; j++) + s += grad_scores[q_idx * max_kv_seq + j] * float(K_cont_buf[j * head_dim + d]); + grad_Q_bh_q[d] = T(s * scale); + } + + // Step 7: grad_K[j,d] += grad_scores[q,j] * Q[q,d] * scale (accumulate over Q rows) + for(int j = 0; j < kv_seq; j++) + for(int d = 0; d < head_dim; d++) + { + float gs = grad_scores[q_idx * max_kv_seq + j]; + grad_K_cont_buf[j * head_dim + d] = + T(float(grad_K_cont_buf[j * head_dim + d]) + + gs * float(Q_bh_q[d]) * scale); + } + } + + // Copy grad_K and grad_V back to strided layout + for(int i = 0; i < kv_seq; i++) + for(int j = 0; j < head_dim; j++) + { + grad_K_bh[i * kv_stride + j] = grad_K_cont_buf[i * head_dim + j]; + grad_V_bh[i * kv_stride + j] = grad_V_cont_buf[i * head_dim + j]; + } + } + } +} + +// --------------------------------------------------------------------------- +// Explicit instantiations +// --------------------------------------------------------------------------- + +template void matmul(const float*, const float*, float*, int, int, int); +template void matmul(const hip_bfloat16*, const hip_bfloat16*, hip_bfloat16*, int, int, int); + +template void transpose(const float*, float*, int, int); +template void transpose(const hip_bfloat16*, hip_bfloat16*, int, int); + +template void sum_last_dim(const float*, float*, int, int); +template void sum_last_dim(const hip_bfloat16*, hip_bfloat16*, int, int); + +template void attn_backward(const float*, const float*, const float*, const float*, + const float*, const float*, float, float*, float*, float*, + int, int, int, int, CausalMaskType, + const int*, const int*, const int*, const int*, int, int, + int, bool); +template void attn_backward(const hip_bfloat16*, const hip_bfloat16*, + const hip_bfloat16*, const hip_bfloat16*, + const hip_bfloat16*, const hip_bfloat16*, float, + hip_bfloat16*, hip_bfloat16*, hip_bfloat16*, + int, int, int, int, CausalMaskType, + const int*, const int*, const int*, const int*, int, int, + int, bool); diff --git a/tests/cpp/small_seq_kernels/ref/attn_bwd_ref.h b/tests/cpp/small_seq_kernels/ref/attn_bwd_ref.h new file mode 100644 index 000000000..42f12da6d --- /dev/null +++ b/tests/cpp/small_seq_kernels/ref/attn_bwd_ref.h @@ -0,0 +1,64 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "attn_common.h" +#include +#include +#include + +// --------------------------------------------------------------------------- +// CPU helper functions used by attn_backward +// --------------------------------------------------------------------------- + +// Matrix multiplication C = A @ B +// A: [rows_a, cols_a], B: [cols_a, cols_b], C: [rows_a, cols_b] +template +void matmul(const T* A, const T* B, T* C, int rows_a, int cols_a, int cols_b); + +// Matrix transpose: A_T = A^T +// A: [rows, cols], A_T: [cols, rows] +template +void transpose(const T* A, T* A_T, int rows, int cols); + +// Sum along last dimension: sums[i] = sum_j A[i, j] +template +void sum_last_dim(const T* A, T* sums, int rows, int cols); + +// --------------------------------------------------------------------------- +// CPU backward reference +// --------------------------------------------------------------------------- + +/** + * Multi-Head Attention Backward Pass (CPU Reference Implementation) + * + * Q/grad_O/grad_Q layout: [total_padded_seq_q, head_num, head_dim] + * K/V/grad_K/grad_V layout: [total_padded_seq_kv, head_num, head_dim] + * attn_weights/dropout_mask: [total_padded_q, head_num, max_kv_seq] + * + * Batches where actual Q seq = 0 are skipped. + */ +template +void attn_backward(const T* Q, + const T* K, + const T* V, + const T* grad_O, + const T* attn_weights, + const T* dropout_mask, + float dropout_p, + T* grad_Q, + T* grad_K, + T* grad_V, + int batch, + int head_num, + int max_kv_seq, + int head_dim, + CausalMaskType mask_type, + const int* cu_seqlens_q, + const int* cu_seqlens_q_padded, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + int total_padded_q, + int total_padded_kv_seq, + int max_seq_q = 1, + bool bf16_weights = false); diff --git a/tests/cpp/small_seq_kernels/ref/attn_fwd_ref.cpp b/tests/cpp/small_seq_kernels/ref/attn_fwd_ref.cpp new file mode 100644 index 000000000..53d28f9f8 --- /dev/null +++ b/tests/cpp/small_seq_kernels/ref/attn_fwd_ref.cpp @@ -0,0 +1,181 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "attn_fwd_ref.h" + +#include +#include +#include +#include + +template +void attn_forward(const T* Q, + const T* K, + const T* V, + const T* dropout_mask, + float dropout_p, + T* O, + T* attn_weights, + int batch, + int head_num, + int max_kv_seq, + int head_dim, + CausalMaskType mask_type, + const int* cu_seqlens_q, + const int* cu_seqlens_q_padded, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + bool bf16_weights) +{ + float scale = 1.0f / std::sqrt(static_cast(head_dim)); + float dropout_scale = (dropout_p > 0.0f) ? (1.0f / (1.0f - dropout_p)) : 1.0f; + + // Allocate temporary buffers in float (matching GPU kernel precision) + std::vector scores(max_kv_seq); + std::vector attn_probs(max_kv_seq); + + // Total padded Q storage size + int total_padded_q = cu_seqlens_q_padded[batch]; + + // Initialize output to zero + std::memset(O, 0, total_padded_q * head_num * head_dim * sizeof(T)); + if(attn_weights != nullptr) + { + // attn_weights: [total_padded_q, head_num, max_kv_seq] + std::memset(attn_weights, 0, total_padded_q * head_num * max_kv_seq * sizeof(T)); + } + + // Process each batch and head + for(int b = 0; b < batch; b++) + { + // Skip batches where actual Q seq length is 0 + int actual_q_seq = cu_seqlens_q[b + 1] - cu_seqlens_q[b]; + if(actual_q_seq == 0) + continue; + + // Get actual KV sequence length for this batch + int kv_seq = cu_seqlens_kv[b + 1] - cu_seqlens_kv[b]; + int kv_offset = cu_seqlens_kv_padded[b]; + // Q padded storage offset + int q_offset = cu_seqlens_q_padded[b]; + + for(int h = 0; h < head_num; h++) + { + // For each query position (actual_q_seq is 0 or 1, and we already checked != 0) + for(int q_idx = 0; q_idx < actual_q_seq; q_idx++) + { + // Q: [total_padded_seq_q, head_num, head_dim] + int offset_Q = ((q_offset + q_idx) * head_num + h) * head_dim; + // O: [total_padded_seq_q, head_num, head_dim] + int offset_O = ((q_offset + q_idx) * head_num + h) * head_dim; + // attn_weights: [total_padded_q, head_num, max_kv_seq] + int offset_attn = ((q_offset + q_idx) * head_num + h) * max_kv_seq; + int offset_dropout = dropout_mask ? ((q_offset + q_idx) * head_num + h) * max_kv_seq : 0; + + const T* Q_ptr = Q + offset_Q; + const T* dropout_ptr = dropout_mask ? dropout_mask + offset_dropout : nullptr; + + T* O_ptr = O + offset_O; + T* attn_ptr = attn_weights ? attn_weights + offset_attn : nullptr; + + // Step 1: Compute scores = Q @ K^T / sqrt(d_k) + // Q: [1, head_dim], K: [kv_seq, head_dim] -> scores: [1, kv_seq] + for(int kv_idx = 0; kv_idx < kv_seq; kv_idx++) + { + int k_offset = ((kv_offset + kv_idx) * head_num + h) * head_dim; + const T* K_ptr = K + k_offset; + float sum = 0.0f; + for(int d = 0; d < head_dim; d++) + sum += float(Q_ptr[d]) * float(K_ptr[d]); + scores[kv_idx] = sum * scale; + } + + // Step 2: Apply causal mask + if(mask_type == CausalMaskType::TOP_LEFT) + { + for(int j = 0; j < kv_seq; j++) + { + if(j > q_idx) + { + scores[j] = -1e9f; + } + } + } + else if(mask_type == CausalMaskType::BOTTOM_RIGHT) + { + for(int j = 0; j < kv_seq; j++) + { + if(j < q_idx) + { + scores[j] = -1e9f; + } + } + } + + // Step 3: Softmax (numerically stable, all in float) + float max_val = -1e9f; + for(int j = 0; j < kv_seq; j++) + { + max_val = std::max(max_val, scores[j]); + } + + float sum = 0.0f; + for(int j = 0; j < kv_seq; j++) + { + attn_probs[j] = std::exp(scores[j] - max_val); + sum += attn_probs[j]; + } + + for(int j = 0; j < kv_seq; j++) + { + attn_probs[j] /= sum; + } + + // Step 4: Apply dropout + if(dropout_p > 0.0f && dropout_ptr != nullptr) + { + for(int i = 0; i < kv_seq; i++) + { + attn_probs[i] *= float(dropout_ptr[i]) * dropout_scale; + } + } + + // Save attention weights if requested (truncate to T for storage) + if(attn_ptr != nullptr) + { + for(int j = 0; j < kv_seq; j++) + attn_ptr[j] = T(attn_probs[j]); + } + + // Truncate weights to bf16 (matches MFMA kernel: float→bhalf_t via SM_lds) + if(bf16_weights) + { + for(int j = 0; j < kv_seq; j++) + attn_probs[j] = float(hip_bfloat16(attn_probs[j])); + } + + // Step 5: Compute output = attn_probs @ V + // attn_probs: [1, kv_seq], V: [kv_seq, head_dim] -> O: [1, head_dim] + for(int d = 0; d < head_dim; d++) + { + float sum = 0.0f; + for(int kv_idx = 0; kv_idx < kv_seq; kv_idx++) + { + int v_offset = ((kv_offset + kv_idx) * head_num + h) * head_dim; + sum += attn_probs[kv_idx] * float(V[v_offset + d]); + } + O_ptr[d] = T(sum); + } + } + } + } +} + +// Explicit instantiations +template void attn_forward(const float*, const float*, const float*, const float*, + float, float*, float*, int, int, int, int, CausalMaskType, + const int*, const int*, const int*, const int*, bool); +template void attn_forward(const hip_bfloat16*, const hip_bfloat16*, + const hip_bfloat16*, const hip_bfloat16*, float, + hip_bfloat16*, hip_bfloat16*, int, int, int, int, + CausalMaskType, const int*, const int*, const int*, + const int*, bool); diff --git a/tests/cpp/small_seq_kernels/ref/attn_fwd_ref.h b/tests/cpp/small_seq_kernels/ref/attn_fwd_ref.h new file mode 100644 index 000000000..ddb17eb42 --- /dev/null +++ b/tests/cpp/small_seq_kernels/ref/attn_fwd_ref.h @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "attn_common.h" +#include +#include +#include + +/** + * Multi-Head Attention Forward Pass (CPU Reference Implementation) + * + * Q layout: [total_padded_seq_q, head_num, head_dim] (variable Q lengths, padded storage) + * K layout: [total_padded_seq_kv, head_num, head_dim] + * V layout: [total_padded_seq_kv, head_num, head_dim] + * O layout: [total_padded_seq_q, head_num, head_dim] + * + * For each batch b, actual Q seq length is (cu_seqlens_q[b+1] - cu_seqlens_q[b]), which is 0 or 1. + * Padded storage offset for Q in batch b starts at cu_seqlens_q_padded[b]. + * Batches with actual Q seq = 0 are skipped (their padded slot is unused). + */ +template +void attn_forward(const T* Q, + const T* K, + const T* V, + const T* dropout_mask, + float dropout_p, + T* O, + T* attn_weights, + int batch, + int head_num, + int max_kv_seq, + int head_dim, + CausalMaskType mask_type, + const int* cu_seqlens_q, + const int* cu_seqlens_q_padded, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + bool bf16_weights = false); diff --git a/tests/cpp/small_seq_kernels/tests/test_bwd.cpp b/tests/cpp/small_seq_kernels/tests/test_bwd.cpp new file mode 100644 index 000000000..e7bbf2a79 --- /dev/null +++ b/tests/cpp/small_seq_kernels/tests/test_bwd.cpp @@ -0,0 +1,545 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// clang-format off +// Build: cmake -B build && cmake --build build && ./build/test_bwd +// clang-format on + +#include "attn_bwd.h" +#include "attn_bwd_ref.h" +#include "test_utils.h" + +#include +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// Main backward correctness + performance test +// --------------------------------------------------------------------------- + +template +void test_run_attn_bwd_kernel( + float dropout_p, int warmup_iters, int test_iters, bool check_correctness, bool dump_err) +{ + using Launcher = AttnBackwardKernelLauncher; + + constexpr int bs = Config::bs; + constexpr int head_num = Config::head_num; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + + std::mt19937 gen(42); + std::uniform_real_distribution dis(-1.0f, 1.0f); + std::bernoulli_distribution q_present_dis(0.5); + + // --- Build cu_seqlens_q --- + std::vector h_cu_seqlens_q, h_cu_seqlens_q_padded, h_padded_q_to_batch; + int total_padded_q = build_cu_seqlens_q(bs, gen, h_cu_seqlens_q, h_cu_seqlens_q_padded, + h_padded_q_to_batch); + int total_actual_q = h_cu_seqlens_q[bs]; + + // --- Build cu_seqlens_kv --- + std::vector h_cu_seqlens_kv, h_cu_seqlens_kv_padded; + int total_actual_kv_seq, total_padded_kv_seq; + build_cu_seqlens_kv(bs, max_seq_kv, gen, h_cu_seqlens_kv, h_cu_seqlens_kv_padded, + total_actual_kv_seq, total_padded_kv_seq); + + // --- Buffer sizes --- + size_t size_Q = (size_t)total_padded_q * head_num * head_dim; + size_t size_K = (size_t)total_padded_kv_seq * head_num * head_dim; + size_t size_V = (size_t)total_padded_kv_seq * head_num * head_dim; + size_t size_grad_O = size_Q; + size_t size_attn_weights = (size_t)total_padded_q * head_num * max_seq_kv; + size_t size_dropout_mask = size_attn_weights; + + // --- Host allocations --- + std::vector h_Q(size_Q, DataType(0.0f)); + std::vector h_K(size_K, DataType(0.0f)); + std::vector h_V(size_V, DataType(0.0f)); + std::vector h_grad_O(size_grad_O, DataType(0.0f)); + std::vector h_attn_weights(size_attn_weights, DataType(0.0f)); + std::vector h_dropout_mask(size_dropout_mask, DataType(1.0f)); + std::vector h_grad_Q_gpu(size_Q, DataType(0.0f)); + std::vector h_grad_K_gpu(size_K, DataType(0.0f)); + std::vector h_grad_V_gpu(size_V, DataType(0.0f)); + std::vector h_grad_Q_cpu(size_Q, DataType(0.0f)); + std::vector h_grad_K_cpu(size_K, DataType(0.0f)); + std::vector h_grad_V_cpu(size_V, DataType(0.0f)); + + // Initialize Q and grad_O for active-Q batches + for(int b = 0; b < bs; b++) + { + if(h_cu_seqlens_q[b + 1] == h_cu_seqlens_q[b]) continue; + int q_off = h_cu_seqlens_q_padded[b]; + for(int h = 0; h < head_num; h++) + { + int base = (q_off * head_num + h) * head_dim; + for(int d = 0; d < head_dim; d++) + { + h_Q[base + d] = DataType(dis(gen)); + h_grad_O[base + d] = DataType(dis(gen)); + } + } + } + + // Initialize K/V + for(int b = 0; b < bs; b++) + { + int kv_seq = h_cu_seqlens_kv[b + 1] - h_cu_seqlens_kv[b]; + int kv_off = h_cu_seqlens_kv_padded[b]; + for(int h = 0; h < head_num; h++) + for(int s = 0; s < kv_seq; s++) + { + int base = (kv_off + s) * head_num * head_dim + h * head_dim; + for(int d = 0; d < head_dim; d++) + { + h_K[base + d] = DataType(dis(gen)); + h_V[base + d] = DataType(dis(gen)); + } + } + } + + // Initialize attn_weights (normalized per row) + for(int b = 0; b < bs; b++) + { + if(h_cu_seqlens_q[b + 1] == h_cu_seqlens_q[b]) continue; + int kv_seq = h_cu_seqlens_kv[b + 1] - h_cu_seqlens_kv[b]; + int q_off = h_cu_seqlens_q_padded[b]; + for(int h = 0; h < head_num; h++) + { + int base = (q_off * head_num + h) * max_seq_kv; + float sum = 0.0f; + for(int j = 0; j < kv_seq; j++) + { + h_attn_weights[base + j] = DataType(std::abs(dis(gen))); + sum += float(h_attn_weights[base + j]); + } + for(int j = kv_seq; j < max_seq_kv; j++) + h_attn_weights[base + j] = DataType(0.0f); + if(sum > 0.0f) + for(int j = 0; j < kv_seq; j++) + h_attn_weights[base + j] = + DataType(float(h_attn_weights[base + j]) / sum); + } + } + + // Initialize dropout mask + for(size_t i = 0; i < size_dropout_mask; i++) + h_dropout_mask[i] = Config::enable_dropout_mask + ? DataType(dis(gen) > dropout_p ? 1.0f : 0.0f) + : DataType(1.0f); + + // --- CPU reference --- + float sqr_dk_scale = 1.0f / std::sqrt(static_cast(head_dim)); + if(check_correctness) + attn_backward(h_Q.data(), h_K.data(), h_V.data(), h_grad_O.data(), + h_attn_weights.data(), + Config::enable_dropout_mask ? h_dropout_mask.data() : nullptr, + dropout_p, h_grad_Q_cpu.data(), h_grad_K_cpu.data(), h_grad_V_cpu.data(), + bs, head_num, max_seq_kv, head_dim, Config::mask_type, + h_cu_seqlens_q.data(), h_cu_seqlens_q_padded.data(), + h_cu_seqlens_kv.data(), h_cu_seqlens_kv_padded.data(), + total_padded_q, total_padded_kv_seq); + + // --- Device allocations --- + DataType *d_Q, *d_K, *d_V, *d_grad_O, *d_attn_weights, *d_dropout_mask; + DataType *d_grad_Q, *d_grad_K, *d_grad_V, *d_workspace; + int *d_cu_seqlens_q, *d_cu_seqlens_q_padded; + int *d_cu_seqlens_kv, *d_cu_seqlens_kv_padded; + int* d_padded_q_to_batch; + + HIP_CHECK(hipMalloc(&d_Q, size_Q > 0 ? size_Q * sizeof(DataType) : 1)); + HIP_CHECK(hipMalloc(&d_K, size_K * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_V, size_V * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_grad_O, size_grad_O > 0 ? size_grad_O * sizeof(DataType) : 1)); + HIP_CHECK(hipMalloc(&d_attn_weights, size_attn_weights > 0 ? size_attn_weights * sizeof(DataType) : 1)); + HIP_CHECK(hipMalloc(&d_dropout_mask, size_dropout_mask > 0 ? size_dropout_mask * sizeof(DataType) : 1)); + HIP_CHECK(hipMalloc(&d_grad_Q, size_Q > 0 ? size_Q * sizeof(DataType) : 1)); + HIP_CHECK(hipMalloc(&d_grad_K, size_K * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_grad_V, size_V * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q_padded, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv_padded, (bs + 1) * sizeof(int))); + if(total_padded_q > 0) + HIP_CHECK(hipMalloc(&d_padded_q_to_batch, total_padded_q * sizeof(int))); + else + d_padded_q_to_batch = nullptr; + + size_t workspace_size = Launcher::calc_workspace_size(total_padded_q); + HIP_CHECK(hipMalloc(&d_workspace, workspace_size > 0 ? workspace_size : 1)); + + // --- Copy to device --- + if(size_Q > 0) + HIP_CHECK(hipMemcpy(d_Q, h_Q.data(), size_Q * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_K, h_K.data(), size_K * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_V, h_V.data(), size_V * sizeof(DataType), hipMemcpyHostToDevice)); + if(size_grad_O > 0) + HIP_CHECK(hipMemcpy(d_grad_O, h_grad_O.data(), size_grad_O * sizeof(DataType), hipMemcpyHostToDevice)); + if(size_attn_weights > 0) + HIP_CHECK(hipMemcpy(d_attn_weights, h_attn_weights.data(), + size_attn_weights * sizeof(DataType), hipMemcpyHostToDevice)); + if(size_dropout_mask > 0) + HIP_CHECK(hipMemcpy(d_dropout_mask, h_dropout_mask.data(), + size_dropout_mask * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q, h_cu_seqlens_q.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q_padded, h_cu_seqlens_q_padded.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv, h_cu_seqlens_kv.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv_padded, h_cu_seqlens_kv_padded.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + if(total_padded_q > 0) + HIP_CHECK(hipMemcpy(d_padded_q_to_batch, h_padded_q_to_batch.data(), + total_padded_q * sizeof(int), hipMemcpyHostToDevice)); + + auto bwd_launch = [&]() { + Launcher::run_attn_bwd_kernel(d_Q, d_K, d_V, d_grad_O, d_attn_weights, + Config::enable_dropout_mask ? d_dropout_mask : nullptr, + dropout_p, sqr_dk_scale, + d_grad_Q, d_grad_K, d_grad_V, d_workspace, + d_cu_seqlens_q, d_cu_seqlens_q_padded, + d_cu_seqlens_kv, d_cu_seqlens_kv_padded, + d_padded_q_to_batch, total_padded_q); + }; + + for(int i = 0; i < warmup_iters; i++) bwd_launch(); + HIP_CHECK(hipDeviceSynchronize()); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + HIP_CHECK(hipEventRecord(start)); + for(int i = 0; i < test_iters; i++) bwd_launch(); + HIP_CHECK(hipEventRecord(stop)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_ms = 0; + HIP_CHECK(hipEventElapsedTime(&elapsed_ms, start, stop)); + double avg_time_ms = elapsed_ms / test_iters; + + HIP_CHECK(hipMemcpy(h_grad_Q_gpu.data(), d_grad_Q, size_Q * sizeof(DataType), hipMemcpyDeviceToHost)); + HIP_CHECK(hipMemcpy(h_grad_K_gpu.data(), d_grad_K, size_K * sizeof(DataType), hipMemcpyDeviceToHost)); + HIP_CHECK(hipMemcpy(h_grad_V_gpu.data(), d_grad_V, size_V * sizeof(DataType), hipMemcpyDeviceToHost)); + + // --- Report --- + double avg_kv_seq = total_actual_kv_seq / double(bs); + double avg_padded_kv_seq = total_padded_kv_seq / double(bs); + double flops_per_active = 4.0 * 2.0 * avg_kv_seq * head_dim; + double total_flops = flops_per_active * total_actual_q * head_num; + double tflops = (total_flops / 1e12) / (avg_time_ms / 1000.0); + + size_t bytes_read = + (size_Q + size_K + size_V + size_grad_O + size_attn_weights) * sizeof(DataType); + if(Config::enable_dropout_mask) bytes_read += size_dropout_mask * sizeof(DataType); + size_t bytes_write = (size_Q + size_K + size_V) * sizeof(DataType); + size_t total_bytes = bytes_read + bytes_write; + double bandwidth_gbps = (total_bytes / 1e9) / (avg_time_ms / 1000.0); + + std::cout << "\n===== run_attn_bwd_kernel Test =====" << std::endl; + std::cout << "Configuration:" << std::endl; + std::cout << " Batch size: " << bs << std::endl; + std::cout << " Heads: " << head_num << std::endl; + std::cout << " Active Q batches: " << total_actual_q << " / " << bs << std::endl; + std::cout << " KV max: " << max_seq_kv << " KV avg: " << std::fixed << std::setprecision(2) + << avg_kv_seq << " KV avg padded: " << avg_padded_kv_seq << std::endl; + std::cout << " Head dimension: " << head_dim << std::endl; + std::cout << " Dropout: " << (Config::enable_dropout_mask ? "enabled" : "disabled") << std::endl; + std::cout << " Mask: " << CausalMaskTypeName[Config::mask_type] << std::endl; + std::cout << std::endl; + + if(check_correctness) + { + std::cout << "Correctness:" << std::endl; + check_grad_q(h_grad_Q_gpu, h_grad_Q_cpu, bs, head_num, head_dim, + h_cu_seqlens_q, h_cu_seqlens_q_padded, 1e-2f, 1e-2f, dump_err); + check_array(h_grad_K_gpu, h_grad_K_cpu, "grad_K", 1e-2f, 1e-2f, dump_err); + check_array(h_grad_V_gpu, h_grad_V_cpu, "grad_V", 1e-2f, 1e-2f, dump_err); + std::cout << std::endl; + } + + std::cout << "Memory:" << std::endl; + std::cout << " Total data read: " << std::fixed << std::setprecision(2) << bytes_read / 1e6 << " MB" << std::endl; + std::cout << " Total data write: " << bytes_write / 1e6 << " MB" << std::endl; + std::cout << " Total data transfer: " << total_bytes / 1e6 << " MB" << std::endl; + std::cout << " Workspace: " << workspace_size / 1e6 << " MB" << std::endl; + std::cout << std::endl; + + std::cout << "Performance:" << std::endl; + std::cout << " Average time: " << std::fixed << std::setprecision(3) << avg_time_ms << " ms" << std::endl; + std::cout << " Bandwidth: " << std::fixed << std::setprecision(2) << bandwidth_gbps << " GB/s" << std::endl; + std::cout << " TFLOPS: " << std::fixed << std::setprecision(2) << tflops << std::endl; + std::cout << "====================================\n" << std::endl; + + // --- Cleanup --- + HIP_CHECK(hipFree(d_Q)); HIP_CHECK(hipFree(d_K)); HIP_CHECK(hipFree(d_V)); + HIP_CHECK(hipFree(d_grad_O)); HIP_CHECK(hipFree(d_attn_weights)); + HIP_CHECK(hipFree(d_dropout_mask)); + HIP_CHECK(hipFree(d_grad_Q)); HIP_CHECK(hipFree(d_grad_K)); HIP_CHECK(hipFree(d_grad_V)); + HIP_CHECK(hipFree(d_workspace)); + HIP_CHECK(hipFree(d_cu_seqlens_q)); HIP_CHECK(hipFree(d_cu_seqlens_q_padded)); + HIP_CHECK(hipFree(d_cu_seqlens_kv)); HIP_CHECK(hipFree(d_cu_seqlens_kv_padded)); + if(d_padded_q_to_batch) HIP_CHECK(hipFree(d_padded_q_to_batch)); + HIP_CHECK(hipEventDestroy(start)); HIP_CHECK(hipEventDestroy(stop)); +} + +// --------------------------------------------------------------------------- +// Corner-case test: explicit Q seqlens provided by caller +// --------------------------------------------------------------------------- + +template +void test_run_attn_bwd_with_seqlens(const std::vector& h_cu_seqlens_q, + const std::vector& h_cu_seqlens_q_padded, + const std::vector& h_padded_q_to_batch, + int total_padded_q, + float dropout_p, + bool check_correctness, + bool dump_err, + const std::string& test_name) +{ + using Launcher = AttnBackwardKernelLauncher; + + constexpr int head_num = Config::head_num; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + int bs = static_cast(h_cu_seqlens_q.size()) - 1; + + std::mt19937 gen(123); + std::uniform_real_distribution dis(-1.0f, 1.0f); + + std::vector h_cu_seqlens_kv, h_cu_seqlens_kv_padded; + int total_actual_kv_seq, total_padded_kv_seq; + build_cu_seqlens_kv(bs, max_seq_kv, gen, h_cu_seqlens_kv, h_cu_seqlens_kv_padded, + total_actual_kv_seq, total_padded_kv_seq); + + size_t size_Q = (size_t)total_padded_q * head_num * head_dim; + size_t size_K = (size_t)total_padded_kv_seq * head_num * head_dim; + size_t size_V = size_K; + size_t size_grad_O = size_Q; + size_t size_attn_weights = (size_t)total_padded_q * head_num * max_seq_kv; + size_t size_dropout_mask = size_attn_weights; + + std::vector h_Q(size_Q, DataType(0.0f)), h_K(size_K), h_V(size_K); + std::vector h_grad_O(size_grad_O), h_attn_weights(size_attn_weights); + std::vector h_dropout_mask(size_dropout_mask, DataType(1.0f)); + std::vector h_grad_Q_gpu(size_Q), h_grad_K_gpu(size_K), h_grad_V_gpu(size_V); + std::vector h_grad_Q_cpu(size_Q), h_grad_K_cpu(size_K), h_grad_V_cpu(size_V); + + for(size_t i = 0; i < size_Q; i++) h_Q[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_K; i++) h_K[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_V; i++) h_V[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_grad_O; i++) h_grad_O[i] = DataType(dis(gen)); + + // Q and grad_O for active-Q batches only + for(int b = 0; b < bs; b++) + { + if(h_cu_seqlens_q[b + 1] == h_cu_seqlens_q[b]) continue; + int q_off = h_cu_seqlens_q_padded[b]; + for(int h = 0; h < head_num; h++) + { + int base = (q_off * head_num + h) * head_dim; + for(int d = 0; d < head_dim; d++) + { + h_Q[base + d] = DataType(dis(gen)); + h_grad_O[base + d] = DataType(dis(gen)); + } + } + } + + // K/V + for(int b = 0; b < bs; b++) + { + int kv_seq = h_cu_seqlens_kv[b + 1] - h_cu_seqlens_kv[b]; + int kv_off = h_cu_seqlens_kv_padded[b]; + for(int h = 0; h < head_num; h++) + for(int s = 0; s < kv_seq; s++) + { + int base = (kv_off + s) * head_num * head_dim + h * head_dim; + for(int d = 0; d < head_dim; d++) + { + h_K[base + d] = DataType(dis(gen)); + h_V[base + d] = DataType(dis(gen)); + } + } + } + + // attn_weights (normalized per row) + for(int b = 0; b < bs; b++) + { + if(h_cu_seqlens_q[b + 1] == h_cu_seqlens_q[b]) continue; + int kv_seq = h_cu_seqlens_kv[b + 1] - h_cu_seqlens_kv[b]; + int q_off = h_cu_seqlens_q_padded[b]; + for(int h = 0; h < head_num; h++) + { + int base = (q_off * head_num + h) * max_seq_kv; + float sum = 0.0f; + for(int j = 0; j < kv_seq; j++) + { + h_attn_weights[base + j] = DataType(std::abs(dis(gen))); + sum += float(h_attn_weights[base + j]); + } + for(int j = kv_seq; j < max_seq_kv; j++) h_attn_weights[base + j] = DataType(0.0f); + if(sum > 0.0f) + for(int j = 0; j < kv_seq; j++) + h_attn_weights[base + j] = DataType(float(h_attn_weights[base + j]) / sum); + } + } + + for(size_t i = 0; i < size_dropout_mask; i++) + h_dropout_mask[i] = Config::enable_dropout_mask + ? DataType(dis(gen) > dropout_p ? 1.0f : 0.0f) + : DataType(1.0f); + + float sqr_dk_scale = 1.0f / std::sqrt(static_cast(head_dim)); + if(check_correctness) + attn_backward(h_Q.data(), h_K.data(), h_V.data(), h_grad_O.data(), + h_attn_weights.data(), + Config::enable_dropout_mask ? h_dropout_mask.data() : nullptr, + dropout_p, h_grad_Q_cpu.data(), h_grad_K_cpu.data(), h_grad_V_cpu.data(), + bs, head_num, max_seq_kv, head_dim, Config::mask_type, + h_cu_seqlens_q.data(), h_cu_seqlens_q_padded.data(), + h_cu_seqlens_kv.data(), h_cu_seqlens_kv_padded.data(), + total_padded_q, total_padded_kv_seq); + + DataType *d_Q, *d_K, *d_V, *d_grad_O, *d_attn_weights, *d_dropout_mask; + DataType *d_grad_Q, *d_grad_K, *d_grad_V, *d_workspace; + int *d_cu_seqlens_q, *d_cu_seqlens_q_padded; + int *d_cu_seqlens_kv, *d_cu_seqlens_kv_padded; + int* d_padded_q_to_batch; + + HIP_CHECK(hipMalloc(&d_Q, size_Q > 0 ? size_Q * sizeof(DataType) : 1)); + HIP_CHECK(hipMalloc(&d_K, size_K * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_V, size_V * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_grad_O, size_grad_O > 0 ? size_grad_O * sizeof(DataType) : 1)); + HIP_CHECK(hipMalloc(&d_attn_weights, size_attn_weights > 0 ? size_attn_weights * sizeof(DataType) : 1)); + HIP_CHECK(hipMalloc(&d_dropout_mask, size_dropout_mask > 0 ? size_dropout_mask * sizeof(DataType) : 1)); + HIP_CHECK(hipMalloc(&d_grad_Q, size_Q > 0 ? size_Q * sizeof(DataType) : 1)); + HIP_CHECK(hipMalloc(&d_grad_K, size_K * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_grad_V, size_V * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q_padded, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv_padded, (bs + 1) * sizeof(int))); + if(total_padded_q > 0) + HIP_CHECK(hipMalloc(&d_padded_q_to_batch, total_padded_q * sizeof(int))); + else + d_padded_q_to_batch = nullptr; + + size_t workspace_size = Launcher::calc_workspace_size(total_padded_q); + HIP_CHECK(hipMalloc(&d_workspace, workspace_size > 0 ? workspace_size : 1)); + + HIP_CHECK(hipMemcpy(d_Q, h_Q.data(), size_Q * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_K, h_K.data(), size_K * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_V, h_V.data(), size_V * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_grad_O, h_grad_O.data(), size_grad_O * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_attn_weights, h_attn_weights.data(), size_attn_weights * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_dropout_mask, h_dropout_mask.data(), size_dropout_mask * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q, h_cu_seqlens_q.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q_padded, h_cu_seqlens_q_padded.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv, h_cu_seqlens_kv.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv_padded, h_cu_seqlens_kv_padded.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + if(total_padded_q > 0) + HIP_CHECK(hipMemcpy(d_padded_q_to_batch, h_padded_q_to_batch.data(), + total_padded_q * sizeof(int), hipMemcpyHostToDevice)); + + if(workspace_size > 0) HIP_CHECK(hipMemset(d_workspace, 0, workspace_size)); + + Launcher::run_attn_bwd_kernel(d_Q, d_K, d_V, d_grad_O, d_attn_weights, + Config::enable_dropout_mask ? d_dropout_mask : nullptr, + dropout_p, sqr_dk_scale, + d_grad_Q, d_grad_K, d_grad_V, d_workspace, + d_cu_seqlens_q, d_cu_seqlens_q_padded, + d_cu_seqlens_kv, d_cu_seqlens_kv_padded, + d_padded_q_to_batch, total_padded_q); + + HIP_CHECK(hipMemcpy(h_grad_Q_gpu.data(), d_grad_Q, size_Q * sizeof(DataType), hipMemcpyDeviceToHost)); + HIP_CHECK(hipMemcpy(h_grad_K_gpu.data(), d_grad_K, size_K * sizeof(DataType), hipMemcpyDeviceToHost)); + HIP_CHECK(hipMemcpy(h_grad_V_gpu.data(), d_grad_V, size_V * sizeof(DataType), hipMemcpyDeviceToHost)); + + if(check_correctness) + { + std::cout << "\n===== " << test_name << " =====" << std::endl; + std::cout << "Correctness:" << std::endl; + check_grad_q(h_grad_Q_gpu, h_grad_Q_cpu, bs, head_num, head_dim, + h_cu_seqlens_q, h_cu_seqlens_q_padded, 1e-2f, 1e-2f, dump_err); + check_array(h_grad_K_gpu, h_grad_K_cpu, "grad_K", 1e-2f, 1e-2f, dump_err); + check_array(h_grad_V_gpu, h_grad_V_cpu, "grad_V", 1e-2f, 1e-2f, dump_err); + std::cout << std::endl; + } + + HIP_CHECK(hipFree(d_Q)); HIP_CHECK(hipFree(d_K)); HIP_CHECK(hipFree(d_V)); + HIP_CHECK(hipFree(d_grad_O)); HIP_CHECK(hipFree(d_attn_weights)); + HIP_CHECK(hipFree(d_dropout_mask)); + HIP_CHECK(hipFree(d_grad_Q)); HIP_CHECK(hipFree(d_grad_K)); HIP_CHECK(hipFree(d_grad_V)); + HIP_CHECK(hipFree(d_workspace)); + HIP_CHECK(hipFree(d_cu_seqlens_q)); HIP_CHECK(hipFree(d_cu_seqlens_q_padded)); + HIP_CHECK(hipFree(d_cu_seqlens_kv)); HIP_CHECK(hipFree(d_cu_seqlens_kv_padded)); + if(d_padded_q_to_batch) HIP_CHECK(hipFree(d_padded_q_to_batch)); +} + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main(int argc, char const* argv[]) +{ + std::cout << "\n========== Correctness Test (bs=30720, SEQ_KV=16) ==========" << std::endl; + using CorrConfig = FmhaKernelConfig<30720, 32, 16, 128, 128, false, CausalMaskType::DISABLE>; + test_run_attn_bwd_kernel(0, 10, 10, true, true); + + std::cout << "\n========== Performance Test (bfloat16, TOP_LEFT mask) ==========" << std::endl; + using PerfConfig = FmhaKernelConfig<30720, 32, 16, 128, 128, false, CausalMaskType::TOP_LEFT>; + test_run_attn_bwd_kernel(0, 3, 5, false, false); + + std::cout << "\n========== Mixed-Q Test (bs=128, 0/1 tokens) ==========" << std::endl; + using MixedConfig = FmhaKernelConfig<128, 4, 8, 64, 256, false, CausalMaskType::DISABLE>; + test_run_attn_bwd_kernel(0, 2, 5, true, true); + + std::cout << "\n========== Corner: Empty segments (even batches active, bs=128) ==========" + << std::endl; + { + const int corner_bs = 128; + std::vector h_cu_seqlens_q(corner_bs + 1); + std::vector h_cu_seqlens_q_padded(corner_bs + 1); + std::vector h_padded_q_to_batch(corner_bs / 2); + h_cu_seqlens_q[0] = h_cu_seqlens_q_padded[0] = 0; + for(int b = 0; b < corner_bs; b++) + { + int actual = (b % 2 == 0) ? 1 : 0; + h_cu_seqlens_q[b + 1] = h_cu_seqlens_q[b] + actual; + h_cu_seqlens_q_padded[b + 1] = h_cu_seqlens_q_padded[b] + actual; + } + int total_padded_q = h_cu_seqlens_q_padded[corner_bs]; + for(int b = 0; b < corner_bs; b++) + if(h_cu_seqlens_q_padded[b + 1] > h_cu_seqlens_q_padded[b]) + h_padded_q_to_batch[h_cu_seqlens_q_padded[b]] = b; + + using CornerConfig = FmhaKernelConfig<128, 4, 8, 64, 256, false, CausalMaskType::DISABLE>; + test_run_attn_bwd_with_seqlens( + h_cu_seqlens_q, h_cu_seqlens_q_padded, h_padded_q_to_batch, + total_padded_q, 0.0f, true, true, "Empty segments"); + } + + std::cout << "\n========== Corner: Q padded > actual (2 slots per batch, bs=128) ==========" + << std::endl; + { + const int corner_bs = 128; + std::vector h_cu_seqlens_q(corner_bs + 1); + std::vector h_cu_seqlens_q_padded(corner_bs + 1); + std::vector h_padded_q_to_batch(256); + for(int b = 0; b <= corner_bs; b++) + { + h_cu_seqlens_q[b] = b; + h_cu_seqlens_q_padded[b] = b * 2; + } + for(int i = 0; i < 256; i++) h_padded_q_to_batch[i] = i / 2; + int total_padded_q = 256; + + using CornerConfig = FmhaKernelConfig<128, 4, 8, 64, 256, false, CausalMaskType::DISABLE>; + test_run_attn_bwd_with_seqlens( + h_cu_seqlens_q, h_cu_seqlens_q_padded, h_padded_q_to_batch, + total_padded_q, 0.0f, true, true, "Q padded > actual"); + } + + return 0; +} diff --git a/tests/cpp/small_seq_kernels/tests/test_bwd_mfma_16x16.cpp b/tests/cpp/small_seq_kernels/tests/test_bwd_mfma_16x16.cpp new file mode 100644 index 000000000..647f03f8c --- /dev/null +++ b/tests/cpp/small_seq_kernels/tests/test_bwd_mfma_16x16.cpp @@ -0,0 +1,381 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Test host for the MFMA 16x16x16 backward kernels (attn_bwd_mfma_16x16.h). +// +// 4 test cases: +// 1. sq∈[1,16], skv∈[2,16]; varlen + padding +// 2. sq=1 (fixed, no padding), skv∈[2,16] (varlen + padding) +// 3. sq=16, skv=16; fixed, no padding +// 4. sq=17, skv=17; fixed, no padding +// +// Build: cmake -B ck_arliu && cmake --build ck_arliu --target test_bwd_mfma_16x16 + +#include "attn_bwd_mfma_16x16.h" +#include "attn_fwd_mfma_16x16.h" +#include "attn_bwd_ref.h" +#include "test_utils.h" + +#include +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// Q cu_seqlens builders +// --------------------------------------------------------------------------- + +// Varlen Q: actual in [1, max_seq_q] with random padding [0, 3] +inline int build_varlen_cu_seqlens_q(int bs, + int max_seq_q, + std::mt19937& gen, + std::vector& cu_seqlens_q, + std::vector& cu_seqlens_q_padded, + std::vector& padded_q_to_batch) +{ + std::uniform_int_distribution q_dist(1, max_seq_q); + std::uniform_int_distribution pad_dist(0, 3); + + cu_seqlens_q.resize(bs + 1); + cu_seqlens_q_padded.resize(bs + 1); + cu_seqlens_q[0] = cu_seqlens_q_padded[0] = 0; + + for(int b = 0; b < bs; b++) + { + int q_len = q_dist(gen); + int pad = pad_dist(gen); + int padded_len = std::min(q_len + pad, max_seq_q); + cu_seqlens_q[b + 1] = cu_seqlens_q[b] + q_len; + cu_seqlens_q_padded[b + 1] = cu_seqlens_q_padded[b] + padded_len; + } + + int total_padded_q = cu_seqlens_q_padded[bs]; + padded_q_to_batch.resize(total_padded_q); + for(int b = 0; b < bs; b++) + for(int q = cu_seqlens_q_padded[b]; q < cu_seqlens_q_padded[b + 1]; q++) + padded_q_to_batch[q] = b; + + return total_padded_q; +} + +// Fixed Q: all batches have exactly fix_sq tokens, no padding +inline int build_fixed_cu_seqlens_q(int bs, + int fix_sq, + std::vector& cu_seqlens_q, + std::vector& cu_seqlens_q_padded, + std::vector& padded_q_to_batch) +{ + cu_seqlens_q.resize(bs + 1); + cu_seqlens_q_padded.resize(bs + 1); + cu_seqlens_q[0] = cu_seqlens_q_padded[0] = 0; + + for(int b = 0; b < bs; b++) + { + cu_seqlens_q[b + 1] = cu_seqlens_q[b] + fix_sq; + cu_seqlens_q_padded[b + 1] = cu_seqlens_q_padded[b] + fix_sq; + } + + int total_padded_q = cu_seqlens_q_padded[bs]; + padded_q_to_batch.resize(total_padded_q); + for(int b = 0; b < bs; b++) + for(int q = cu_seqlens_q_padded[b]; q < cu_seqlens_q_padded[b + 1]; q++) + padded_q_to_batch[q] = b; + + return total_padded_q; +} + +// --------------------------------------------------------------------------- +// KV cu_seqlens builders +// --------------------------------------------------------------------------- + +// Fixed KV: all batches have exactly fix_skv tokens, no padding +inline void build_fixed_cu_seqlens_kv(int bs, + int fix_skv, + std::vector& cu_seqlens_kv, + std::vector& cu_seqlens_kv_padded, + int& total_padded_kv_seq) +{ + cu_seqlens_kv.resize(bs + 1); + cu_seqlens_kv_padded.resize(bs + 1); + cu_seqlens_kv[0] = cu_seqlens_kv_padded[0] = 0; + + for(int b = 0; b < bs; b++) + { + cu_seqlens_kv[b + 1] = cu_seqlens_kv[b] + fix_skv; + cu_seqlens_kv_padded[b + 1] = cu_seqlens_kv_padded[b] + fix_skv; + } + + total_padded_kv_seq = cu_seqlens_kv_padded[bs]; +} + +// --------------------------------------------------------------------------- +// Main backward correctness + performance test (MFMA 16x16x16 variant) +// --------------------------------------------------------------------------- + +template +void test_run_attn_bwd_mfma_16x16( + bool varlen_q, int fix_sq, + bool varlen_kv, int fix_skv, + const std::string& label, + int warmup_iters, int test_iters, + bool check_correctness, bool dump_err, + float cmp_rtol = 1e-2f, + float cmp_atol = 1e-2f) +{ + using BwdLauncher = AttnBackwardMfma16x16KernelLauncher; + using FwdLauncher = AttnForwardMfma16x16KernelLauncher; + + constexpr int bs = Config::bs; + constexpr int head_num = Config::head_num; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int max_seq_q = Config::max_seq_q; + constexpr int head_dim = Config::head_dim; + + std::mt19937 gen(42); + std::uniform_real_distribution dis(-1.0f, 1.0f); + + // --- Build cu_seqlens --- + std::vector h_cu_seqlens_q, h_cu_seqlens_q_padded, h_padded_q_to_batch; + std::vector h_cu_seqlens_kv, h_cu_seqlens_kv_padded; + int total_padded_kv_seq; + int total_padded_q; + int total_actual_kv_seq; + + if(varlen_q) + total_padded_q = build_varlen_cu_seqlens_q( + bs, max_seq_q, gen, h_cu_seqlens_q, h_cu_seqlens_q_padded, h_padded_q_to_batch); + else + total_padded_q = build_fixed_cu_seqlens_q( + bs, fix_sq, h_cu_seqlens_q, h_cu_seqlens_q_padded, h_padded_q_to_batch); + + if(varlen_kv) + build_cu_seqlens_kv( + bs, max_seq_kv, gen, h_cu_seqlens_kv, h_cu_seqlens_kv_padded, + total_actual_kv_seq, total_padded_kv_seq); + else + build_fixed_cu_seqlens_kv( + bs, fix_skv, h_cu_seqlens_kv, h_cu_seqlens_kv_padded, total_padded_kv_seq); + + // --- Buffer sizes --- + size_t size_Q = (size_t)total_padded_q * head_num * head_dim; + size_t size_K = (size_t)total_padded_kv_seq * head_num * head_dim; + size_t size_V = size_K; + size_t size_grad_O = size_Q; + size_t size_attn_weights = (size_t)total_padded_q * head_num * max_seq_kv; + + // --- Host allocations --- + std::vector h_Q(size_Q), h_K(size_K), h_V(size_V); + std::vector h_grad_O(size_grad_O); + std::vector h_attn_weights(size_attn_weights, DataType(0.0f)); + std::vector h_grad_Q_gpu(size_Q, DataType(0.0f)); + std::vector h_grad_K_gpu(size_K, DataType(0.0f)); + std::vector h_grad_V_gpu(size_V, DataType(0.0f)); + std::vector h_grad_Q_cpu(size_Q, DataType(0.0f)); + std::vector h_grad_K_cpu(size_K, DataType(0.0f)); + std::vector h_grad_V_cpu(size_V, DataType(0.0f)); + + // Initialize Q, K, V, grad_O + for(size_t i = 0; i < size_Q; i++) h_Q[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_K; i++) h_K[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_V; i++) h_V[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_grad_O; i++) h_grad_O[i] = DataType(dis(gen)); + + // Pre-round to bf16 precision + if constexpr(std::is_same::value) + { + for(size_t i = 0; i < size_Q; i++) h_Q[i] = float(hip_bfloat16(h_Q[i])); + for(size_t i = 0; i < size_K; i++) h_K[i] = float(hip_bfloat16(h_K[i])); + for(size_t i = 0; i < size_V; i++) h_V[i] = float(hip_bfloat16(h_V[i])); + for(size_t i = 0; i < size_grad_O; i++) h_grad_O[i] = float(hip_bfloat16(h_grad_O[i])); + } + + float sqr_dk_scale = 1.0f / std::sqrt(static_cast(head_dim)); + + // --- Device allocations --- + DataType *d_Q, *d_K, *d_V, *d_O, *d_grad_O; + float* d_softmax_lse; + DataType *d_grad_Q, *d_grad_K, *d_grad_V; + int *d_cu_seqlens_q, *d_cu_seqlens_q_padded; + int *d_cu_seqlens_kv, *d_cu_seqlens_kv_padded; + int* d_padded_q_to_batch; + + HIP_CHECK(hipMalloc(&d_Q, size_Q * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_K, size_K * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_V, size_V * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_O, total_padded_q * head_num * head_dim * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_grad_O, size_grad_O * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_softmax_lse, + FwdLauncher::calc_workspace_size(total_padded_q) > 0 + ? FwdLauncher::calc_workspace_size(total_padded_q) + : sizeof(float))); + HIP_CHECK(hipMalloc(&d_grad_Q, size_Q * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_grad_K, size_K * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_grad_V, size_V * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q_padded, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv_padded, (bs + 1) * sizeof(int))); + if(total_padded_q > 0) + HIP_CHECK(hipMalloc(&d_padded_q_to_batch, total_padded_q * sizeof(int))); + else + d_padded_q_to_batch = nullptr; + + // --- Copy to device --- + HIP_CHECK(hipMemcpy(d_Q, h_Q.data(), size_Q * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_K, h_K.data(), size_K * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_V, h_V.data(), size_V * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_grad_O, h_grad_O.data(), size_grad_O * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q, h_cu_seqlens_q.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q_padded, h_cu_seqlens_q_padded.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv, h_cu_seqlens_kv.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv_padded, h_cu_seqlens_kv_padded.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + if(total_padded_q > 0) + HIP_CHECK(hipMemcpy(d_padded_q_to_batch, h_padded_q_to_batch.data(), + total_padded_q * sizeof(int), hipMemcpyHostToDevice)); + + // MFMA forward writes softmax_lse used by Option A backward (must match GPU recomputation). + FwdLauncher::run_attn_fwd_kernel(d_Q, d_K, d_V, static_cast(nullptr), 0.0f, + sqr_dk_scale, d_O, d_softmax_lse, + d_cu_seqlens_q, d_cu_seqlens_q_padded, + d_cu_seqlens_kv, d_cu_seqlens_kv_padded, + d_padded_q_to_batch, total_padded_q); + HIP_CHECK(hipDeviceSynchronize()); + + std::vector h_softmax_lse(total_padded_q * head_num); + HIP_CHECK(hipMemcpy(h_softmax_lse.data(), d_softmax_lse, + h_softmax_lse.size() * sizeof(float), hipMemcpyDeviceToHost)); + + reference_attn_probs_bf16_dots_with_given_lse( + h_Q, h_K, bs, head_num, max_seq_kv, head_dim, sqr_dk_scale, Config::mask_type, + h_cu_seqlens_q, h_cu_seqlens_q_padded, h_cu_seqlens_kv, h_cu_seqlens_kv_padded, + h_softmax_lse, h_attn_weights); + + // --- CPU reference (bf16_weights=false: GPU Option A uses float P, not bf16-stored weights) + if(check_correctness) + attn_backward(h_Q.data(), h_K.data(), h_V.data(), h_grad_O.data(), + h_attn_weights.data(), static_cast(nullptr), 0.0f, + h_grad_Q_cpu.data(), h_grad_K_cpu.data(), h_grad_V_cpu.data(), + bs, head_num, max_seq_kv, head_dim, Config::mask_type, + h_cu_seqlens_q.data(), h_cu_seqlens_q_padded.data(), + h_cu_seqlens_kv.data(), h_cu_seqlens_kv_padded.data(), + total_padded_q, total_padded_kv_seq, + max_seq_q, false); + + auto bwd_launch = [&]() { + HIP_CHECK(hipMemset(d_grad_Q, 0, size_Q * sizeof(DataType))); + HIP_CHECK(hipMemset(d_grad_K, 0, size_K * sizeof(DataType))); + HIP_CHECK(hipMemset(d_grad_V, 0, size_V * sizeof(DataType))); + BwdLauncher::run_attn_bwd_kernel(d_Q, d_K, d_V, d_grad_O, d_softmax_lse, + d_grad_Q, d_grad_K, d_grad_V, sqr_dk_scale, + d_cu_seqlens_q, d_cu_seqlens_q_padded, + d_cu_seqlens_kv, d_cu_seqlens_kv_padded); + }; + + for(int i = 0; i < warmup_iters; i++) bwd_launch(); + HIP_CHECK(hipDeviceSynchronize()); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + HIP_CHECK(hipEventRecord(start)); + for(int i = 0; i < test_iters; i++) bwd_launch(); + HIP_CHECK(hipEventRecord(stop)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_ms = 0; + HIP_CHECK(hipEventElapsedTime(&elapsed_ms, start, stop)); + double avg_time_ms = elapsed_ms / test_iters; + + HIP_CHECK(hipMemcpy(h_grad_Q_gpu.data(), d_grad_Q, size_Q * sizeof(DataType), hipMemcpyDeviceToHost)); + HIP_CHECK(hipMemcpy(h_grad_K_gpu.data(), d_grad_K, size_K * sizeof(DataType), hipMemcpyDeviceToHost)); + HIP_CHECK(hipMemcpy(h_grad_V_gpu.data(), d_grad_V, size_V * sizeof(DataType), hipMemcpyDeviceToHost)); + + // --- Report --- + std::cout << "\n===== " << label << " =====" << std::endl; + std::cout << "Configuration:" << std::endl; + std::cout << " Batch size: " << bs << std::endl; + std::cout << " Heads: " << head_num << std::endl; + std::cout << " max_seq_q: " << max_seq_q << " max_seq_kv: " << max_seq_kv << std::endl; + std::cout << " Head dimension: " << head_dim << std::endl; + std::cout << " Mask: " << CausalMaskTypeName[Config::mask_type] << std::endl; + std::cout << " Q mode: " << (varlen_q ? "varlen+padding" : "fixed") << std::endl; + std::cout << " KV mode: " << (varlen_kv ? "varlen+padding" : "fixed") << std::endl; + std::cout << " total_padded_q: " << total_padded_q + << " total_padded_kv: " << total_padded_kv_seq << std::endl; + std::cout << std::endl; + + if(check_correctness) + { + std::cout << "Correctness:" << std::endl; + check_output(h_grad_Q_gpu, h_grad_Q_cpu, bs, head_num, head_dim, + h_cu_seqlens_q, h_cu_seqlens_q_padded, "grad_Q", cmp_rtol, cmp_atol, dump_err); + check_output(h_grad_K_gpu, h_grad_K_cpu, bs, head_num, head_dim, + h_cu_seqlens_kv, h_cu_seqlens_kv_padded, "grad_K", cmp_rtol, cmp_atol, dump_err); + check_output(h_grad_V_gpu, h_grad_V_cpu, bs, head_num, head_dim, + h_cu_seqlens_kv, h_cu_seqlens_kv_padded, "grad_V", cmp_rtol, cmp_atol, dump_err); + std::cout << std::endl; + } + + std::cout << "Performance:" << std::endl; + std::cout << " Average time: " << std::fixed << std::setprecision(3) << avg_time_ms << " ms" << std::endl; + std::cout << "====================================\n" << std::endl; + + // --- Cleanup --- + HIP_CHECK(hipFree(d_Q)); HIP_CHECK(hipFree(d_K)); HIP_CHECK(hipFree(d_V)); + HIP_CHECK(hipFree(d_O)); HIP_CHECK(hipFree(d_grad_O)); HIP_CHECK(hipFree(d_softmax_lse)); + HIP_CHECK(hipFree(d_grad_Q)); HIP_CHECK(hipFree(d_grad_K)); HIP_CHECK(hipFree(d_grad_V)); + HIP_CHECK(hipFree(d_cu_seqlens_q)); HIP_CHECK(hipFree(d_cu_seqlens_q_padded)); + HIP_CHECK(hipFree(d_cu_seqlens_kv)); HIP_CHECK(hipFree(d_cu_seqlens_kv_padded)); + if(d_padded_q_to_batch) HIP_CHECK(hipFree(d_padded_q_to_batch)); + HIP_CHECK(hipEventDestroy(start)); HIP_CHECK(hipEventDestroy(stop)); +} + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main(int argc, char const* argv[]) +{ + // Test 1: sq∈[1,16], skv∈[2,16]; varlen + padding + // Looser cmp_rtol/cmp_atol: CPU ref P uses chunked bf16 dots; MFMA S differs (see log max ~0.23). + { + using Cfg = FmhaKernelConfig<2048, 8, 16, 128, 256, false, CausalMaskType::DISABLE, 16>; + test_run_attn_bwd_mfma_16x16( + true, 0, true, 0, + "Test 1: sq∈[1,16] varlen+pad, skv∈[2,16] varlen+pad", + 1, 1, true, true, + 0.12f, 0.23f); + } + + // Test 2: sq=1 (fixed, no padding), skv∈[2,16] (varlen + padding) + { + using Cfg = FmhaKernelConfig<2048, 8, 16, 128, 256, false, CausalMaskType::DISABLE, 1>; + test_run_attn_bwd_mfma_16x16( + false, 1, true, 0, + "Test 2: sq=1 fixed, skv∈[2,16] varlen+pad", + 1, 1, true, true, + 0.12f, 0.23f); + } + + // Test 3: sq=16, skv=16; fixed, no padding + { + using Cfg = FmhaKernelConfig<2048, 8, 16, 128, 256, false, CausalMaskType::DISABLE, 16>; + test_run_attn_bwd_mfma_16x16( + false, 16, false, 16, + "Test 3: sq=16 fixed, skv=16 fixed", + 1, 1, true, true); + } + + // Test 4: sq=17, skv=17; fixed, no padding + { + using Cfg = FmhaKernelConfig<2048, 8, 17, 128, 256, false, CausalMaskType::DISABLE, 17>; + test_run_attn_bwd_mfma_16x16( + false, 17, false, 17, + "Test 4: sq=17 fixed, skv=17 fixed", + 1, 1, true, true); + } + + return 0; +} diff --git a/tests/cpp/small_seq_kernels/tests/test_fwd.cpp b/tests/cpp/small_seq_kernels/tests/test_fwd.cpp new file mode 100644 index 000000000..dc84c751e --- /dev/null +++ b/tests/cpp/small_seq_kernels/tests/test_fwd.cpp @@ -0,0 +1,406 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// clang-format off +// Build: cmake -B build && cmake --build build && ./build/test_fwd +// clang-format on + +#include "attn_fwd.h" +#include "attn_fwd_ref.h" +#include "test_utils.h" + +#include +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// Main forward correctness + performance test +// --------------------------------------------------------------------------- + +template +void test_run_attn_fwd_kernel( + float dropout_p, int warmup_iters, int test_iters, bool check_correctness, bool dump_err) +{ + using Launcher = AttnForwardKernelLauncher; + + constexpr int bs = Config::bs; + constexpr int head_num = Config::head_num; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + + std::mt19937 gen(42); // Fixed seed for reproducibility + std::uniform_real_distribution dis(-1.0f, 1.0f); + + // --- Build cu_seqlens_q --- + std::vector h_cu_seqlens_q, h_cu_seqlens_q_padded, h_padded_q_to_batch; + int total_padded_q = build_cu_seqlens_q(bs, gen, h_cu_seqlens_q, h_cu_seqlens_q_padded, + h_padded_q_to_batch); + int total_actual_q_seq = h_cu_seqlens_q[bs]; + + // --- Build cu_seqlens_kv --- + std::vector h_cu_seqlens_kv, h_cu_seqlens_kv_padded; + int total_actual_kv_seq, total_padded_kv_seq; + build_cu_seqlens_kv(bs, max_seq_kv, gen, h_cu_seqlens_kv, h_cu_seqlens_kv_padded, + total_actual_kv_seq, total_padded_kv_seq); + + // --- Buffer sizes --- + size_t size_Q = (size_t)total_padded_q * head_num * head_dim; + size_t size_K = (size_t)total_padded_kv_seq * head_num * head_dim; + size_t size_V = (size_t)total_padded_kv_seq * head_num * head_dim; + size_t size_O = (size_t)total_padded_q * head_num * head_dim; + size_t size_dropout_mask = (size_t)bs * head_num * max_seq_kv; + + // --- Host allocations --- + std::vector h_Q(size_Q), h_K(size_K), h_V(size_V); + std::vector h_dropout_mask(size_dropout_mask); + std::vector h_O_gpu(size_O, DataType(0.0f)); + std::vector h_O_cpu(size_O, DataType(0.0f)); + + for(size_t i = 0; i < size_Q; i++) h_Q[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_K; i++) h_K[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_V; i++) h_V[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_dropout_mask; i++) + h_dropout_mask[i] = Config::enable_dropout_mask + ? DataType(dis(gen) > dropout_p ? 1.0f : 0.0f) + : DataType(1.0f); + + // --- CPU reference --- + float sqr_dk_scale = 1.0f / std::sqrt(static_cast(head_dim)); + if(check_correctness) + attn_forward(h_Q.data(), h_K.data(), h_V.data(), + Config::enable_dropout_mask ? h_dropout_mask.data() : nullptr, + dropout_p, h_O_cpu.data(), static_cast(nullptr), + bs, head_num, max_seq_kv, head_dim, Config::mask_type, + h_cu_seqlens_q.data(), h_cu_seqlens_q_padded.data(), + h_cu_seqlens_kv.data(), h_cu_seqlens_kv_padded.data()); + + // --- Device allocations --- + DataType *d_Q, *d_K, *d_V, *d_dropout_mask, *d_O, *d_workspace; + int *d_cu_seqlens_q, *d_cu_seqlens_q_padded; + int *d_cu_seqlens_kv, *d_cu_seqlens_kv_padded; + int* d_padded_q_to_batch; + + HIP_CHECK(hipMalloc(&d_Q, size_Q * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_K, size_K * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_V, size_V * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_dropout_mask, size_dropout_mask * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_O, size_O * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q_padded, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv_padded, (bs + 1) * sizeof(int))); + if(total_padded_q > 0) + HIP_CHECK(hipMalloc(&d_padded_q_to_batch, total_padded_q * sizeof(int))); + else + d_padded_q_to_batch = nullptr; + + size_t workspace_size = Launcher::calc_workspace_size(total_padded_q); + HIP_CHECK(hipMalloc(&d_workspace, workspace_size > 0 ? workspace_size : 1)); + + // --- Copy to device --- + HIP_CHECK(hipMemcpy(d_Q, h_Q.data(), size_Q * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_K, h_K.data(), size_K * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_V, h_V.data(), size_V * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_dropout_mask, h_dropout_mask.data(), + size_dropout_mask * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q, h_cu_seqlens_q.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q_padded, h_cu_seqlens_q_padded.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv, h_cu_seqlens_kv.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv_padded, h_cu_seqlens_kv_padded.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + if(total_padded_q > 0) + HIP_CHECK(hipMemcpy(d_padded_q_to_batch, h_padded_q_to_batch.data(), + total_padded_q * sizeof(int), hipMemcpyHostToDevice)); + + auto launch = [&]() { + Launcher::run_attn_fwd_kernel(d_Q, d_K, d_V, + Config::enable_dropout_mask ? d_dropout_mask : nullptr, + dropout_p, sqr_dk_scale, d_O, d_workspace, + d_cu_seqlens_q, d_cu_seqlens_q_padded, + d_cu_seqlens_kv, d_cu_seqlens_kv_padded, + d_padded_q_to_batch, total_padded_q); + }; + + for(int i = 0; i < warmup_iters; i++) launch(); + HIP_CHECK(hipDeviceSynchronize()); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + HIP_CHECK(hipEventRecord(start)); + for(int i = 0; i < test_iters; i++) launch(); + HIP_CHECK(hipEventRecord(stop)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_ms = 0; + HIP_CHECK(hipEventElapsedTime(&elapsed_ms, start, stop)); + double avg_time_ms = elapsed_ms / test_iters; + + HIP_CHECK(hipMemcpy(h_O_gpu.data(), d_O, size_O * sizeof(DataType), hipMemcpyDeviceToHost)); + + // --- Report --- + double avg_kv_seq = static_cast(total_actual_kv_seq) / bs; + double active_q = static_cast(total_actual_q_seq); + double flops_per_batch_head = 2.0 * avg_kv_seq * head_dim + 2.0 * head_dim * avg_kv_seq; + double total_flops = flops_per_batch_head * active_q * head_num; + double tflops = (total_flops / 1e12) / (avg_time_ms / 1000.0); + + size_t bytes_read = (size_Q + size_K + size_V) * sizeof(DataType); + if(Config::enable_dropout_mask) bytes_read += size_dropout_mask * sizeof(DataType); + size_t bytes_write = size_O * sizeof(DataType); + size_t total_bytes = bytes_read + bytes_write; + double bandwidth_gbps = (total_bytes / 1e9) / (avg_time_ms / 1000.0); + + std::cout << "\n===== run_attn_fwd_kernel Test =====" << std::endl; + std::cout << "Configuration:" << std::endl; + std::cout << " Batch size: " << bs << std::endl; + std::cout << " Heads: " << head_num << std::endl; + std::cout << " Q seq (active/total): " << total_actual_q_seq << "/" << bs << std::endl; + std::cout << " KV seq (avg): " << std::fixed << std::setprecision(2) << avg_kv_seq + << " (max: " << max_seq_kv << ")" << std::endl; + std::cout << " Head dimension: " << head_dim << std::endl; + std::cout << " Dropout: " << (Config::enable_dropout_mask ? "enabled" : "disabled") << std::endl; + std::cout << " Mask: " << CausalMaskTypeName[Config::mask_type] << std::endl; + std::cout << std::endl; + + if(check_correctness) + { + std::cout << "Correctness:" << std::endl; + check_output(h_O_gpu, h_O_cpu, bs, head_num, head_dim, + h_cu_seqlens_q, h_cu_seqlens_q_padded, "Output", 1e-2f, 1e-2f, dump_err); + std::cout << std::endl; + } + + std::cout << "Memory:" << std::endl; + std::cout << " Total data read: " << std::fixed << std::setprecision(2) + << bytes_read / 1e6 << " MB" << std::endl; + std::cout << " Total data write: " << bytes_write / 1e6 << " MB" << std::endl; + std::cout << " Total data transfer: " << total_bytes / 1e6 << " MB" << std::endl; + std::cout << " Workspace: " << workspace_size / 1e6 << " MB" << std::endl; + std::cout << std::endl; + + std::cout << "Performance:" << std::endl; + std::cout << " Average time: " << std::fixed << std::setprecision(3) << avg_time_ms + << " ms" << std::endl; + std::cout << " Bandwidth: " << std::fixed << std::setprecision(2) << bandwidth_gbps + << " GB/s" << std::endl; + std::cout << " TFLOPS: " << std::fixed << std::setprecision(2) << tflops << std::endl; + std::cout << "====================================\n" << std::endl; + + // --- Cleanup --- + HIP_CHECK(hipFree(d_Q)); HIP_CHECK(hipFree(d_K)); HIP_CHECK(hipFree(d_V)); + HIP_CHECK(hipFree(d_dropout_mask)); HIP_CHECK(hipFree(d_O)); HIP_CHECK(hipFree(d_workspace)); + HIP_CHECK(hipFree(d_cu_seqlens_q)); HIP_CHECK(hipFree(d_cu_seqlens_q_padded)); + HIP_CHECK(hipFree(d_cu_seqlens_kv)); HIP_CHECK(hipFree(d_cu_seqlens_kv_padded)); + if(d_padded_q_to_batch) HIP_CHECK(hipFree(d_padded_q_to_batch)); + HIP_CHECK(hipEventDestroy(start)); HIP_CHECK(hipEventDestroy(stop)); +} + +// --------------------------------------------------------------------------- +// Corner-case test: explicit Q seqlens provided by caller +// --------------------------------------------------------------------------- + +template +void test_run_attn_fwd_with_seqlens(const std::vector& h_cu_seqlens_q, + const std::vector& h_cu_seqlens_q_padded, + const std::vector& h_padded_q_to_batch, + int total_padded_q, + float dropout_p, + bool check_correctness, + bool dump_err, + const std::string& test_name) +{ + using Launcher = AttnForwardKernelLauncher; + + constexpr int head_num = Config::head_num; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + int bs = static_cast(h_cu_seqlens_q.size()) - 1; + + std::mt19937 gen(123); + std::uniform_real_distribution dis(-1.0f, 1.0f); + + std::vector h_cu_seqlens_kv, h_cu_seqlens_kv_padded; + int total_actual_kv_seq, total_padded_kv_seq; + build_cu_seqlens_kv(bs, max_seq_kv, gen, h_cu_seqlens_kv, h_cu_seqlens_kv_padded, + total_actual_kv_seq, total_padded_kv_seq); + + size_t size_Q = (size_t)total_padded_q * head_num * head_dim; + size_t size_K = (size_t)total_padded_kv_seq * head_num * head_dim; + size_t size_V = size_K; + size_t size_O = size_Q; + size_t size_dropout_mask = (size_t)total_padded_q * head_num * max_seq_kv; + size_t size_workspace = size_dropout_mask; + + std::vector h_Q(size_Q), h_K(size_K), h_V(size_V), h_dropout_mask(size_dropout_mask); + std::vector h_O_gpu(size_O, DataType(0.0f)), h_O_cpu(size_O, DataType(0.0f)); + + for(size_t i = 0; i < size_Q; i++) h_Q[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_K; i++) h_K[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_V; i++) h_V[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_dropout_mask; i++) + h_dropout_mask[i] = Config::enable_dropout_mask + ? DataType(dis(gen) > dropout_p ? 1.0f : 0.0f) + : DataType(1.0f); + + float sqr_dk_scale = 1.0f / std::sqrt(static_cast(head_dim)); + if(check_correctness) + attn_forward(h_Q.data(), h_K.data(), h_V.data(), + Config::enable_dropout_mask ? h_dropout_mask.data() : nullptr, + dropout_p, h_O_cpu.data(), static_cast(nullptr), + bs, head_num, max_seq_kv, head_dim, Config::mask_type, + h_cu_seqlens_q.data(), h_cu_seqlens_q_padded.data(), + h_cu_seqlens_kv.data(), h_cu_seqlens_kv_padded.data()); + + DataType *d_Q, *d_K, *d_V, *d_dropout_mask, *d_O, *d_workspace; + int *d_cu_seqlens_q, *d_cu_seqlens_q_padded; + int *d_cu_seqlens_kv, *d_cu_seqlens_kv_padded; + int* d_padded_q_to_batch; + + HIP_CHECK(hipMalloc(&d_Q, size_Q * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_K, size_K * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_V, size_V * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_dropout_mask, size_dropout_mask * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_O, size_O * sizeof(DataType))); + size_t workspace_size = Launcher::calc_workspace_size(total_padded_q); + HIP_CHECK(hipMalloc(&d_workspace, workspace_size > 0 ? workspace_size : 1)); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q_padded, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv_padded, (bs + 1) * sizeof(int))); + if(total_padded_q > 0) + HIP_CHECK(hipMalloc(&d_padded_q_to_batch, total_padded_q * sizeof(int))); + else + d_padded_q_to_batch = nullptr; + + // Init workspace to -1e9f so padding rows are defined for softmax + std::vector h_workspace_init(size_workspace, DataType(-1e9f)); + HIP_CHECK(hipMemcpy(d_workspace, h_workspace_init.data(), + size_workspace * sizeof(DataType), hipMemcpyHostToDevice)); + + HIP_CHECK(hipMemcpy(d_Q, h_Q.data(), size_Q * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_K, h_K.data(), size_K * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_V, h_V.data(), size_V * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_dropout_mask, h_dropout_mask.data(), + size_dropout_mask * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q, h_cu_seqlens_q.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q_padded, h_cu_seqlens_q_padded.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv, h_cu_seqlens_kv.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv_padded, h_cu_seqlens_kv_padded.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + if(total_padded_q > 0) + HIP_CHECK(hipMemcpy(d_padded_q_to_batch, h_padded_q_to_batch.data(), + total_padded_q * sizeof(int), hipMemcpyHostToDevice)); + + Launcher::run_attn_fwd_kernel(d_Q, d_K, d_V, + Config::enable_dropout_mask ? d_dropout_mask : nullptr, + dropout_p, sqr_dk_scale, d_O, d_workspace, + d_cu_seqlens_q, d_cu_seqlens_q_padded, + d_cu_seqlens_kv, d_cu_seqlens_kv_padded, + d_padded_q_to_batch, total_padded_q); + + HIP_CHECK(hipMemcpy(h_O_gpu.data(), d_O, size_O * sizeof(DataType), hipMemcpyDeviceToHost)); + + if(check_correctness) + { + check_output(h_O_gpu, h_O_cpu, bs, head_num, head_dim, + h_cu_seqlens_q, h_cu_seqlens_q_padded, + test_name + " Output", 1e-2f, 1e-2f, dump_err); + } + + HIP_CHECK(hipFree(d_Q)); HIP_CHECK(hipFree(d_K)); HIP_CHECK(hipFree(d_V)); + HIP_CHECK(hipFree(d_dropout_mask)); HIP_CHECK(hipFree(d_O)); HIP_CHECK(hipFree(d_workspace)); + HIP_CHECK(hipFree(d_cu_seqlens_q)); HIP_CHECK(hipFree(d_cu_seqlens_q_padded)); + HIP_CHECK(hipFree(d_cu_seqlens_kv)); HIP_CHECK(hipFree(d_cu_seqlens_kv_padded)); + if(d_padded_q_to_batch) HIP_CHECK(hipFree(d_padded_q_to_batch)); +} + +// --------------------------------------------------------------------------- +// Functor for TestRunner +// --------------------------------------------------------------------------- + +struct RunFwd { + template + void operator()(float dropout_p, int warmup, int iters, bool check, bool dump) const { + test_run_attn_fwd_kernel(dropout_p, warmup, iters, check, dump); + } +}; + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main(int argc, char const* argv[]) +{ + std::cout << "\n========== Correctness Test (SEQ_KV 2..16, bs=30720) ==========" << std::endl; + + TestRunner<2, 16>::run( + RunFwd{}, 0.0f, 1, 1, true, true); + + std::cout << "\n========== Correctness Test (mixed Q=0/1, SEQ_KV=8, bs=128) ==========" << std::endl; + { + using MixedConfig = FmhaKernelConfig<128, 8, 8, 128, 256, false, CausalMaskType::DISABLE>; + test_run_attn_fwd_kernel(0, 1, 1, true, true); + } + + std::cout << "\n========== Performance Test (bfloat16, SEQ_KV 2..16) ==========" << std::endl; + TestRunner<2, 16>::run( + RunFwd{}, 0.0f, 3, 5, false, false); + + std::cout << "\n========== Corner: Empty segments (even batches active, bs=128) ==========" + << std::endl; + { + const int corner_bs = 128; + std::vector h_cu_seqlens_q(corner_bs + 1); + std::vector h_cu_seqlens_q_padded(corner_bs + 1); + std::vector h_padded_q_to_batch(corner_bs / 2); + h_cu_seqlens_q[0] = h_cu_seqlens_q_padded[0] = 0; + for(int b = 0; b < corner_bs; b++) + { + int actual = (b % 2 == 0) ? 1 : 0; + h_cu_seqlens_q[b + 1] = h_cu_seqlens_q[b] + actual; + h_cu_seqlens_q_padded[b + 1] = h_cu_seqlens_q_padded[b] + actual; + } + int total_padded_q = h_cu_seqlens_q_padded[corner_bs]; + for(int b = 0; b < corner_bs; b++) + if(h_cu_seqlens_q_padded[b + 1] > h_cu_seqlens_q_padded[b]) + h_padded_q_to_batch[h_cu_seqlens_q_padded[b]] = b; + + using CornerConfig = FmhaKernelConfig<128, 8, 8, 128, 256, false, CausalMaskType::DISABLE>; + test_run_attn_fwd_with_seqlens( + h_cu_seqlens_q, h_cu_seqlens_q_padded, h_padded_q_to_batch, + total_padded_q, 0.0f, true, true, "Empty segments"); + } + + std::cout << "\n========== Corner: Q padded > actual (2 slots per batch, bs=128) ==========" + << std::endl; + { + const int corner_bs = 128; + std::vector h_cu_seqlens_q(corner_bs + 1); + std::vector h_cu_seqlens_q_padded(corner_bs + 1); + std::vector h_padded_q_to_batch(256); + for(int b = 0; b <= corner_bs; b++) + { + h_cu_seqlens_q[b] = b; + h_cu_seqlens_q_padded[b] = b * 2; + } + for(int i = 0; i < 256; i++) h_padded_q_to_batch[i] = i / 2; + int total_padded_q = 256; + + using CornerConfig = FmhaKernelConfig<128, 8, 8, 128, 256, false, CausalMaskType::DISABLE>; + test_run_attn_fwd_with_seqlens( + h_cu_seqlens_q, h_cu_seqlens_q_padded, h_padded_q_to_batch, + total_padded_q, 0.0f, true, true, "Q padded > actual"); + } + + return 0; +} diff --git a/tests/cpp/small_seq_kernels/tests/test_fwd_mfma.cpp b/tests/cpp/small_seq_kernels/tests/test_fwd_mfma.cpp new file mode 100644 index 000000000..b69d67c6b --- /dev/null +++ b/tests/cpp/small_seq_kernels/tests/test_fwd_mfma.cpp @@ -0,0 +1,420 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Test host for the fused MFMA forward kernel (attn_fwd_mfma.h). +// Uses the same test infrastructure and CPU reference as test_fwd.cpp. +// +// Build: cmake -B build && cmake --build build && ./build/test_fwd_mfma + +#include "attn_fwd_mfma.h" +#include "attn_fwd_ref.h" +#include "test_utils.h" + +#include +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// Main forward correctness + performance test (MFMA variant) +// --------------------------------------------------------------------------- + +template +void test_run_attn_fwd_mfma_kernel( + float dropout_p, int warmup_iters, int test_iters, bool check_correctness, bool dump_err) +{ + using Launcher = AttnForwardMfmaKernelLauncher; + + constexpr int bs = Config::bs; + constexpr int head_num = Config::head_num; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + + std::mt19937 gen(42); // Fixed seed for reproducibility + std::uniform_real_distribution dis(-1.0f, 1.0f); + + // --- Build cu_seqlens_q --- + std::vector h_cu_seqlens_q, h_cu_seqlens_q_padded, h_padded_q_to_batch; + int total_padded_q = build_cu_seqlens_q(bs, gen, h_cu_seqlens_q, h_cu_seqlens_q_padded, + h_padded_q_to_batch); + int total_actual_q_seq = h_cu_seqlens_q[bs]; + + // --- Build cu_seqlens_kv --- + std::vector h_cu_seqlens_kv, h_cu_seqlens_kv_padded; + int total_actual_kv_seq, total_padded_kv_seq; + build_cu_seqlens_kv(bs, max_seq_kv, gen, h_cu_seqlens_kv, h_cu_seqlens_kv_padded, + total_actual_kv_seq, total_padded_kv_seq); + + // --- Buffer sizes --- + size_t size_Q = (size_t)total_padded_q * head_num * head_dim; + size_t size_K = (size_t)total_padded_kv_seq * head_num * head_dim; + size_t size_V = (size_t)total_padded_kv_seq * head_num * head_dim; + size_t size_O = (size_t)total_padded_q * head_num * head_dim; + size_t size_dropout_mask = (size_t)bs * head_num * max_seq_kv; + + // --- Host allocations --- + std::vector h_Q(size_Q), h_K(size_K), h_V(size_V); + std::vector h_dropout_mask(size_dropout_mask); + std::vector h_O_gpu(size_O, DataType(0.0f)); + std::vector h_O_cpu(size_O, DataType(0.0f)); + + for(size_t i = 0; i < size_Q; i++) h_Q[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_K; i++) h_K[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_V; i++) h_V[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_dropout_mask; i++) + h_dropout_mask[i] = Config::enable_dropout_mask + ? DataType(dis(gen) > dropout_p ? 1.0f : 0.0f) + : DataType(1.0f); + + // Pre-round to bf16 precision so CPU reference and MFMA kernel + // (which converts to bf16 internally) see identical input values + if constexpr(std::is_same::value) + { + for(size_t i = 0; i < size_Q; i++) h_Q[i] = float(hip_bfloat16(h_Q[i])); + for(size_t i = 0; i < size_K; i++) h_K[i] = float(hip_bfloat16(h_K[i])); + for(size_t i = 0; i < size_V; i++) h_V[i] = float(hip_bfloat16(h_V[i])); + } + + // --- CPU reference --- + float sqr_dk_scale = 1.0f / std::sqrt(static_cast(head_dim)); + if(check_correctness) + attn_forward(h_Q.data(), h_K.data(), h_V.data(), + Config::enable_dropout_mask ? h_dropout_mask.data() : nullptr, + dropout_p, h_O_cpu.data(), static_cast(nullptr), + bs, head_num, max_seq_kv, head_dim, Config::mask_type, + h_cu_seqlens_q.data(), h_cu_seqlens_q_padded.data(), + h_cu_seqlens_kv.data(), h_cu_seqlens_kv_padded.data(), + true); + + // --- Device allocations --- + DataType *d_Q, *d_K, *d_V, *d_dropout_mask, *d_O, *d_workspace; + int *d_cu_seqlens_q, *d_cu_seqlens_q_padded; + int *d_cu_seqlens_kv, *d_cu_seqlens_kv_padded; + int* d_padded_q_to_batch; + + HIP_CHECK(hipMalloc(&d_Q, size_Q * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_K, size_K * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_V, size_V * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_dropout_mask, size_dropout_mask * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_O, size_O * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q_padded, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv_padded, (bs + 1) * sizeof(int))); + if(total_padded_q > 0) + HIP_CHECK(hipMalloc(&d_padded_q_to_batch, total_padded_q * sizeof(int))); + else + d_padded_q_to_batch = nullptr; + + size_t workspace_size = Launcher::calc_workspace_size(total_padded_q); + HIP_CHECK(hipMalloc(&d_workspace, workspace_size > 0 ? workspace_size : 1)); + + // --- Copy to device --- + HIP_CHECK(hipMemcpy(d_Q, h_Q.data(), size_Q * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_K, h_K.data(), size_K * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_V, h_V.data(), size_V * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_dropout_mask, h_dropout_mask.data(), + size_dropout_mask * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q, h_cu_seqlens_q.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q_padded, h_cu_seqlens_q_padded.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv, h_cu_seqlens_kv.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv_padded, h_cu_seqlens_kv_padded.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + if(total_padded_q > 0) + HIP_CHECK(hipMemcpy(d_padded_q_to_batch, h_padded_q_to_batch.data(), + total_padded_q * sizeof(int), hipMemcpyHostToDevice)); + + auto launch = [&]() { + Launcher::run_attn_fwd_kernel(d_Q, d_K, d_V, + Config::enable_dropout_mask ? d_dropout_mask : nullptr, + dropout_p, sqr_dk_scale, d_O, d_workspace, + d_cu_seqlens_q, d_cu_seqlens_q_padded, + d_cu_seqlens_kv, d_cu_seqlens_kv_padded, + d_padded_q_to_batch, total_padded_q); + }; + + for(int i = 0; i < warmup_iters; i++) launch(); + HIP_CHECK(hipDeviceSynchronize()); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + HIP_CHECK(hipEventRecord(start)); + for(int i = 0; i < test_iters; i++) launch(); + HIP_CHECK(hipEventRecord(stop)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_ms = 0; + HIP_CHECK(hipEventElapsedTime(&elapsed_ms, start, stop)); + double avg_time_ms = elapsed_ms / test_iters; + + HIP_CHECK(hipMemcpy(h_O_gpu.data(), d_O, size_O * sizeof(DataType), hipMemcpyDeviceToHost)); + + // --- Report --- + double avg_kv_seq = static_cast(total_actual_kv_seq) / bs; + double active_q = static_cast(total_actual_q_seq); + double flops_per_batch_head = 2.0 * avg_kv_seq * head_dim + 2.0 * head_dim * avg_kv_seq; + double total_flops = flops_per_batch_head * active_q * head_num; + double tflops = (total_flops / 1e12) / (avg_time_ms / 1000.0); + + size_t bytes_read = (size_Q + size_K + size_V) * sizeof(DataType); + if(Config::enable_dropout_mask) bytes_read += size_dropout_mask * sizeof(DataType); + size_t bytes_write = size_O * sizeof(DataType); + size_t total_bytes = bytes_read + bytes_write; + double bandwidth_gbps = (total_bytes / 1e9) / (avg_time_ms / 1000.0); + + std::cout << "\n===== run_attn_fwd_mfma_kernel Test =====" << std::endl; + std::cout << "Configuration:" << std::endl; + std::cout << " Batch size: " << bs << std::endl; + std::cout << " Heads: " << head_num << std::endl; + std::cout << " Q seq (active/total): " << total_actual_q_seq << "/" << bs << std::endl; + std::cout << " KV seq (avg): " << std::fixed << std::setprecision(2) << avg_kv_seq + << " (max: " << max_seq_kv << ")" << std::endl; + std::cout << " Head dimension: " << head_dim << std::endl; + std::cout << " Dropout: " << (Config::enable_dropout_mask ? "enabled" : "disabled") << std::endl; + std::cout << " Mask: " << CausalMaskTypeName[Config::mask_type] << std::endl; + std::cout << std::endl; + + if(check_correctness) + { + std::cout << "Correctness:" << std::endl; + check_output(h_O_gpu, h_O_cpu, bs, head_num, head_dim, + h_cu_seqlens_q, h_cu_seqlens_q_padded, "Output", 1e-2f, 1e-2f, dump_err); + std::cout << std::endl; + } + + std::cout << "Memory:" << std::endl; + std::cout << " Total data read: " << std::fixed << std::setprecision(2) + << bytes_read / 1e6 << " MB" << std::endl; + std::cout << " Total data write: " << bytes_write / 1e6 << " MB" << std::endl; + std::cout << " Total data transfer: " << total_bytes / 1e6 << " MB" << std::endl; + std::cout << " Workspace: " << workspace_size / 1e6 << " MB" << std::endl; + std::cout << std::endl; + + std::cout << "Performance:" << std::endl; + std::cout << " Average time: " << std::fixed << std::setprecision(3) << avg_time_ms + << " ms" << std::endl; + std::cout << " Bandwidth: " << std::fixed << std::setprecision(2) << bandwidth_gbps + << " GB/s" << std::endl; + std::cout << " TFLOPS: " << std::fixed << std::setprecision(2) << tflops << std::endl; + std::cout << "====================================\n" << std::endl; + + // --- Cleanup --- + HIP_CHECK(hipFree(d_Q)); HIP_CHECK(hipFree(d_K)); HIP_CHECK(hipFree(d_V)); + HIP_CHECK(hipFree(d_dropout_mask)); HIP_CHECK(hipFree(d_O)); HIP_CHECK(hipFree(d_workspace)); + HIP_CHECK(hipFree(d_cu_seqlens_q)); HIP_CHECK(hipFree(d_cu_seqlens_q_padded)); + HIP_CHECK(hipFree(d_cu_seqlens_kv)); HIP_CHECK(hipFree(d_cu_seqlens_kv_padded)); + if(d_padded_q_to_batch) HIP_CHECK(hipFree(d_padded_q_to_batch)); + HIP_CHECK(hipEventDestroy(start)); HIP_CHECK(hipEventDestroy(stop)); +} + +// --------------------------------------------------------------------------- +// Corner-case test: explicit Q seqlens provided by caller (MFMA variant) +// --------------------------------------------------------------------------- + +template +void test_run_attn_fwd_mfma_with_seqlens(const std::vector& h_cu_seqlens_q, + const std::vector& h_cu_seqlens_q_padded, + const std::vector& h_padded_q_to_batch, + int total_padded_q, + float dropout_p, + bool check_correctness, + bool dump_err, + const std::string& test_name) +{ + using Launcher = AttnForwardMfmaKernelLauncher; + + constexpr int head_num = Config::head_num; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + int bs = static_cast(h_cu_seqlens_q.size()) - 1; + + std::mt19937 gen(123); + std::uniform_real_distribution dis(-1.0f, 1.0f); + + std::vector h_cu_seqlens_kv, h_cu_seqlens_kv_padded; + int total_actual_kv_seq, total_padded_kv_seq; + build_cu_seqlens_kv(bs, max_seq_kv, gen, h_cu_seqlens_kv, h_cu_seqlens_kv_padded, + total_actual_kv_seq, total_padded_kv_seq); + + size_t size_Q = (size_t)total_padded_q * head_num * head_dim; + size_t size_K = (size_t)total_padded_kv_seq * head_num * head_dim; + size_t size_V = size_K; + size_t size_O = size_Q; + size_t size_dropout_mask = (size_t)total_padded_q * head_num * max_seq_kv; + size_t size_workspace = size_dropout_mask; + + std::vector h_Q(size_Q), h_K(size_K), h_V(size_V), h_dropout_mask(size_dropout_mask); + std::vector h_O_gpu(size_O, DataType(0.0f)), h_O_cpu(size_O, DataType(0.0f)); + + for(size_t i = 0; i < size_Q; i++) h_Q[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_K; i++) h_K[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_V; i++) h_V[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_dropout_mask; i++) + h_dropout_mask[i] = Config::enable_dropout_mask + ? DataType(dis(gen) > dropout_p ? 1.0f : 0.0f) + : DataType(1.0f); + + // Pre-round to bf16 precision so CPU reference and MFMA kernel + // (which converts to bf16 internally) see identical input values + if constexpr(std::is_same::value) + { + for(size_t i = 0; i < size_Q; i++) h_Q[i] = float(hip_bfloat16(h_Q[i])); + for(size_t i = 0; i < size_K; i++) h_K[i] = float(hip_bfloat16(h_K[i])); + for(size_t i = 0; i < size_V; i++) h_V[i] = float(hip_bfloat16(h_V[i])); + } + + float sqr_dk_scale = 1.0f / std::sqrt(static_cast(head_dim)); + if(check_correctness) + attn_forward(h_Q.data(), h_K.data(), h_V.data(), + Config::enable_dropout_mask ? h_dropout_mask.data() : nullptr, + dropout_p, h_O_cpu.data(), static_cast(nullptr), + bs, head_num, max_seq_kv, head_dim, Config::mask_type, + h_cu_seqlens_q.data(), h_cu_seqlens_q_padded.data(), + h_cu_seqlens_kv.data(), h_cu_seqlens_kv_padded.data(), + true); + + DataType *d_Q, *d_K, *d_V, *d_dropout_mask, *d_O, *d_workspace; + int *d_cu_seqlens_q, *d_cu_seqlens_q_padded; + int *d_cu_seqlens_kv, *d_cu_seqlens_kv_padded; + int* d_padded_q_to_batch; + + HIP_CHECK(hipMalloc(&d_Q, size_Q * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_K, size_K * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_V, size_V * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_dropout_mask, size_dropout_mask * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_O, size_O * sizeof(DataType))); + size_t workspace_size = Launcher::calc_workspace_size(total_padded_q); + HIP_CHECK(hipMalloc(&d_workspace, workspace_size > 0 ? workspace_size : 1)); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q_padded, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv_padded, (bs + 1) * sizeof(int))); + if(total_padded_q > 0) + HIP_CHECK(hipMalloc(&d_padded_q_to_batch, total_padded_q * sizeof(int))); + else + d_padded_q_to_batch = nullptr; + + HIP_CHECK(hipMemcpy(d_Q, h_Q.data(), size_Q * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_K, h_K.data(), size_K * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_V, h_V.data(), size_V * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_dropout_mask, h_dropout_mask.data(), + size_dropout_mask * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q, h_cu_seqlens_q.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q_padded, h_cu_seqlens_q_padded.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv, h_cu_seqlens_kv.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv_padded, h_cu_seqlens_kv_padded.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + if(total_padded_q > 0) + HIP_CHECK(hipMemcpy(d_padded_q_to_batch, h_padded_q_to_batch.data(), + total_padded_q * sizeof(int), hipMemcpyHostToDevice)); + + Launcher::run_attn_fwd_kernel(d_Q, d_K, d_V, + Config::enable_dropout_mask ? d_dropout_mask : nullptr, + dropout_p, sqr_dk_scale, d_O, d_workspace, + d_cu_seqlens_q, d_cu_seqlens_q_padded, + d_cu_seqlens_kv, d_cu_seqlens_kv_padded, + d_padded_q_to_batch, total_padded_q); + + HIP_CHECK(hipMemcpy(h_O_gpu.data(), d_O, size_O * sizeof(DataType), hipMemcpyDeviceToHost)); + + if(check_correctness) + { + check_output(h_O_gpu, h_O_cpu, bs, head_num, head_dim, + h_cu_seqlens_q, h_cu_seqlens_q_padded, + test_name + " Output", 1e-2f, 1e-2f, dump_err); + } + + HIP_CHECK(hipFree(d_Q)); HIP_CHECK(hipFree(d_K)); HIP_CHECK(hipFree(d_V)); + HIP_CHECK(hipFree(d_dropout_mask)); HIP_CHECK(hipFree(d_O)); HIP_CHECK(hipFree(d_workspace)); + HIP_CHECK(hipFree(d_cu_seqlens_q)); HIP_CHECK(hipFree(d_cu_seqlens_q_padded)); + HIP_CHECK(hipFree(d_cu_seqlens_kv)); HIP_CHECK(hipFree(d_cu_seqlens_kv_padded)); + if(d_padded_q_to_batch) HIP_CHECK(hipFree(d_padded_q_to_batch)); +} + +// --------------------------------------------------------------------------- +// Functor for TestRunner +// --------------------------------------------------------------------------- + +struct RunFwdMfma { + template + void operator()(float dropout_p, int warmup, int iters, bool check, bool dump) const { + test_run_attn_fwd_mfma_kernel(dropout_p, warmup, iters, check, dump); + } +}; + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main(int argc, char const* argv[]) +{ + std::cout << "\n========== MFMA Fwd Correctness (bf16, SEQ_KV 2..16, bs=30720) ==========" << std::endl; + + TestRunner<2, 16>::run( + RunFwdMfma{}, 0.0f, 1, 1, true, true); + + std::cout << "\n========== MFMA Fwd Correctness (bf16, mixed Q=0/1, SEQ_KV=8, bs=128) ==========" << std::endl; + { + using MixedConfig = FmhaKernelConfig<128, 8, 8, 128, 256, false, CausalMaskType::DISABLE>; + test_run_attn_fwd_mfma_kernel(0, 1, 1, true, true); + } + + std::cout << "\n========== MFMA Fwd Performance (bfloat16, SEQ_KV 2..16) ==========" << std::endl; + TestRunner<2, 16>::run( + RunFwdMfma{}, 0.0f, 3, 5, false, false); + + std::cout << "\n========== MFMA Corner: Empty segments (even batches active, bs=128) ==========" << std::endl; + { + const int corner_bs = 128; + std::vector h_cu_seqlens_q(corner_bs + 1); + std::vector h_cu_seqlens_q_padded(corner_bs + 1); + std::vector h_padded_q_to_batch(corner_bs / 2); + h_cu_seqlens_q[0] = h_cu_seqlens_q_padded[0] = 0; + for(int b = 0; b < corner_bs; b++) + { + int actual = (b % 2 == 0) ? 1 : 0; + h_cu_seqlens_q[b + 1] = h_cu_seqlens_q[b] + actual; + h_cu_seqlens_q_padded[b + 1] = h_cu_seqlens_q_padded[b] + actual; + } + int total_padded_q = h_cu_seqlens_q_padded[corner_bs]; + for(int b = 0; b < corner_bs; b++) + if(h_cu_seqlens_q_padded[b + 1] > h_cu_seqlens_q_padded[b]) + h_padded_q_to_batch[h_cu_seqlens_q_padded[b]] = b; + + using CornerConfig = FmhaKernelConfig<128, 8, 8, 128, 256, false, CausalMaskType::DISABLE>; + test_run_attn_fwd_mfma_with_seqlens( + h_cu_seqlens_q, h_cu_seqlens_q_padded, h_padded_q_to_batch, + total_padded_q, 0.0f, true, true, "Empty segments"); + } + + std::cout << "\n========== MFMA Corner: Q padded > actual (2 slots per batch, bs=128) ==========" << std::endl; + { + const int corner_bs = 128; + std::vector h_cu_seqlens_q(corner_bs + 1); + std::vector h_cu_seqlens_q_padded(corner_bs + 1); + std::vector h_padded_q_to_batch(256); + for(int b = 0; b <= corner_bs; b++) + { + h_cu_seqlens_q[b] = b; + h_cu_seqlens_q_padded[b] = b * 2; + } + for(int i = 0; i < 256; i++) h_padded_q_to_batch[i] = i / 2; + int total_padded_q = 256; + + using CornerConfig = FmhaKernelConfig<128, 8, 8, 128, 256, false, CausalMaskType::DISABLE>; + test_run_attn_fwd_mfma_with_seqlens( + h_cu_seqlens_q, h_cu_seqlens_q_padded, h_padded_q_to_batch, + total_padded_q, 0.0f, true, true, "Q padded > actual"); + } + + return 0; +} diff --git a/tests/cpp/small_seq_kernels/tests/test_fwd_mfma_16x16.cpp b/tests/cpp/small_seq_kernels/tests/test_fwd_mfma_16x16.cpp new file mode 100644 index 000000000..083ca9f80 --- /dev/null +++ b/tests/cpp/small_seq_kernels/tests/test_fwd_mfma_16x16.cpp @@ -0,0 +1,521 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Test host for the fused MFMA 16x16x16 forward kernel (attn_fwd_mfma_16x16.h). +// Uses the same test infrastructure and CPU reference as test_fwd_mfma.cpp. +// +// Build: cmake -B build && cmake --build build && ./build/test_fwd_mfma_16x16 + +#include "attn_fwd_mfma_16x16.h" +#include "attn_fwd_ref.h" +#include "test_utils.h" + +#include +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// Main forward correctness + performance test (MFMA 16x16x16 variant) +// --------------------------------------------------------------------------- + +template +void test_run_attn_fwd_mfma_16x16_kernel( + float dropout_p, int warmup_iters, int test_iters, bool check_correctness, bool dump_err) +{ + using Launcher = AttnForwardMfma16x16KernelLauncher; + + constexpr int bs = Config::bs; + constexpr int head_num = Config::head_num; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + + std::mt19937 gen(42); // Fixed seed for reproducibility + std::uniform_real_distribution dis(-1.0f, 1.0f); + + // --- Build cu_seqlens_q --- + std::vector h_cu_seqlens_q, h_cu_seqlens_q_padded, h_padded_q_to_batch; + int total_padded_q = build_cu_seqlens_q(bs, gen, h_cu_seqlens_q, h_cu_seqlens_q_padded, + h_padded_q_to_batch); + int total_actual_q_seq = h_cu_seqlens_q[bs]; + + // --- Build cu_seqlens_kv --- + std::vector h_cu_seqlens_kv, h_cu_seqlens_kv_padded; + int total_actual_kv_seq, total_padded_kv_seq; + build_cu_seqlens_kv(bs, max_seq_kv, gen, h_cu_seqlens_kv, h_cu_seqlens_kv_padded, + total_actual_kv_seq, total_padded_kv_seq); + + // --- Buffer sizes --- + size_t size_Q = (size_t)total_padded_q * head_num * head_dim; + size_t size_K = (size_t)total_padded_kv_seq * head_num * head_dim; + size_t size_V = (size_t)total_padded_kv_seq * head_num * head_dim; + size_t size_O = (size_t)total_padded_q * head_num * head_dim; + size_t size_dropout_mask = (size_t)bs * head_num * max_seq_kv; + + // --- Host allocations --- + std::vector h_Q(size_Q), h_K(size_K), h_V(size_V); + std::vector h_dropout_mask(size_dropout_mask); + std::vector h_O_gpu(size_O, DataType(0.0f)); + std::vector h_O_cpu(size_O, DataType(0.0f)); + + for(size_t i = 0; i < size_Q; i++) h_Q[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_K; i++) h_K[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_V; i++) h_V[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_dropout_mask; i++) + h_dropout_mask[i] = Config::enable_dropout_mask + ? DataType(dis(gen) > dropout_p ? 1.0f : 0.0f) + : DataType(1.0f); + + // Pre-round to bf16 precision so CPU reference and MFMA kernel + // (which converts to bf16 internally) see identical input values + if constexpr(std::is_same::value) + { + for(size_t i = 0; i < size_Q; i++) h_Q[i] = float(hip_bfloat16(h_Q[i])); + for(size_t i = 0; i < size_K; i++) h_K[i] = float(hip_bfloat16(h_K[i])); + for(size_t i = 0; i < size_V; i++) h_V[i] = float(hip_bfloat16(h_V[i])); + } + + // --- CPU reference --- + float sqr_dk_scale = 1.0f / std::sqrt(static_cast(head_dim)); + if(check_correctness) + attn_forward(h_Q.data(), h_K.data(), h_V.data(), + Config::enable_dropout_mask ? h_dropout_mask.data() : nullptr, + dropout_p, h_O_cpu.data(), static_cast(nullptr), + bs, head_num, max_seq_kv, head_dim, Config::mask_type, + h_cu_seqlens_q.data(), h_cu_seqlens_q_padded.data(), + h_cu_seqlens_kv.data(), h_cu_seqlens_kv_padded.data(), + true); + + // --- Device allocations --- + DataType *d_Q, *d_K, *d_V, *d_dropout_mask, *d_O; + float* d_softmax_lse; + int *d_cu_seqlens_q, *d_cu_seqlens_q_padded; + int *d_cu_seqlens_kv, *d_cu_seqlens_kv_padded; + int* d_padded_q_to_batch; + + HIP_CHECK(hipMalloc(&d_Q, size_Q * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_K, size_K * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_V, size_V * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_dropout_mask, size_dropout_mask * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_O, size_O * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q_padded, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv_padded, (bs + 1) * sizeof(int))); + if(total_padded_q > 0) + HIP_CHECK(hipMalloc(&d_padded_q_to_batch, total_padded_q * sizeof(int))); + else + d_padded_q_to_batch = nullptr; + + size_t lse_bytes = Launcher::calc_workspace_size(total_padded_q); + HIP_CHECK(hipMalloc(&d_softmax_lse, lse_bytes > 0 ? lse_bytes : sizeof(float))); + + // --- Copy to device --- + HIP_CHECK(hipMemcpy(d_Q, h_Q.data(), size_Q * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_K, h_K.data(), size_K * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_V, h_V.data(), size_V * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_dropout_mask, h_dropout_mask.data(), + size_dropout_mask * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q, h_cu_seqlens_q.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q_padded, h_cu_seqlens_q_padded.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv, h_cu_seqlens_kv.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv_padded, h_cu_seqlens_kv_padded.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + if(total_padded_q > 0) + HIP_CHECK(hipMemcpy(d_padded_q_to_batch, h_padded_q_to_batch.data(), + total_padded_q * sizeof(int), hipMemcpyHostToDevice)); + + auto launch = [&]() { + Launcher::run_attn_fwd_kernel(d_Q, d_K, d_V, + Config::enable_dropout_mask ? d_dropout_mask : nullptr, + dropout_p, sqr_dk_scale, d_O, d_softmax_lse, + d_cu_seqlens_q, d_cu_seqlens_q_padded, + d_cu_seqlens_kv, d_cu_seqlens_kv_padded, + d_padded_q_to_batch, total_padded_q); + }; + + for(int i = 0; i < warmup_iters; i++) launch(); + HIP_CHECK(hipDeviceSynchronize()); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + HIP_CHECK(hipEventRecord(start)); + for(int i = 0; i < test_iters; i++) launch(); + HIP_CHECK(hipEventRecord(stop)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_ms = 0; + HIP_CHECK(hipEventElapsedTime(&elapsed_ms, start, stop)); + double avg_time_ms = elapsed_ms / test_iters; + + HIP_CHECK(hipMemcpy(h_O_gpu.data(), d_O, size_O * sizeof(DataType), hipMemcpyDeviceToHost)); + + // --- Report --- + double avg_kv_seq = static_cast(total_actual_kv_seq) / bs; + double active_q = static_cast(total_actual_q_seq); + double flops_per_batch_head = 2.0 * avg_kv_seq * head_dim + 2.0 * head_dim * avg_kv_seq; + double total_flops = flops_per_batch_head * active_q * head_num; + double tflops = (total_flops / 1e12) / (avg_time_ms / 1000.0); + + size_t bytes_read = (size_Q + size_K + size_V) * sizeof(DataType); + if(Config::enable_dropout_mask) bytes_read += size_dropout_mask * sizeof(DataType); + size_t bytes_write = size_O * sizeof(DataType); + size_t total_bytes = bytes_read + bytes_write; + double bandwidth_gbps = (total_bytes / 1e9) / (avg_time_ms / 1000.0); + + std::cout << "\n===== run_attn_fwd_mfma_16x16_kernel Test =====" << std::endl; + std::cout << "Configuration:" << std::endl; + std::cout << " Batch size: " << bs << std::endl; + std::cout << " Heads: " << head_num << std::endl; + std::cout << " Q seq (active/total): " << total_actual_q_seq << "/" << bs << std::endl; + std::cout << " KV seq (avg): " << std::fixed << std::setprecision(2) << avg_kv_seq + << " (max: " << max_seq_kv << ")" << std::endl; + std::cout << " Head dimension: " << head_dim << std::endl; + std::cout << " Dropout: " << (Config::enable_dropout_mask ? "enabled" : "disabled") << std::endl; + std::cout << " Mask: " << CausalMaskTypeName[Config::mask_type] << std::endl; + std::cout << std::endl; + + if(check_correctness) + { + std::cout << "Correctness:" << std::endl; + check_output(h_O_gpu, h_O_cpu, bs, head_num, head_dim, + h_cu_seqlens_q, h_cu_seqlens_q_padded, "Output", 1e-2f, 1e-2f, dump_err); + std::cout << std::endl; + } + + std::cout << "Memory:" << std::endl; + std::cout << " Total data read: " << std::fixed << std::setprecision(2) + << bytes_read / 1e6 << " MB" << std::endl; + std::cout << " Total data write: " << bytes_write / 1e6 << " MB" << std::endl; + std::cout << " Total data transfer: " << total_bytes / 1e6 << " MB" << std::endl; + std::cout << " softmax_lse buffer: " << lse_bytes / 1e6 << " MB" << std::endl; + std::cout << std::endl; + + std::cout << "Performance:" << std::endl; + std::cout << " Average time: " << std::fixed << std::setprecision(3) << avg_time_ms + << " ms" << std::endl; + std::cout << " Bandwidth: " << std::fixed << std::setprecision(2) << bandwidth_gbps + << " GB/s" << std::endl; + std::cout << " TFLOPS: " << std::fixed << std::setprecision(2) << tflops << std::endl; + std::cout << "====================================\n" << std::endl; + + // --- Cleanup --- + HIP_CHECK(hipFree(d_Q)); HIP_CHECK(hipFree(d_K)); HIP_CHECK(hipFree(d_V)); + HIP_CHECK(hipFree(d_dropout_mask)); HIP_CHECK(hipFree(d_O)); HIP_CHECK(hipFree(d_softmax_lse)); + HIP_CHECK(hipFree(d_cu_seqlens_q)); HIP_CHECK(hipFree(d_cu_seqlens_q_padded)); + HIP_CHECK(hipFree(d_cu_seqlens_kv)); HIP_CHECK(hipFree(d_cu_seqlens_kv_padded)); + if(d_padded_q_to_batch) HIP_CHECK(hipFree(d_padded_q_to_batch)); + HIP_CHECK(hipEventDestroy(start)); HIP_CHECK(hipEventDestroy(stop)); +} + +// --------------------------------------------------------------------------- +// Corner-case test: explicit Q seqlens provided by caller (MFMA 16x16 variant) +// --------------------------------------------------------------------------- + +template +void test_run_attn_fwd_mfma_16x16_with_seqlens(const std::vector& h_cu_seqlens_q, + const std::vector& h_cu_seqlens_q_padded, + const std::vector& h_padded_q_to_batch, + int total_padded_q, + float dropout_p, + bool check_correctness, + bool dump_err, + const std::string& test_name) +{ + using Launcher = AttnForwardMfma16x16KernelLauncher; + + constexpr int head_num = Config::head_num; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + int bs = static_cast(h_cu_seqlens_q.size()) - 1; + + std::mt19937 gen(123); + std::uniform_real_distribution dis(-1.0f, 1.0f); + + std::vector h_cu_seqlens_kv, h_cu_seqlens_kv_padded; + int total_actual_kv_seq, total_padded_kv_seq; + build_cu_seqlens_kv(bs, max_seq_kv, gen, h_cu_seqlens_kv, h_cu_seqlens_kv_padded, + total_actual_kv_seq, total_padded_kv_seq); + + size_t size_Q = (size_t)total_padded_q * head_num * head_dim; + size_t size_K = (size_t)total_padded_kv_seq * head_num * head_dim; + size_t size_V = size_K; + size_t size_O = size_Q; + size_t size_dropout_mask = (size_t)total_padded_q * head_num * max_seq_kv; + size_t size_workspace = size_dropout_mask; + + std::vector h_Q(size_Q), h_K(size_K), h_V(size_V), h_dropout_mask(size_dropout_mask); + std::vector h_O_gpu(size_O, DataType(0.0f)), h_O_cpu(size_O, DataType(0.0f)); + + for(size_t i = 0; i < size_Q; i++) h_Q[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_K; i++) h_K[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_V; i++) h_V[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_dropout_mask; i++) + h_dropout_mask[i] = Config::enable_dropout_mask + ? DataType(dis(gen) > dropout_p ? 1.0f : 0.0f) + : DataType(1.0f); + + // Pre-round to bf16 precision so CPU reference and MFMA kernel + // (which converts to bf16 internally) see identical input values + if constexpr(std::is_same::value) + { + for(size_t i = 0; i < size_Q; i++) h_Q[i] = float(hip_bfloat16(h_Q[i])); + for(size_t i = 0; i < size_K; i++) h_K[i] = float(hip_bfloat16(h_K[i])); + for(size_t i = 0; i < size_V; i++) h_V[i] = float(hip_bfloat16(h_V[i])); + } + + float sqr_dk_scale = 1.0f / std::sqrt(static_cast(head_dim)); + if(check_correctness) + attn_forward(h_Q.data(), h_K.data(), h_V.data(), + Config::enable_dropout_mask ? h_dropout_mask.data() : nullptr, + dropout_p, h_O_cpu.data(), static_cast(nullptr), + bs, head_num, max_seq_kv, head_dim, Config::mask_type, + h_cu_seqlens_q.data(), h_cu_seqlens_q_padded.data(), + h_cu_seqlens_kv.data(), h_cu_seqlens_kv_padded.data(), + true); + + DataType *d_Q, *d_K, *d_V, *d_dropout_mask, *d_O; + float* d_softmax_lse; + int *d_cu_seqlens_q, *d_cu_seqlens_q_padded; + int *d_cu_seqlens_kv, *d_cu_seqlens_kv_padded; + int* d_padded_q_to_batch; + + HIP_CHECK(hipMalloc(&d_Q, size_Q * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_K, size_K * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_V, size_V * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_dropout_mask, size_dropout_mask * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_O, size_O * sizeof(DataType))); + size_t lse_bytes = Launcher::calc_workspace_size(total_padded_q); + HIP_CHECK(hipMalloc(&d_softmax_lse, lse_bytes > 0 ? lse_bytes : sizeof(float))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q_padded, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv_padded, (bs + 1) * sizeof(int))); + if(total_padded_q > 0) + HIP_CHECK(hipMalloc(&d_padded_q_to_batch, total_padded_q * sizeof(int))); + else + d_padded_q_to_batch = nullptr; + + HIP_CHECK(hipMemcpy(d_Q, h_Q.data(), size_Q * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_K, h_K.data(), size_K * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_V, h_V.data(), size_V * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_dropout_mask, h_dropout_mask.data(), + size_dropout_mask * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q, h_cu_seqlens_q.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q_padded, h_cu_seqlens_q_padded.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv, h_cu_seqlens_kv.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv_padded, h_cu_seqlens_kv_padded.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + if(total_padded_q > 0) + HIP_CHECK(hipMemcpy(d_padded_q_to_batch, h_padded_q_to_batch.data(), + total_padded_q * sizeof(int), hipMemcpyHostToDevice)); + + Launcher::run_attn_fwd_kernel(d_Q, d_K, d_V, + Config::enable_dropout_mask ? d_dropout_mask : nullptr, + dropout_p, sqr_dk_scale, d_O, d_softmax_lse, + d_cu_seqlens_q, d_cu_seqlens_q_padded, + d_cu_seqlens_kv, d_cu_seqlens_kv_padded, + d_padded_q_to_batch, total_padded_q); + + HIP_CHECK(hipMemcpy(h_O_gpu.data(), d_O, size_O * sizeof(DataType), hipMemcpyDeviceToHost)); + + if(check_correctness) + { + check_output(h_O_gpu, h_O_cpu, bs, head_num, head_dim, + h_cu_seqlens_q, h_cu_seqlens_q_padded, + test_name + " Output", 1e-2f, 1e-2f, dump_err); + } + + HIP_CHECK(hipFree(d_Q)); HIP_CHECK(hipFree(d_K)); HIP_CHECK(hipFree(d_V)); + HIP_CHECK(hipFree(d_dropout_mask)); HIP_CHECK(hipFree(d_O)); HIP_CHECK(hipFree(d_softmax_lse)); + HIP_CHECK(hipFree(d_cu_seqlens_q)); HIP_CHECK(hipFree(d_cu_seqlens_q_padded)); + HIP_CHECK(hipFree(d_cu_seqlens_kv)); HIP_CHECK(hipFree(d_cu_seqlens_kv_padded)); + if(d_padded_q_to_batch) HIP_CHECK(hipFree(d_padded_q_to_batch)); +} + +// --------------------------------------------------------------------------- +// Functor for TestRunner +// --------------------------------------------------------------------------- + +struct RunFwdMfma16x16 { + template + void operator()(float dropout_p, int warmup, int iters, bool check, bool dump) const { + test_run_attn_fwd_mfma_16x16_kernel(dropout_p, warmup, iters, check, dump); + } +}; + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main(int argc, char const* argv[]) +{ + std::cout << "\n========== MFMA 16x16 Fwd Correctness (bf16, SEQ_KV 2..16, bs=30720) ==========" << std::endl; + + TestRunner<2, 16>::run( + RunFwdMfma16x16{}, 0.0f, 1, 1, true, true); + + std::cout << "\n========== MFMA 16x16 Fwd Correctness (bf16, mixed Q=0/1, SEQ_KV=8, bs=128) ==========" << std::endl; + { + using MixedConfig = FmhaKernelConfig<128, 8, 8, 128, 256, false, CausalMaskType::DISABLE>; + test_run_attn_fwd_mfma_16x16_kernel(0, 1, 1, true, true); + } + + std::cout << "\n========== MFMA 16x16 Fwd Performance (bfloat16, SEQ_KV 2..16) ==========" << std::endl; + TestRunner<2, 16>::run( + RunFwdMfma16x16{}, 0.0f, 3, 5, false, false); + + std::cout << "\n========== MFMA 16x16 Corner: Empty segments (even batches active, bs=128) ==========" << std::endl; + { + const int corner_bs = 128; + std::vector h_cu_seqlens_q(corner_bs + 1); + std::vector h_cu_seqlens_q_padded(corner_bs + 1); + std::vector h_padded_q_to_batch(corner_bs / 2); + h_cu_seqlens_q[0] = h_cu_seqlens_q_padded[0] = 0; + for(int b = 0; b < corner_bs; b++) + { + int actual = (b % 2 == 0) ? 1 : 0; + h_cu_seqlens_q[b + 1] = h_cu_seqlens_q[b] + actual; + h_cu_seqlens_q_padded[b + 1] = h_cu_seqlens_q_padded[b] + actual; + } + int total_padded_q = h_cu_seqlens_q_padded[corner_bs]; + for(int b = 0; b < corner_bs; b++) + if(h_cu_seqlens_q_padded[b + 1] > h_cu_seqlens_q_padded[b]) + h_padded_q_to_batch[h_cu_seqlens_q_padded[b]] = b; + + using CornerConfig = FmhaKernelConfig<128, 8, 8, 128, 256, false, CausalMaskType::DISABLE>; + test_run_attn_fwd_mfma_16x16_with_seqlens( + h_cu_seqlens_q, h_cu_seqlens_q_padded, h_padded_q_to_batch, + total_padded_q, 0.0f, true, true, "Empty segments"); + } + + std::cout << "\n========== MFMA 16x16 Corner: Q padded > actual (2 slots per batch, bs=128) ==========" << std::endl; + { + const int corner_bs = 128; + std::vector h_cu_seqlens_q(corner_bs + 1); + std::vector h_cu_seqlens_q_padded(corner_bs + 1); + std::vector h_padded_q_to_batch(256); + for(int b = 0; b <= corner_bs; b++) + { + h_cu_seqlens_q[b] = b; + h_cu_seqlens_q_padded[b] = b * 2; + } + for(int i = 0; i < 256; i++) h_padded_q_to_batch[i] = i / 2; + int total_padded_q = 256; + + using CornerConfig = FmhaKernelConfig<128, 8, 8, 128, 256, false, CausalMaskType::DISABLE>; + test_run_attn_fwd_mfma_16x16_with_seqlens( + h_cu_seqlens_q, h_cu_seqlens_q_padded, h_padded_q_to_batch, + total_padded_q, 0.0f, true, true, "Q padded > actual"); + } + + // --- Tiling boundary tests --- + // Helper lambda to build uniform Q seqlens + auto build_uniform_q = [](int bs, int sq, + std::vector& csq, std::vector& csqp, + std::vector& q2b) { + csq.resize(bs + 1); + csqp.resize(bs + 1); + csq[0] = csqp[0] = 0; + for(int b = 0; b < bs; b++) { + csq[b + 1] = csq[b] + sq; + csqp[b + 1] = csqp[b] + sq; + } + int tot = csqp[bs]; + q2b.resize(tot); + for(int b = 0; b < bs; b++) + for(int i = csqp[b]; i < csqp[b + 1]; i++) + q2b[i] = b; + return tot; + }; + + // Test 1: sq=17, skv=17 — both Q and KV cross tile boundary (2x2 = 4 tiles) + std::cout << "\n========== Tiled Fwd: sq=17, skv=17 (2x2 tiles, bs=2048) ==========" << std::endl; + { + const int bs = 2048; + std::vector csq, csqp, q2b; + int tot = build_uniform_q(bs, 17, csq, csqp, q2b); + using Cfg = FmhaKernelConfig<2048, 8, 17, 128, 256, false, CausalMaskType::DISABLE, 17>; + test_run_attn_fwd_mfma_16x16_with_seqlens( + csq, csqp, q2b, tot, 0.0f, true, true, "sq17_skv17"); + } + + // Test 2: sq=17, skv=16 — only Q crosses tile boundary (2x1 tiles) + std::cout << "\n========== Tiled Fwd: sq=17, skv=16 (Q tiled, KV single, bs=2048) ==========" << std::endl; + { + const int bs = 2048; + std::vector csq, csqp, q2b; + int tot = build_uniform_q(bs, 17, csq, csqp, q2b); + using Cfg = FmhaKernelConfig<2048, 8, 16, 128, 256, false, CausalMaskType::DISABLE, 17>; + test_run_attn_fwd_mfma_16x16_with_seqlens( + csq, csqp, q2b, tot, 0.0f, true, true, "sq17_skv16"); + } + + // Test 3: sq=16, skv=17 — only KV crosses tile boundary (1x2 tiles) + std::cout << "\n========== Tiled Fwd: sq=16, skv=17 (Q single, KV tiled, bs=2048) ==========" << std::endl; + { + const int bs = 2048; + std::vector csq, csqp, q2b; + int tot = build_uniform_q(bs, 16, csq, csqp, q2b); + using Cfg = FmhaKernelConfig<2048, 8, 17, 128, 256, false, CausalMaskType::DISABLE, 16>; + test_run_attn_fwd_mfma_16x16_with_seqlens( + csq, csqp, q2b, tot, 0.0f, true, true, "sq16_skv17"); + } + + // Test 4: sq=1, skv=17 — single Q row, KV tiled + std::cout << "\n========== Tiled Fwd: sq=1, skv=17 (single Q, KV tiled, bs=2048) ==========" << std::endl; + { + const int bs = 2048; + std::vector csq, csqp, q2b; + int tot = build_uniform_q(bs, 1, csq, csqp, q2b); + using Cfg = FmhaKernelConfig<2048, 8, 17, 128, 256, false, CausalMaskType::DISABLE, 1>; + test_run_attn_fwd_mfma_16x16_with_seqlens( + csq, csqp, q2b, tot, 0.0f, true, true, "sq1_skv17"); + } + + // Test 5: mixed Q lengths (0..17), skv=17 — variable Q with tiling + std::cout << "\n========== Tiled Fwd: mixed Q=0..17, skv=17 (bs=2048) ==========" << std::endl; + { + const int bs = 2048; + std::mt19937 gen2(99); + std::vector csq, csqp, q2b; + csq.resize(bs + 1); csqp.resize(bs + 1); + csq[0] = csqp[0] = 0; + std::uniform_int_distribution qdist(0, 17); + for(int b = 0; b < bs; b++) { + int ql = qdist(gen2); + csq[b + 1] = csq[b] + ql; + csqp[b + 1] = csqp[b] + ql; + } + int tot = csqp[bs]; + q2b.resize(tot); + for(int b = 0; b < bs; b++) + for(int i = csqp[b]; i < csqp[b + 1]; i++) + q2b[i] = b; + using Cfg = FmhaKernelConfig<2048, 8, 17, 128, 256, false, CausalMaskType::DISABLE, 17>; + test_run_attn_fwd_mfma_16x16_with_seqlens( + csq, csqp, q2b, tot, 0.0f, true, true, "mixed_q0_17_skv17"); + } + + // Test 6: sq=skv=16 through tiled path — verify single-tile still correct + std::cout << "\n========== Tiled Fwd: sq=16, skv=16 (single tile, sanity, bs=2048) ==========" << std::endl; + { + const int bs = 2048; + std::vector csq, csqp, q2b; + int tot = build_uniform_q(bs, 16, csq, csqp, q2b); + using Cfg = FmhaKernelConfig<2048, 8, 16, 128, 256, false, CausalMaskType::DISABLE, 16>; + test_run_attn_fwd_mfma_16x16_with_seqlens( + csq, csqp, q2b, tot, 0.0f, true, true, "sq16_skv16"); + } + + return 0; +} diff --git a/tests/cpp/small_seq_kernels/tests/test_fwd_mfma_multiq.cpp b/tests/cpp/small_seq_kernels/tests/test_fwd_mfma_multiq.cpp new file mode 100644 index 000000000..f524b567f --- /dev/null +++ b/tests/cpp/small_seq_kernels/tests/test_fwd_mfma_multiq.cpp @@ -0,0 +1,310 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Test host for multi-Q MFMA forward kernels with dispatch. +// Tests both 4x4x4 (max_seq_q ≤ 4) and 16x16x16 (max_seq_q > 4) paths. +// +// Build: cmake -B build2 && cmake --build build2 && ./build2/test_fwd_mfma_multiq + +#include "attn_fwd_mfma_dispatch.h" +#include "attn_fwd_ref.h" +#include "test_utils.h" + +#include +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// Main forward correctness + performance test (multi-Q dispatch) +// --------------------------------------------------------------------------- + +template +void test_run_attn_fwd_mfma_multiq_kernel( + float dropout_p, int warmup_iters, int test_iters, bool check_correctness, bool dump_err) +{ + using Launcher = AttnForwardMfmaDispatchLauncher; + + constexpr int bs = Config::bs; + constexpr int head_num = Config::head_num; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int max_seq_q = Config::max_seq_q; + + std::mt19937 gen(42); // Fixed seed for reproducibility + std::uniform_real_distribution dis(-1.0f, 1.0f); + + // --- Build cu_seqlens_q (multi-Q aware) --- + std::vector h_cu_seqlens_q, h_cu_seqlens_q_padded, h_padded_q_to_batch; + int total_padded_q = build_cu_seqlens_q(bs, gen, h_cu_seqlens_q, h_cu_seqlens_q_padded, + h_padded_q_to_batch, max_seq_q); + int total_actual_q_seq = h_cu_seqlens_q[bs]; + + // --- Build cu_seqlens_kv --- + std::vector h_cu_seqlens_kv, h_cu_seqlens_kv_padded; + int total_actual_kv_seq, total_padded_kv_seq; + build_cu_seqlens_kv(bs, max_seq_kv, gen, h_cu_seqlens_kv, h_cu_seqlens_kv_padded, + total_actual_kv_seq, total_padded_kv_seq); + + // --- Buffer sizes --- + size_t size_Q = (size_t)total_padded_q * head_num * head_dim; + size_t size_K = (size_t)total_padded_kv_seq * head_num * head_dim; + size_t size_V = (size_t)total_padded_kv_seq * head_num * head_dim; + size_t size_O = (size_t)total_padded_q * head_num * head_dim; + size_t size_dropout_mask = (size_t)total_padded_q * head_num * max_seq_kv; + + // --- Host allocations --- + std::vector h_Q(size_Q), h_K(size_K), h_V(size_V); + std::vector h_dropout_mask(size_dropout_mask); + std::vector h_O_gpu(size_O, DataType(0.0f)); + std::vector h_O_cpu(size_O, DataType(0.0f)); + + for(size_t i = 0; i < size_Q; i++) h_Q[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_K; i++) h_K[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_V; i++) h_V[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_dropout_mask; i++) + h_dropout_mask[i] = Config::enable_dropout_mask + ? DataType(dis(gen) > dropout_p ? 1.0f : 0.0f) + : DataType(1.0f); + + // Pre-round to bf16 precision so CPU reference and MFMA kernel + // (which converts to bf16 internally) see identical input values + if constexpr(std::is_same::value) + { + for(size_t i = 0; i < size_Q; i++) h_Q[i] = float(hip_bfloat16(h_Q[i])); + for(size_t i = 0; i < size_K; i++) h_K[i] = float(hip_bfloat16(h_K[i])); + for(size_t i = 0; i < size_V; i++) h_V[i] = float(hip_bfloat16(h_V[i])); + } + + // --- CPU reference --- + float sqr_dk_scale = 1.0f / std::sqrt(static_cast(head_dim)); + if(check_correctness) + attn_forward(h_Q.data(), h_K.data(), h_V.data(), + Config::enable_dropout_mask ? h_dropout_mask.data() : nullptr, + dropout_p, h_O_cpu.data(), static_cast(nullptr), + bs, head_num, max_seq_kv, head_dim, Config::mask_type, + h_cu_seqlens_q.data(), h_cu_seqlens_q_padded.data(), + h_cu_seqlens_kv.data(), h_cu_seqlens_kv_padded.data(), + true); + + // --- Device allocations --- + DataType *d_Q, *d_K, *d_V, *d_dropout_mask, *d_O; + void* d_aux; + int *d_cu_seqlens_q, *d_cu_seqlens_q_padded; + int *d_cu_seqlens_kv, *d_cu_seqlens_kv_padded; + int* d_padded_q_to_batch; + + HIP_CHECK(hipMalloc(&d_Q, std::max(size_Q, (size_t)1) * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_K, size_K * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_V, size_V * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_dropout_mask, std::max(size_dropout_mask, (size_t)1) * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_O, std::max(size_O, (size_t)1) * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q_padded, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv_padded, (bs + 1) * sizeof(int))); + if(total_padded_q > 0) + HIP_CHECK(hipMalloc(&d_padded_q_to_batch, total_padded_q * sizeof(int))); + else + d_padded_q_to_batch = nullptr; + + size_t aux_bytes = Launcher::calc_workspace_size(total_padded_q); + HIP_CHECK(hipMalloc(&d_aux, aux_bytes > 0 ? aux_bytes : (sizeof(float) > sizeof(DataType) ? sizeof(float) : sizeof(DataType)))); + + // --- Copy to device --- + if(size_Q > 0) + HIP_CHECK(hipMemcpy(d_Q, h_Q.data(), size_Q * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_K, h_K.data(), size_K * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_V, h_V.data(), size_V * sizeof(DataType), hipMemcpyHostToDevice)); + if(size_dropout_mask > 0) + HIP_CHECK(hipMemcpy(d_dropout_mask, h_dropout_mask.data(), + size_dropout_mask * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q, h_cu_seqlens_q.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q_padded, h_cu_seqlens_q_padded.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv, h_cu_seqlens_kv.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv_padded, h_cu_seqlens_kv_padded.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + if(total_padded_q > 0) + HIP_CHECK(hipMemcpy(d_padded_q_to_batch, h_padded_q_to_batch.data(), + total_padded_q * sizeof(int), hipMemcpyHostToDevice)); + + auto launch = [&]() { + Launcher::run_attn_fwd_kernel(d_Q, d_K, d_V, + Config::enable_dropout_mask ? d_dropout_mask : nullptr, + dropout_p, sqr_dk_scale, d_O, d_aux, + d_cu_seqlens_q, d_cu_seqlens_q_padded, + d_cu_seqlens_kv, d_cu_seqlens_kv_padded, + d_padded_q_to_batch, total_padded_q); + }; + + for(int i = 0; i < warmup_iters; i++) launch(); + HIP_CHECK(hipDeviceSynchronize()); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + HIP_CHECK(hipEventRecord(start)); + for(int i = 0; i < test_iters; i++) launch(); + HIP_CHECK(hipEventRecord(stop)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_ms = 0; + HIP_CHECK(hipEventElapsedTime(&elapsed_ms, start, stop)); + double avg_time_ms = elapsed_ms / test_iters; + + if(size_O > 0) + HIP_CHECK(hipMemcpy(h_O_gpu.data(), d_O, size_O * sizeof(DataType), hipMemcpyDeviceToHost)); + + // --- Report --- + const char* kernel_name = (max_seq_q <= 4) ? "4x4x4" : "16x16x16"; + + std::cout << "\n===== Multi-Q Dispatch Test (kernel=" << kernel_name + << ", max_seq_q=" << max_seq_q << ") =====" << std::endl; + std::cout << "Configuration:" << std::endl; + std::cout << " Batch size: " << bs << std::endl; + std::cout << " Heads: " << head_num << std::endl; + std::cout << " Max Q seq: " << max_seq_q << std::endl; + std::cout << " Q seq (active tokens/batches): " << total_actual_q_seq << "/" << bs << std::endl; + std::cout << " KV seq (max): " << max_seq_kv << std::endl; + std::cout << " Head dimension: " << head_dim << std::endl; + std::cout << " Dropout: " << (Config::enable_dropout_mask ? "enabled" : "disabled") << std::endl; + std::cout << " Mask: " << CausalMaskTypeName[Config::mask_type] << std::endl; + std::cout << std::endl; + + if(check_correctness) + { + std::cout << "Correctness:" << std::endl; + check_output(h_O_gpu, h_O_cpu, bs, head_num, head_dim, + h_cu_seqlens_q, h_cu_seqlens_q_padded, "Output", 1e-2f, 1e-2f, dump_err); + std::cout << std::endl; + } + + double avg_kv_seq = static_cast(total_actual_kv_seq) / bs; + double active_q = static_cast(total_actual_q_seq); + double flops_per_batch_head = 2.0 * avg_kv_seq * head_dim + 2.0 * head_dim * avg_kv_seq; + double total_flops = flops_per_batch_head * active_q * head_num; + double tflops = (total_flops / 1e12) / (avg_time_ms / 1000.0); + + size_t bytes_read = (size_Q + size_K + size_V) * sizeof(DataType); + if(Config::enable_dropout_mask) bytes_read += size_dropout_mask * sizeof(DataType); + size_t bytes_write = size_O * sizeof(DataType); + size_t total_bytes = bytes_read + bytes_write; + double bandwidth_gbps = (total_bytes / 1e9) / (avg_time_ms / 1000.0); + + std::cout << "Performance:" << std::endl; + std::cout << " Average time: " << std::fixed << std::setprecision(3) << avg_time_ms + << " ms" << std::endl; + std::cout << " Bandwidth: " << std::fixed << std::setprecision(2) << bandwidth_gbps + << " GB/s" << std::endl; + std::cout << " TFLOPS: " << std::fixed << std::setprecision(2) << tflops << std::endl; + std::cout << "====================================\n" << std::endl; + + // --- Cleanup --- + HIP_CHECK(hipFree(d_Q)); HIP_CHECK(hipFree(d_K)); HIP_CHECK(hipFree(d_V)); + HIP_CHECK(hipFree(d_dropout_mask)); HIP_CHECK(hipFree(d_O)); HIP_CHECK(hipFree(d_aux)); + HIP_CHECK(hipFree(d_cu_seqlens_q)); HIP_CHECK(hipFree(d_cu_seqlens_q_padded)); + HIP_CHECK(hipFree(d_cu_seqlens_kv)); HIP_CHECK(hipFree(d_cu_seqlens_kv_padded)); + if(d_padded_q_to_batch) HIP_CHECK(hipFree(d_padded_q_to_batch)); + HIP_CHECK(hipEventDestroy(start)); HIP_CHECK(hipEventDestroy(stop)); +} + +// --------------------------------------------------------------------------- +// Functor for TestRunner +// --------------------------------------------------------------------------- + +struct RunFwdMfmaMultiQ { + template + void operator()(float dropout_p, int warmup, int iters, bool check, bool dump) const { + test_run_attn_fwd_mfma_multiq_kernel(dropout_p, warmup, iters, check, dump); + } +}; + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main(int argc, char const* argv[]) +{ + // ===================================================================== + // Test A: Backward compat — seq_q=1 (dispatch → 4x4x4), varlen seq_kv + // ===================================================================== + std::cout << "\n========== Test A: Dispatch seq_q=1 (4x4x4), SEQ_KV 2..16, bs=30720 ==========" << std::endl; + + TestRunner<2, 16>::run( + RunFwdMfmaMultiQ{}, 0.0f, 1, 1, true, true); + + // ===================================================================== + // Test B: Multi-Q correctness — various max_seq_q via dispatch + // ===================================================================== + + // max_seq_q=1 (dispatch → 4x4x4) + std::cout << "\n========== Test B1: Multi-Q correctness, max_seq_q=1, SEQ_KV=8, bs=128 ==========" << std::endl; + { + using Cfg = FmhaKernelConfig<128, 8, 8, 128, 256, false, CausalMaskType::DISABLE, 1>; + test_run_attn_fwd_mfma_multiq_kernel(0, 1, 1, true, true); + } + + // max_seq_q=2 (dispatch → 4x4x4) + std::cout << "\n========== Test B2: Multi-Q correctness, max_seq_q=2, SEQ_KV=8, bs=128 ==========" << std::endl; + { + using Cfg = FmhaKernelConfig<128, 8, 8, 128, 256, false, CausalMaskType::DISABLE, 2>; + test_run_attn_fwd_mfma_multiq_kernel(0, 1, 1, true, true); + } + + // max_seq_q=4 (dispatch → 4x4x4) + std::cout << "\n========== Test B3: Multi-Q correctness, max_seq_q=4, SEQ_KV=8, bs=128 ==========" << std::endl; + { + using Cfg = FmhaKernelConfig<128, 8, 8, 128, 256, false, CausalMaskType::DISABLE, 4>; + test_run_attn_fwd_mfma_multiq_kernel(0, 1, 1, true, true); + } + + // max_seq_q=8 (dispatch → 16x16x16) + std::cout << "\n========== Test B4: Multi-Q correctness, max_seq_q=8, SEQ_KV=8, bs=128 ==========" << std::endl; + { + using Cfg = FmhaKernelConfig<128, 8, 8, 128, 256, false, CausalMaskType::DISABLE, 8>; + test_run_attn_fwd_mfma_multiq_kernel(0, 1, 1, true, true); + } + + // max_seq_q=16 (dispatch → 16x16x16) + std::cout << "\n========== Test B5: Multi-Q correctness, max_seq_q=16, SEQ_KV=16, bs=128 ==========" << std::endl; + { + using Cfg = FmhaKernelConfig<128, 8, 16, 128, 256, false, CausalMaskType::DISABLE, 16>; + test_run_attn_fwd_mfma_multiq_kernel(0, 1, 1, true, true); + } + + // ===================================================================== + // Test C: Performance (bfloat16, bs=30720, seq_kv=16) + // ===================================================================== + std::cout << "\n========== Test C: Performance, bfloat16, bs=30720, SEQ_KV=16 ==========" << std::endl; + + // max_seq_q=1 → 4x4x4 + { + using Cfg = FmhaKernelConfig<30720, 32, 16, 128, 256, false, CausalMaskType::DISABLE, 1>; + test_run_attn_fwd_mfma_multiq_kernel(0, 3, 5, false, false); + } + + // max_seq_q=4 → 4x4x4 + { + using Cfg = FmhaKernelConfig<30720, 32, 16, 128, 256, false, CausalMaskType::DISABLE, 4>; + test_run_attn_fwd_mfma_multiq_kernel(0, 3, 5, false, false); + } + + // max_seq_q=8 → 16x16x16 + { + using Cfg = FmhaKernelConfig<30720, 32, 16, 128, 256, false, CausalMaskType::DISABLE, 8>; + test_run_attn_fwd_mfma_multiq_kernel(0, 3, 5, false, false); + } + + // max_seq_q=16 → 16x16x16 + { + using Cfg = FmhaKernelConfig<30720, 32, 16, 128, 256, false, CausalMaskType::DISABLE, 16>; + test_run_attn_fwd_mfma_multiq_kernel(0, 3, 5, false, false); + } + + return 0; +} diff --git a/tests/cpp/small_seq_kernels/tests/test_mfma_head_dims.cpp b/tests/cpp/small_seq_kernels/tests/test_mfma_head_dims.cpp new file mode 100644 index 000000000..9502aab78 --- /dev/null +++ b/tests/cpp/small_seq_kernels/tests/test_mfma_head_dims.cpp @@ -0,0 +1,193 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// MFMA 16x16x16 forward smoke + correctness for head_dim 128, 256, 512. +// Uses AttnForwardMfma16x16KernelLauncher (same path as test_fwd_mfma_16x16). +// +// Build / run: see ../README.md + +#include "attn_common.h" +#include "attn_fwd_mfma_16x16.h" +#include "attn_fwd_ref.h" +#include "test_utils.h" + +#include +#include +#include +#include +#include + +// Uniform Q length per batch, no extra padding (actual == padded). +static int build_uniform_q(int bs, + int sq, + std::vector& csq, + std::vector& csqp, + std::vector& q2b) +{ + csq.resize(bs + 1); + csqp.resize(bs + 1); + csq[0] = csqp[0] = 0; + for(int b = 0; b < bs; b++) + { + csq[b + 1] = csq[b] + sq; + csqp[b + 1] = csqp[b] + sq; + } + int tot = csqp[bs]; + q2b.resize(tot); + for(int b = 0; b < bs; b++) + for(int i = csqp[b]; i < csqp[b + 1]; i++) + q2b[i] = b; + return tot; +} + +template +void run_fwd_head_dim_case(const std::string& label, float dropout_p, bool dump_err) +{ + using Launcher = AttnForwardMfma16x16KernelLauncher; + + constexpr int head_num = Config::head_num; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + const int bs = Config::bs; + + std::vector h_cu_seqlens_q, h_cu_seqlens_q_padded, h_padded_q_to_batch; + const int sq = Config::max_seq_q; + int total_padded_q = build_uniform_q(bs, sq, h_cu_seqlens_q, h_cu_seqlens_q_padded, + h_padded_q_to_batch); + + std::mt19937 gen(123); + std::uniform_real_distribution dis(-1.0f, 1.0f); + + std::vector h_cu_seqlens_kv, h_cu_seqlens_kv_padded; + int total_actual_kv_seq, total_padded_kv_seq; + build_cu_seqlens_kv(bs, max_seq_kv, gen, h_cu_seqlens_kv, h_cu_seqlens_kv_padded, + total_actual_kv_seq, total_padded_kv_seq); + + size_t size_Q = (size_t)total_padded_q * head_num * head_dim; + size_t size_K = (size_t)total_padded_kv_seq * head_num * head_dim; + size_t size_V = size_K; + size_t size_O = size_Q; + size_t size_dropout_mask = (size_t)total_padded_q * head_num * max_seq_kv; + + std::vector h_Q(size_Q), h_K(size_K), h_V(size_V), h_dropout_mask(size_dropout_mask); + std::vector h_O_gpu(size_O, DataType(0.0f)), h_O_cpu(size_O, DataType(0.0f)); + + for(size_t i = 0; i < size_Q; i++) h_Q[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_K; i++) h_K[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_V; i++) h_V[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_dropout_mask; i++) + h_dropout_mask[i] = Config::enable_dropout_mask + ? DataType(dis(gen) > dropout_p ? 1.0f : 0.0f) + : DataType(1.0f); + + if constexpr(std::is_same::value) + { + for(size_t i = 0; i < size_Q; i++) h_Q[i] = float(hip_bfloat16(h_Q[i])); + for(size_t i = 0; i < size_K; i++) h_K[i] = float(hip_bfloat16(h_K[i])); + for(size_t i = 0; i < size_V; i++) h_V[i] = float(hip_bfloat16(h_V[i])); + } + + float sqr_dk_scale = 1.0f / std::sqrt(static_cast(head_dim)); + attn_forward(h_Q.data(), h_K.data(), h_V.data(), + Config::enable_dropout_mask ? h_dropout_mask.data() : nullptr, + dropout_p, h_O_cpu.data(), static_cast(nullptr), + bs, head_num, max_seq_kv, head_dim, Config::mask_type, + h_cu_seqlens_q.data(), h_cu_seqlens_q_padded.data(), + h_cu_seqlens_kv.data(), h_cu_seqlens_kv_padded.data(), + true); + + DataType *d_Q, *d_K, *d_V, *d_dropout_mask, *d_O; + float* d_softmax_lse; + int *d_cu_seqlens_q, *d_cu_seqlens_q_padded; + int *d_cu_seqlens_kv, *d_cu_seqlens_kv_padded; + int* d_padded_q_to_batch; + + HIP_CHECK(hipMalloc(&d_Q, size_Q * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_K, size_K * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_V, size_V * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_dropout_mask, size_dropout_mask * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_O, size_O * sizeof(DataType))); + size_t lse_bytes = Launcher::calc_workspace_size(total_padded_q); + HIP_CHECK(hipMalloc(&d_softmax_lse, lse_bytes > 0 ? lse_bytes : sizeof(float))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_q_padded, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_seqlens_kv_padded, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_padded_q_to_batch, total_padded_q * sizeof(int))); + + HIP_CHECK(hipMemcpy(d_Q, h_Q.data(), size_Q * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_K, h_K.data(), size_K * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_V, h_V.data(), size_V * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_dropout_mask, h_dropout_mask.data(), + size_dropout_mask * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q, h_cu_seqlens_q.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_q_padded, h_cu_seqlens_q_padded.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv, h_cu_seqlens_kv.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_seqlens_kv_padded, h_cu_seqlens_kv_padded.data(), + (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_padded_q_to_batch, h_padded_q_to_batch.data(), + total_padded_q * sizeof(int), hipMemcpyHostToDevice)); + + std::cout << "\n===== " << label << " (MFMA 16x16 fwd) =====" << std::endl; + std::cout << " bs=" << bs << " heads=" << head_num << " sq=" << sq << " max_seq_kv=" << max_seq_kv + << " head_dim=" << head_dim << std::endl; + + Launcher::run_attn_fwd_kernel(d_Q, d_K, d_V, + Config::enable_dropout_mask ? d_dropout_mask : nullptr, + dropout_p, sqr_dk_scale, d_O, d_softmax_lse, + d_cu_seqlens_q, d_cu_seqlens_q_padded, + d_cu_seqlens_kv, d_cu_seqlens_kv_padded, + d_padded_q_to_batch, total_padded_q); + + HIP_CHECK(hipMemcpy(h_O_gpu.data(), d_O, size_O * sizeof(DataType), hipMemcpyDeviceToHost)); + + check_output(h_O_gpu, h_O_cpu, bs, head_num, head_dim, + h_cu_seqlens_q, h_cu_seqlens_q_padded, + label + " Output", 1e-2f, 1e-2f, dump_err); + std::cout << " PASS (vs CPU ref)\n" << std::endl; + + HIP_CHECK(hipFree(d_Q)); + HIP_CHECK(hipFree(d_K)); + HIP_CHECK(hipFree(d_V)); + HIP_CHECK(hipFree(d_dropout_mask)); + HIP_CHECK(hipFree(d_O)); + HIP_CHECK(hipFree(d_softmax_lse)); + HIP_CHECK(hipFree(d_cu_seqlens_q)); + HIP_CHECK(hipFree(d_cu_seqlens_q_padded)); + HIP_CHECK(hipFree(d_cu_seqlens_kv)); + HIP_CHECK(hipFree(d_cu_seqlens_kv_padded)); + HIP_CHECK(hipFree(d_padded_q_to_batch)); +} + +int main() +{ + constexpr int bs = 4; + constexpr int heads = 4; + constexpr int max_seq_kv = 16; + constexpr int sq = 8; // max_seq_q; >4 so tests 16x16 path, not 4x4 dispatch + const float dropout_p = 0.0f; + const bool dump_err = true; + + std::cout << "MFMA 16x16 forward head-dimension sweep (CPU reference vs GPU kernel)\n"; + + run_fwd_head_dim_case>( + "head_dim_128", dropout_p, dump_err); + + run_fwd_head_dim_case>( + "head_dim_256", dropout_p, dump_err); + + run_fwd_head_dim_case>( + "head_dim_512", dropout_p, dump_err); + + std::cout << "All head-dimension MFMA forward tests finished successfully.\n"; + return 0; +} diff --git a/tests/cpp/small_seq_kernels/tests/test_small_seq_sweep.cpp b/tests/cpp/small_seq_kernels/tests/test_small_seq_sweep.cpp new file mode 100644 index 000000000..86e591017 --- /dev/null +++ b/tests/cpp/small_seq_kernels/tests/test_small_seq_sweep.cpp @@ -0,0 +1,428 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Small-sequence sweep benchmark matching the TransformerEngine benchmark: +// bs=2048, nheads=32, hdim=128, bfloat16 +// +// 1) Self-attention: seqlen_q == seqlen_kv = 1..17 +// Forward: mfma_4x4 for seq<=4, mfma_16x16 for all; backward: mfma_16x16. +// +// 2) Cross-attention: seqlen_q = 1, seqlen_kv = 2..16 (uniform per batch) +// Forward + backward: mfma_16x16 only (kernel compiled with max_seq_q=1, +// max_seq_kv=16). +// +// Outputs results in CSV format compatible with the TE benchmark CSV. + +#include "attn_fwd_mfma.h" +#include "attn_fwd_mfma_16x16.h" +#include "attn_bwd_mfma_16x16.h" +#include "attn_fwd_ref.h" +#include "attn_common.h" +#include "test_utils.h" + +#include +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// Build cu_seqlens with UNIFORM lengths (every batch has exactly seq_len) +// --------------------------------------------------------------------------- + +static void build_cu_seqlens_uniform(int bs, int seq_len, + std::vector& cu_seqlens, + std::vector& cu_seqlens_padded, + std::vector& padded_to_batch) +{ + cu_seqlens.resize(bs + 1); + cu_seqlens_padded.resize(bs + 1); + cu_seqlens[0] = cu_seqlens_padded[0] = 0; + for(int b = 0; b < bs; b++) + { + cu_seqlens[b + 1] = cu_seqlens[b] + seq_len; + cu_seqlens_padded[b + 1] = cu_seqlens_padded[b] + seq_len; + } + int total = cu_seqlens_padded[bs]; + padded_to_batch.resize(total); + for(int b = 0; b < bs; b++) + for(int i = cu_seqlens_padded[b]; i < cu_seqlens_padded[b + 1]; i++) + padded_to_batch[i] = b; +} + +// --------------------------------------------------------------------------- +// Benchmark result +// --------------------------------------------------------------------------- + +struct BenchResult { + double min_ms; + double median_ms; + double mean_ms; + double q1_ms; + double q3_ms; + double tflops; +}; + +static BenchResult compute_stats(std::vector& timings, double total_flops) +{ + std::sort(timings.begin(), timings.end()); + int n = timings.size(); + BenchResult r; + r.min_ms = timings[0]; + r.median_ms = timings[n / 2]; + r.q1_ms = timings[n / 4]; + r.q3_ms = timings[3 * n / 4]; + double sum = 0; + for(auto t : timings) sum += t; + r.mean_ms = sum / n; + r.tflops = (total_flops / 1e12) / (r.min_ms / 1000.0); + return r; +} + +static void print_csv_row(const char* mode, int bs, int sq, int skv, + int nheads, int hdim, const char* kernel, + const BenchResult& r) +{ + std::cout << mode << ",bshd,bfloat16," << bs << "," + << sq << "," << skv << "," + << nheads << "," << hdim << ",1," << kernel << "," + << std::fixed << std::setprecision(3) + << r.min_ms << "," << r.median_ms << "," + << r.mean_ms << "," << r.q1_ms << "," << r.q3_ms << "," + << r.tflops << std::endl; +} + +// --------------------------------------------------------------------------- +// Forward benchmark +// --------------------------------------------------------------------------- + +template +BenchResult run_fwd_bench(int sq, int skv, int warmup, int niters) +{ + constexpr int bs = Config::bs; + constexpr int nh = Config::head_num; + constexpr int hd = Config::head_dim; + + std::mt19937 gen(42); + std::uniform_real_distribution dis(-1.0f, 1.0f); + + std::vector csq, csqp, q2b, cskv, cskvp; + build_cu_seqlens_uniform(bs, sq, csq, csqp, q2b); + cskv.resize(bs + 1); cskvp.resize(bs + 1); + cskv[0] = cskvp[0] = 0; + for(int b = 0; b < bs; b++) { + cskv[b+1] = cskv[b] + skv; + cskvp[b+1] = cskvp[b] + skv; + } + int tot_q = csqp[bs]; + int tot_kv = cskvp[bs]; + + size_t sQ = (size_t)tot_q * nh * hd; + size_t sK = (size_t)tot_kv * nh * hd; + size_t sO = sQ; + + std::vector hQ(sQ), hK(sK), hV(sK); + for(size_t i = 0; i < sQ; i++) hQ[i] = DataType(dis(gen)); + for(size_t i = 0; i < sK; i++) hK[i] = DataType(dis(gen)); + for(size_t i = 0; i < sK; i++) hV[i] = DataType(dis(gen)); + + using FwdAuxEl = typename Launcher::fwd_aux_buffer_scalar; + FwdAuxEl* dW; + DataType *dQ, *dK, *dV, *dO; + int *d_csq, *d_csqp, *d_cskv, *d_cskvp, *d_q2b; + HIP_CHECK(hipMalloc(&dQ, sQ * sizeof(DataType))); + HIP_CHECK(hipMalloc(&dK, sK * sizeof(DataType))); + HIP_CHECK(hipMalloc(&dV, sK * sizeof(DataType))); + HIP_CHECK(hipMalloc(&dO, sO * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_csq, (bs+1)*sizeof(int))); + HIP_CHECK(hipMalloc(&d_csqp, (bs+1)*sizeof(int))); + HIP_CHECK(hipMalloc(&d_cskv, (bs+1)*sizeof(int))); + HIP_CHECK(hipMalloc(&d_cskvp,(bs+1)*sizeof(int))); + HIP_CHECK(hipMalloc(&d_q2b, tot_q * sizeof(int))); + size_t ws = Launcher::calc_workspace_size(tot_q); + HIP_CHECK(hipMalloc(&dW, ws > 0 ? ws : sizeof(FwdAuxEl))); + + HIP_CHECK(hipMemcpy(dQ, hQ.data(), sQ*sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(dK, hK.data(), sK*sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(dV, hV.data(), sK*sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_csq, csq.data(), (bs+1)*sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_csqp, csqp.data(), (bs+1)*sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cskv, cskv.data(), (bs+1)*sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cskvp,cskvp.data(),(bs+1)*sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_q2b, q2b.data(), tot_q*sizeof(int), hipMemcpyHostToDevice)); + + float scale = 1.0f / std::sqrt((float)hd); + auto launch = [&]() { + Launcher::run_attn_fwd_kernel(dQ, dK, dV, + static_cast(nullptr), 0.0f, scale, dO, dW, + d_csq, d_csqp, d_cskv, d_cskvp, d_q2b, tot_q); + }; + + for(int i = 0; i < warmup; i++) launch(); + HIP_CHECK(hipDeviceSynchronize()); + + std::vector timings(niters); + for(int i = 0; i < niters; i++) { + hipEvent_t t0, t1; + HIP_CHECK(hipEventCreate(&t0)); HIP_CHECK(hipEventCreate(&t1)); + HIP_CHECK(hipEventRecord(t0)); + launch(); + HIP_CHECK(hipEventRecord(t1)); HIP_CHECK(hipEventSynchronize(t1)); + float ms; HIP_CHECK(hipEventElapsedTime(&ms, t0, t1)); + timings[i] = ms; + HIP_CHECK(hipEventDestroy(t0)); HIP_CHECK(hipEventDestroy(t1)); + } + + double flops = 4.0 * sq * skv * hd * bs * nh; + auto res = compute_stats(timings, flops); + + HIP_CHECK(hipFree(dQ)); HIP_CHECK(hipFree(dK)); HIP_CHECK(hipFree(dV)); + HIP_CHECK(hipFree(dO)); HIP_CHECK(hipFree(dW)); + HIP_CHECK(hipFree(d_csq)); HIP_CHECK(hipFree(d_csqp)); + HIP_CHECK(hipFree(d_cskv)); HIP_CHECK(hipFree(d_cskvp)); + HIP_CHECK(hipFree(d_q2b)); + return res; +} + +// --------------------------------------------------------------------------- +// Backward benchmark +// --------------------------------------------------------------------------- + +template +BenchResult run_bwd_bench(int sq, int skv, int warmup, int niters) +{ + using FwdLauncher = AttnForwardMfma16x16KernelLauncher; + + constexpr int bs = Config::bs; + constexpr int nh = Config::head_num; + constexpr int hd = Config::head_dim; + + std::mt19937 gen(42); + std::uniform_real_distribution dis(-1.0f, 1.0f); + + std::vector csq, csqp, q2b, cskv, cskvp; + build_cu_seqlens_uniform(bs, sq, csq, csqp, q2b); + cskv.resize(bs + 1); cskvp.resize(bs + 1); + cskv[0] = cskvp[0] = 0; + for(int b = 0; b < bs; b++) { + cskv[b+1] = cskv[b] + skv; + cskvp[b+1] = cskvp[b] + skv; + } + int tot_q = csqp[bs]; + int tot_kv = cskvp[bs]; + + size_t sQ = (size_t)tot_q * nh * hd; + size_t sK = (size_t)tot_kv * nh * hd; + + std::vector hQ(sQ), hK(sK), hV(sK), hGO(sQ); + for(size_t i = 0; i < sQ; i++) hQ[i] = DataType(dis(gen)); + for(size_t i = 0; i < sK; i++) hK[i] = DataType(dis(gen)); + for(size_t i = 0; i < sK; i++) hV[i] = DataType(dis(gen)); + for(size_t i = 0; i < sQ; i++) hGO[i] = DataType(dis(gen)); + + float scale = 1.0f / std::sqrt((float)hd); + + DataType *dQ, *dK, *dV, *dO, *dGO, *dGQ, *dGK, *dGV; + float* d_lse; + int *d_csq, *d_csqp, *d_cskv, *d_cskvp; + int* d_q2b; + HIP_CHECK(hipMalloc(&dQ, sQ * sizeof(DataType))); + HIP_CHECK(hipMalloc(&dK, sK * sizeof(DataType))); + HIP_CHECK(hipMalloc(&dV, sK * sizeof(DataType))); + HIP_CHECK(hipMalloc(&dO, sQ * sizeof(DataType))); + HIP_CHECK(hipMalloc(&dGO, sQ * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_lse, + FwdLauncher::calc_workspace_size(tot_q) > 0 + ? FwdLauncher::calc_workspace_size(tot_q) + : sizeof(float))); + HIP_CHECK(hipMalloc(&dGQ, sQ * sizeof(DataType))); + HIP_CHECK(hipMalloc(&dGK, sK * sizeof(DataType))); + HIP_CHECK(hipMalloc(&dGV, sK * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_csq, (bs+1)*sizeof(int))); + HIP_CHECK(hipMalloc(&d_csqp, (bs+1)*sizeof(int))); + HIP_CHECK(hipMalloc(&d_cskv, (bs+1)*sizeof(int))); + HIP_CHECK(hipMalloc(&d_cskvp,(bs+1)*sizeof(int))); + HIP_CHECK(hipMalloc(&d_q2b, tot_q * sizeof(int))); + + HIP_CHECK(hipMemcpy(dQ, hQ.data(), sQ *sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(dK, hK.data(), sK *sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(dV, hV.data(), sK *sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(dGO, hGO.data(), sQ *sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_csq, csq.data(), (bs+1)*sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_csqp, csqp.data(), (bs+1)*sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cskv, cskv.data(), (bs+1)*sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cskvp,cskvp.data(),(bs+1)*sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_q2b, q2b.data(), tot_q * sizeof(int), hipMemcpyHostToDevice)); + + FwdLauncher::run_attn_fwd_kernel(dQ, dK, dV, static_cast(nullptr), 0.0f, + scale, dO, d_lse, + d_csq, d_csqp, d_cskv, d_cskvp, d_q2b, tot_q); + HIP_CHECK(hipDeviceSynchronize()); + + auto launch = [&]() { + HIP_CHECK(hipMemset(dGQ, 0, sQ * sizeof(DataType))); + HIP_CHECK(hipMemset(dGK, 0, sK * sizeof(DataType))); + HIP_CHECK(hipMemset(dGV, 0, sK * sizeof(DataType))); + Launcher::run_attn_bwd_kernel(dQ, dK, dV, dGO, d_lse, + dGQ, dGK, dGV, scale, + d_csq, d_csqp, d_cskv, d_cskvp); + }; + + for(int i = 0; i < warmup; i++) launch(); + HIP_CHECK(hipDeviceSynchronize()); + + std::vector timings(niters); + for(int i = 0; i < niters; i++) { + hipEvent_t t0, t1; + HIP_CHECK(hipEventCreate(&t0)); HIP_CHECK(hipEventCreate(&t1)); + HIP_CHECK(hipEventRecord(t0)); + launch(); + HIP_CHECK(hipEventRecord(t1)); HIP_CHECK(hipEventSynchronize(t1)); + float ms; HIP_CHECK(hipEventElapsedTime(&ms, t0, t1)); + timings[i] = ms; + HIP_CHECK(hipEventDestroy(t0)); HIP_CHECK(hipEventDestroy(t1)); + } + + // Bwd FLOPS: ~8 * sq * skv * hd per batch per head (2x QK^T grad + 2x PV grad) + double flops = 8.0 * sq * skv * hd * bs * nh; + auto res = compute_stats(timings, flops); + + HIP_CHECK(hipFree(dQ)); HIP_CHECK(hipFree(dK)); HIP_CHECK(hipFree(dV)); + HIP_CHECK(hipFree(dO)); HIP_CHECK(hipFree(dGO)); HIP_CHECK(hipFree(d_lse)); + HIP_CHECK(hipFree(dGQ)); HIP_CHECK(hipFree(dGK)); HIP_CHECK(hipFree(dGV)); + HIP_CHECK(hipFree(d_csq)); HIP_CHECK(hipFree(d_csqp)); + HIP_CHECK(hipFree(d_cskv)); HIP_CHECK(hipFree(d_cskvp)); + HIP_CHECK(hipFree(d_q2b)); + return res; +} + +// --------------------------------------------------------------------------- +// Recursive template to sweep seqlen from SEQ to MAX_SEQ +// --------------------------------------------------------------------------- + +template +struct SeqSweep +{ + template + static void run(int warmup, int iters) + { + // --- Forward: 4x4x4 (only for SEQ <= 4) --- + if constexpr(SEQ <= 4) + { + using Cfg = FmhaKernelConfig; + using L = AttnForwardMfmaKernelLauncher; + auto r = run_fwd_bench(SEQ, SEQ, warmup, iters); + print_csv_row("fwd", BS, SEQ, SEQ, HEAD_NUM, HEAD_DIM, "mfma_4x4", r); + } + + // --- Forward: 16x16x16 --- + { + using Cfg = FmhaKernelConfig; + using L = AttnForwardMfma16x16KernelLauncher; + auto r = run_fwd_bench(SEQ, SEQ, warmup, iters); + print_csv_row("fwd", BS, SEQ, SEQ, HEAD_NUM, HEAD_DIM, "mfma_16x16", r); + } + + // --- Backward: 16x16x16 --- + { + using Cfg = FmhaKernelConfig; + using L = AttnBackwardMfma16x16KernelLauncher; + auto r = run_bwd_bench(SEQ, SEQ, warmup, iters); + print_csv_row("bwd", BS, SEQ, SEQ, HEAD_NUM, HEAD_DIM, "mfma_16x16", r); + } + + SeqSweep::template run(warmup, iters); + } +}; + +template +struct SeqSweep +{ + template + static void run(int warmup, int iters) + { + if constexpr(MAX_SEQ <= 4) + { + using Cfg = FmhaKernelConfig; + using L = AttnForwardMfmaKernelLauncher; + auto r = run_fwd_bench(MAX_SEQ, MAX_SEQ, warmup, iters); + print_csv_row("fwd", BS, MAX_SEQ, MAX_SEQ, HEAD_NUM, HEAD_DIM, "mfma_4x4", r); + } + + { + using Cfg = FmhaKernelConfig; + using L = AttnForwardMfma16x16KernelLauncher; + auto r = run_fwd_bench(MAX_SEQ, MAX_SEQ, warmup, iters); + print_csv_row("fwd", BS, MAX_SEQ, MAX_SEQ, HEAD_NUM, HEAD_DIM, "mfma_16x16", r); + } + + { + using Cfg = FmhaKernelConfig; + using L = AttnBackwardMfma16x16KernelLauncher; + auto r = run_bwd_bench(MAX_SEQ, MAX_SEQ, warmup, iters); + print_csv_row("bwd", BS, MAX_SEQ, MAX_SEQ, HEAD_NUM, HEAD_DIM, "mfma_16x16", r); + } + } +}; + +// --------------------------------------------------------------------------- +// Cross-attention sweep: sq=1, skv in [MIN_SKV, MAX_SKV] (MFMA 16x16 only) +// --------------------------------------------------------------------------- + +template +struct CrossSeqSweep +{ + template + static void run(int warmup, int iters) + { + using Cfg = FmhaKernelConfig; + using FwdL = AttnForwardMfma16x16KernelLauncher; + using BwdL = AttnBackwardMfma16x16KernelLauncher; + + auto rf = run_fwd_bench(1, SKV, warmup, iters); + print_csv_row("fwd", BS, 1, SKV, HEAD_NUM, HEAD_DIM, "mfma_16x16", rf); + auto rb = run_bwd_bench(1, SKV, warmup, iters); + print_csv_row("bwd", BS, 1, SKV, HEAD_NUM, HEAD_DIM, "mfma_16x16", rb); + + CrossSeqSweep::template run( + warmup, iters); + } +}; + +template +struct CrossSeqSweep +{ + template + static void run(int, int) + { + } +}; + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main() +{ + std::cout << "mode,layout,dtype,batch_size,seqlen_q,seqlen_kv,nheads,dim," + << "gqa_ratio,kernel," + << "min_steptime_ms,median_steptime_ms,mean_steptime_ms," + << "q1_steptime_ms,q3_steptime_ms,tflops" + << std::endl; + + constexpr int warmup = 5; + constexpr int niters = 20; + + SeqSweep<1, 17>::run(warmup, niters); + CrossSeqSweep<2, 16>::run(warmup, niters); + + return 0; +} diff --git a/tests/cpp/small_seq_kernels/tests/test_utils.h b/tests/cpp/small_seq_kernels/tests/test_utils.h new file mode 100644 index 000000000..062224ba3 --- /dev/null +++ b/tests/cpp/small_seq_kernels/tests/test_utils.h @@ -0,0 +1,523 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "attn_common.h" + +#include + +#include +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// Build cu_seqlens for Q side +// +// max_seq_q == 1: Actual Q len per batch is 0 or 1 (bernoulli). +// max_seq_q > 1: Actual Q len per batch is uniform in [0, max_seq_q]. +// +// Returns total_padded_q (== total_actual_q in this scheme). +// --------------------------------------------------------------------------- + +inline int build_cu_seqlens_q(int bs, + std::mt19937& gen, + std::vector& cu_seqlens_q, + std::vector& cu_seqlens_q_padded, + std::vector& padded_q_to_batch, + int max_seq_q = 1) +{ + cu_seqlens_q.resize(bs + 1); + cu_seqlens_q_padded.resize(bs + 1); + cu_seqlens_q[0] = 0; + cu_seqlens_q_padded[0] = 0; + int total_actual_q = 0; + + // Different distributions based on max_seq_q + std::bernoulli_distribution q_bernoulli(0.5); + std::uniform_int_distribution q_uniform(0, max_seq_q); + + for(int b = 0; b < bs; b++) + { + int q_len; + if(max_seq_q == 1) + q_len = q_bernoulli(gen) ? 1 : 0; + else + q_len = q_uniform(gen); + + total_actual_q += q_len; + cu_seqlens_q[b + 1] = total_actual_q; + cu_seqlens_q_padded[b + 1] = cu_seqlens_q_padded[b] + q_len; + } + + int total_padded_q = cu_seqlens_q_padded[bs]; + padded_q_to_batch.resize(total_padded_q); + for(int b = 0; b < bs; b++) + { + int q_start = cu_seqlens_q_padded[b]; + int q_end = cu_seqlens_q_padded[b + 1]; + for(int i = q_start; i < q_end; i++) + padded_q_to_batch[i] = b; + } + + return total_padded_q; +} + +// --------------------------------------------------------------------------- +// Build cu_seqlens for KV side (random lengths with optional padding) +// --------------------------------------------------------------------------- + +inline void build_cu_seqlens_kv(int bs, + int max_seq_kv, + std::mt19937& gen, + std::vector& cu_seqlens_kv, + std::vector& cu_seqlens_kv_padded, + int& total_actual_kv_seq, + int& total_padded_kv_seq) +{ + std::normal_distribution normal_dis(4.0f, 2.0f); + std::uniform_int_distribution pad_dis(0, 5); + + cu_seqlens_kv.resize(bs + 1); + cu_seqlens_kv_padded.resize(bs + 1); + cu_seqlens_kv[0] = 0; + cu_seqlens_kv_padded[0] = 0; + total_actual_kv_seq = 0; + total_padded_kv_seq = 0; + + for(int b = 0; b < bs; b++) + { + int kv_len = static_cast(std::round(normal_dis(gen))); + kv_len = std::max(2, std::min(max_seq_kv, kv_len)); + int random_pad = pad_dis(gen); + int padded_len = (kv_len + random_pad > max_seq_kv) ? max_seq_kv : kv_len + random_pad; + total_actual_kv_seq += kv_len; + total_padded_kv_seq += padded_len; + cu_seqlens_kv[b + 1] = total_actual_kv_seq; + cu_seqlens_kv_padded[b + 1] = total_padded_kv_seq; + } +} + +// --------------------------------------------------------------------------- +// Reference softmax P + LSE (matches MFMA 16x16 forward numerics; mask-aware). +// Used by Option A backward tests (recompute P from Q, K, LSE on GPU). +// --------------------------------------------------------------------------- + +template +inline void reference_probs_and_lse_from_qk( + const std::vector& Q, + const std::vector& K, + int bs, + int head_num, + int max_seq_kv, + int head_dim, + float scale, + CausalMaskType mask_type, + const std::vector& cu_seqlens_q, + const std::vector& cu_seqlens_q_padded, + const std::vector& cu_seqlens_kv, + const std::vector& cu_seqlens_kv_padded, + std::vector& softmax_lse, + std::vector& attn_probs) +{ + int total_padded_q = cu_seqlens_q_padded.back(); + softmax_lse.resize((size_t)total_padded_q * head_num); + attn_probs.assign((size_t)total_padded_q * head_num * max_seq_kv, T(0)); + + for(int b = 0; b < bs; ++b) + { + int q_off = cu_seqlens_q_padded[b]; + int kv_off = cu_seqlens_kv_padded[b]; + int actual_q = cu_seqlens_q[b + 1] - cu_seqlens_q[b]; + int seq_kv = cu_seqlens_kv[b + 1] - cu_seqlens_kv[b]; + + for(int qi = 0; qi < actual_q; ++qi) + { + for(int h = 0; h < head_num; ++h) + { + int q_row_g = q_off + qi; + size_t lse_i = (size_t)q_row_g * head_num + h; + + float row_max = -INFINITY; + std::vector scores((size_t)seq_kv); + for(int j = 0; j < seq_kv; ++j) + { + bool masked = false; + if(mask_type == CausalMaskType::TOP_LEFT) + { + if(j > qi) + masked = true; + } + float s = 0.0f; + if(!masked) + { + for(int d = 0; d < head_dim; ++d) + { + size_t qix = ((size_t)q_row_g * head_num + h) * head_dim + d; + size_t kix = + ((size_t)(kv_off + j) * head_num + h) * head_dim + d; + s += float(Q[qix]) * float(K[kix]); + } + s *= scale; + } + else + s = -INFINITY; + scores[(size_t)j] = s; + row_max = std::max(row_max, s); + } + + float row_sum = 0.0f; + for(int j = 0; j < seq_kv; ++j) + { + if(scores[(size_t)j] > -INFINITY / 4) + row_sum += std::exp(scores[(size_t)j] - row_max); + } + + float lse = (row_sum > 0.0f) ? (row_max + std::log(row_sum)) : -INFINITY; + softmax_lse[lse_i] = lse; + + for(int j = 0; j < seq_kv; ++j) + { + size_t p_i = lse_i * (size_t)max_seq_kv + (size_t)j; + if(scores[(size_t)j] > -INFINITY / 4) + attn_probs[p_i] = T(std::exp(scores[(size_t)j] - lse)); + else + attn_probs[p_i] = T(0); + } + } + } + } +} + +// --------------------------------------------------------------------------- +// Attention probs P_ij = exp(S_ij - LSE_row) with **bf16** dot products for S_ij (matches +// MFMA forward/backward numerics better than float dots) and **given** per-row LSE — use +// LSE copied from the MFMA forward GPU pass so CPU backward matches Option A GPU backward. +// --------------------------------------------------------------------------- + +template +inline void reference_attn_probs_bf16_dots_with_given_lse( + const std::vector& Q, + const std::vector& K, + int bs, + int head_num, + int max_seq_kv, + int head_dim, + float scale, + CausalMaskType mask_type, + const std::vector& cu_seqlens_q, + const std::vector& cu_seqlens_q_padded, + const std::vector& cu_seqlens_kv, + const std::vector& cu_seqlens_kv_padded, + const std::vector& softmax_lse, + std::vector& attn_probs) +{ + int total_padded_q = cu_seqlens_q_padded.back(); + attn_probs.assign((size_t)total_padded_q * head_num * max_seq_kv, T(0)); + + for(int b = 0; b < bs; ++b) + { + int q_off = cu_seqlens_q_padded[b]; + int kv_off = cu_seqlens_kv_padded[b]; + int actual_q = cu_seqlens_q[b + 1] - cu_seqlens_q[b]; + int seq_kv = cu_seqlens_kv[b + 1] - cu_seqlens_kv[b]; + + for(int qi = 0; qi < actual_q; ++qi) + { + for(int h = 0; h < head_num; ++h) + { + int q_row_g = q_off + qi; + size_t lse_i = (size_t)q_row_g * head_num + h; + const float lse = softmax_lse[lse_i]; + + for(int j = 0; j < seq_kv; ++j) + { + bool masked = false; + if(mask_type == CausalMaskType::TOP_LEFT) + { + if(j > qi) + masked = true; + } + size_t p_i = lse_i * (size_t)max_seq_kv + (size_t)j; + if(masked) + { + attn_probs[p_i] = T(0); + continue; + } + // Sum per 16-dim tile then add (matches MFMA k-loop float acc; not bf16-rounded P). + float s = 0.0f; + const int total_hd_tiles = (head_dim + 15) / 16; + for(int kt = 0; kt < total_hd_tiles; ++kt) + { + float partial = 0.0f; + const int d0 = kt * 16; + const int d1 = std::min(d0 + 16, head_dim); + for(int d = d0; d < d1; ++d) + { + size_t qix = ((size_t)q_row_g * head_num + h) * head_dim + d; + size_t kix = + ((size_t)(kv_off + j) * head_num + h) * head_dim + d; + float qf = static_cast(Q[qix]); + float kf = static_cast(K[kix]); + partial += float(hip_bfloat16(qf)) * float(hip_bfloat16(kf)); + } + s += partial; + } + s *= scale; + attn_probs[p_i] = T(std::exp(s - lse)); + } + } + } + } +} + +// --------------------------------------------------------------------------- +// Correctness check helpers +// +// Tolerance formula (numpy/PyTorch allclose, same as CK check_err): +// PASS when |x_test − x_ref| ≤ atol + rtol × |x_ref| +// +// bf16 MFMA defaults: rtol=1e-2, atol=1e-2 (matches CK FmhaBwdBf16/FmhaFwdBf16) +// --------------------------------------------------------------------------- + +// Check two arrays element-wise (all elements). +template +void check_array(const std::vector& gpu, + const std::vector& cpu, + const std::string& name, + float rtol = 1e-2f, + float atol = 1e-2f, + bool dump_err = false) +{ + float max_diff = 0.0f; + float max_rel_diff = 0.0f; + size_t diff_count = 0; + + for(size_t i = 0; i < gpu.size(); i++) + { + float ref_val = float(cpu[i]); + float diff = std::abs(float(gpu[i]) - ref_val); + float tol = atol + rtol * std::abs(ref_val); + float rel_diff = diff / (std::abs(ref_val) + 1e-12f); + max_diff = std::max(max_diff, diff); + max_rel_diff = std::max(max_rel_diff, rel_diff); + if(diff > tol) + { + if(dump_err) + std::cout << name << " mismatch at " << i + << ": GPU=" << float(gpu[i]) << " CPU=" << ref_val + << " abs=" << diff << " tol=" << tol << std::endl; + diff_count++; + } + } + + bool pass = (diff_count == 0); + std::cout << name << " check:" << std::endl; + std::cout << " Max abs diff: " << max_diff << " Max rel diff: " << max_rel_diff << std::endl; + std::cout << " Exceeding tolerance (rtol=" << rtol << ", atol=" << atol + << "): " << diff_count << " / " << gpu.size() << std::endl; + std::cout << " Status: " << (pass ? "PASS" : "FAIL") << std::endl; +} + +// Check grad_Q only on active-Q slots (skip empty-Q batches). +template +void check_grad_q(const std::vector& gpu, + const std::vector& cpu, + int bs, + int head_num, + int head_dim, + const std::vector& cu_seqlens_q, + const std::vector& cu_seqlens_q_padded, + float rtol = 1e-2f, + float atol = 1e-2f, + bool dump_err = false) +{ + float max_diff = 0.0f; + float max_rel_diff = 0.0f; + size_t diff_count = 0; + size_t active_elems = 0; + + for(int b = 0; b < bs; b++) + { + int actual_q = cu_seqlens_q[b + 1] - cu_seqlens_q[b]; + if(actual_q == 0) + continue; + int q_off = cu_seqlens_q_padded[b]; + for(int q = 0; q < actual_q; q++) + { + for(int h = 0; h < head_num; h++) + { + int base = ((q_off + q) * head_num + h) * head_dim; + for(int d = 0; d < head_dim; d++) + { + size_t idx = base + d; + float ref_val = float(cpu[idx]); + float diff = std::abs(float(gpu[idx]) - ref_val); + float tol = atol + rtol * std::abs(ref_val); + float rel_diff = diff / (std::abs(ref_val) + 1e-12f); + max_diff = std::max(max_diff, diff); + max_rel_diff = std::max(max_rel_diff, rel_diff); + if(diff > tol) + { + if(dump_err) + std::cout << "grad_Q mismatch at [b=" << b << ",q=" << q + << ",h=" << h << ",d=" << d + << "]: GPU=" << float(gpu[idx]) << " CPU=" << ref_val + << " abs=" << diff << " tol=" << tol << std::endl; + diff_count++; + } + active_elems++; + } + } + } + } + + bool pass = (diff_count == 0); + std::cout << "grad_Q check (active slots only):" << std::endl; + std::cout << " Active Q elements: " << active_elems << std::endl; + std::cout << " Max abs diff: " << max_diff << " Max rel diff: " << max_rel_diff << std::endl; + std::cout << " Exceeding tolerance (rtol=" << rtol << ", atol=" << atol + << "): " << diff_count << " / " << active_elems << std::endl; + std::cout << " Status: " << (pass ? "PASS" : "FAIL") << std::endl; +} + +// Check output on active-Q batch positions (supports multi-Q). +// +// Tolerance: |diff| ≤ atol + rtol × |x_ref| +template +void check_output(const std::vector& gpu, + const std::vector& cpu, + int bs, + int head_num, + int head_dim, + const std::vector& cu_seqlens_q, + const std::vector& cu_seqlens_q_padded, + const std::string& name, + float rtol = 1e-2f, + float atol = 1e-2f, + bool dump_err = false) +{ + float max_diff = 0.0f; + float max_rel_diff = 0.0f; + size_t diff_count = 0; + size_t total_elems = 0; + + for(int b = 0; b < bs; b++) + { + int actual_q = cu_seqlens_q[b + 1] - cu_seqlens_q[b]; + if(actual_q == 0) + continue; + int q_off = cu_seqlens_q_padded[b]; + for(int q = 0; q < actual_q; q++) + { + for(int h = 0; h < head_num; h++) + { + for(int d = 0; d < head_dim; d++) + { + size_t idx = ((size_t)(q_off + q) * head_num + h) * head_dim + d; + float ref_val = float(cpu[idx]); + float diff = std::abs(float(gpu[idx]) - ref_val); + float tol = atol + rtol * std::abs(ref_val); + float rel_diff = diff / (std::abs(ref_val) + 1e-12f); + max_diff = std::max(max_diff, diff); + max_rel_diff = std::max(max_rel_diff, rel_diff); + total_elems++; + if(diff > tol) + { + if(dump_err) + std::cout << name << " mismatch at b=" << b << " q=" << q + << " h=" << h << " d=" << d + << ": GPU=" << float(gpu[idx]) << " CPU=" << ref_val + << " abs=" << diff << " tol=" << tol << std::endl; + diff_count++; + } + } + } + } + } + + bool pass = (diff_count == 0); + std::cout << name << " check:" << std::endl; + std::cout << " Max abs diff: " << max_diff << " Max rel diff: " << max_rel_diff << std::endl; + std::cout << " Exceeding tolerance (rtol=" << rtol << ", atol=" << atol + << "): " << diff_count << " / " << total_elems << std::endl; + std::cout << " Status: " << (pass ? "PASS" : "FAIL") << std::endl; +} + +// --------------------------------------------------------------------------- +// TestRunner: iterate over SEQ_KV values from SEQ_KV to MAX_SEQ_KV +// +// Usage: +// TestRunner::run(fn, args...); +// +// The Func callable must have the signature: +// template +// void fn(Args...); +// --------------------------------------------------------------------------- + +template +struct TestRunner +{ + template + static void run(Func fn, Args&&... args) + { + using KernelConfig = FmhaKernelConfig; + fn.template operator()(std::forward(args)...); + + TestRunner::template run(fn, std::forward(args)...); + } +}; + +// Termination specialisation +template +struct TestRunner +{ + template + static void run(Func fn, Args&&... args) + { + using KernelConfig = FmhaKernelConfig; + fn.template operator()(std::forward(args)...); + } +}; diff --git a/tests/cpp/small_seq_kernels/tests/test_varlen_mfma_16x16.cpp b/tests/cpp/small_seq_kernels/tests/test_varlen_mfma_16x16.cpp new file mode 100644 index 000000000..8f1010ac2 --- /dev/null +++ b/tests/cpp/small_seq_kernels/tests/test_varlen_mfma_16x16.cpp @@ -0,0 +1,460 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// +// Unified varlen test for MFMA 16x16x16 forward + backward kernels. +// +// 4 test cases (each runs fwd + bwd): +// 1. sq∈[1,16], skv∈[2,16]; varlen + padding +// 2. sq=1 (fixed, no padding), skv∈[2,16] (varlen + padding) +// 3. sq=16, skv=16; fixed, no padding +// 4. sq=17, skv=17; fixed, no padding +// +// Build: cmake --build build --target test_varlen_mfma_16x16 + +#include "attn_fwd_mfma_16x16.h" +#include "attn_bwd_mfma_16x16.h" +#include "attn_fwd_ref.h" +#include "attn_bwd_ref.h" +#include "test_utils.h" + +#include +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// Q cu_seqlens builders +// --------------------------------------------------------------------------- + +// Varlen Q: actual in [1, max_seq_q] with random padding [0, 3] +inline int build_varlen_cu_seqlens_q(int bs, + int max_seq_q, + std::mt19937& gen, + std::vector& cu_seqlens_q, + std::vector& cu_seqlens_q_padded, + std::vector& padded_q_to_batch) +{ + std::uniform_int_distribution q_dist(1, max_seq_q); + std::uniform_int_distribution pad_dist(0, 3); + + cu_seqlens_q.resize(bs + 1); + cu_seqlens_q_padded.resize(bs + 1); + cu_seqlens_q[0] = cu_seqlens_q_padded[0] = 0; + + for(int b = 0; b < bs; b++) + { + int q_len = q_dist(gen); + int pad = pad_dist(gen); + int padded_len = std::min(q_len + pad, max_seq_q); + cu_seqlens_q[b + 1] = cu_seqlens_q[b] + q_len; + cu_seqlens_q_padded[b + 1] = cu_seqlens_q_padded[b] + padded_len; + } + + int total_padded_q = cu_seqlens_q_padded[bs]; + padded_q_to_batch.resize(total_padded_q); + for(int b = 0; b < bs; b++) + for(int q = cu_seqlens_q_padded[b]; q < cu_seqlens_q_padded[b + 1]; q++) + padded_q_to_batch[q] = b; + + return total_padded_q; +} + +// Fixed Q: all batches have exactly fix_sq tokens, no padding +inline int build_fixed_cu_seqlens_q(int bs, + int fix_sq, + std::vector& cu_seqlens_q, + std::vector& cu_seqlens_q_padded, + std::vector& padded_q_to_batch) +{ + cu_seqlens_q.resize(bs + 1); + cu_seqlens_q_padded.resize(bs + 1); + cu_seqlens_q[0] = cu_seqlens_q_padded[0] = 0; + + for(int b = 0; b < bs; b++) + { + cu_seqlens_q[b + 1] = cu_seqlens_q[b] + fix_sq; + cu_seqlens_q_padded[b + 1] = cu_seqlens_q_padded[b] + fix_sq; + } + + int total_padded_q = cu_seqlens_q_padded[bs]; + padded_q_to_batch.resize(total_padded_q); + for(int b = 0; b < bs; b++) + for(int q = cu_seqlens_q_padded[b]; q < cu_seqlens_q_padded[b + 1]; q++) + padded_q_to_batch[q] = b; + + return total_padded_q; +} + +// --------------------------------------------------------------------------- +// KV cu_seqlens builder (fixed) +// --------------------------------------------------------------------------- + +inline void build_fixed_cu_seqlens_kv(int bs, + int fix_skv, + std::vector& cu_seqlens_kv, + std::vector& cu_seqlens_kv_padded, + int& total_padded_kv_seq) +{ + cu_seqlens_kv.resize(bs + 1); + cu_seqlens_kv_padded.resize(bs + 1); + cu_seqlens_kv[0] = cu_seqlens_kv_padded[0] = 0; + + for(int b = 0; b < bs; b++) + { + cu_seqlens_kv[b + 1] = cu_seqlens_kv[b] + fix_skv; + cu_seqlens_kv_padded[b + 1] = cu_seqlens_kv_padded[b] + fix_skv; + } + + total_padded_kv_seq = cu_seqlens_kv_padded[bs]; +} + +// --------------------------------------------------------------------------- +// Forward test +// --------------------------------------------------------------------------- + +template +bool test_fwd(bool varlen_q, int fix_sq, + bool varlen_kv, int fix_skv, + const std::string& label, + const std::vector& h_cu_seqlens_q, + const std::vector& h_cu_seqlens_q_padded, + const std::vector& h_padded_q_to_batch, + const std::vector& h_cu_seqlens_kv, + const std::vector& h_cu_seqlens_kv_padded, + int total_padded_q, + int total_padded_kv_seq) +{ + using Launcher = AttnForwardMfma16x16KernelLauncher; + + constexpr int bs = Config::bs; + constexpr int head_num = Config::head_num; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int max_seq_q = Config::max_seq_q; + constexpr int head_dim = Config::head_dim; + + std::mt19937 gen(42); + std::uniform_real_distribution dis(-1.0f, 1.0f); + + size_t size_Q = (size_t)total_padded_q * head_num * head_dim; + size_t size_K = (size_t)total_padded_kv_seq * head_num * head_dim; + size_t size_V = size_K; + size_t size_O = size_Q; + + std::vector h_Q(size_Q), h_K(size_K), h_V(size_V); + std::vector h_O_gpu(size_O, DataType(0.0f)); + std::vector h_O_cpu(size_O, DataType(0.0f)); + + for(size_t i = 0; i < size_Q; i++) h_Q[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_K; i++) h_K[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_V; i++) h_V[i] = DataType(dis(gen)); + + if constexpr(std::is_same::value) + { + for(size_t i = 0; i < size_Q; i++) h_Q[i] = float(hip_bfloat16(h_Q[i])); + for(size_t i = 0; i < size_K; i++) h_K[i] = float(hip_bfloat16(h_K[i])); + for(size_t i = 0; i < size_V; i++) h_V[i] = float(hip_bfloat16(h_V[i])); + } + + float sqr_dk_scale = 1.0f / std::sqrt(static_cast(head_dim)); + + attn_forward(h_Q.data(), h_K.data(), h_V.data(), + static_cast(nullptr), 0.0f, + h_O_cpu.data(), static_cast(nullptr), + bs, head_num, max_seq_kv, head_dim, Config::mask_type, + h_cu_seqlens_q.data(), h_cu_seqlens_q_padded.data(), + h_cu_seqlens_kv.data(), h_cu_seqlens_kv_padded.data(), + true); + + DataType *d_Q, *d_K, *d_V, *d_O; + float* d_softmax_lse; + int *d_cu_sq, *d_cu_sqp, *d_cu_skv, *d_cu_skvp; + + HIP_CHECK(hipMalloc(&d_Q, size_Q * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_K, size_K * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_V, size_V * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_O, size_O * sizeof(DataType))); + size_t ws_size = Launcher::calc_workspace_size(total_padded_q); + HIP_CHECK(hipMalloc(&d_softmax_lse, ws_size > 0 ? ws_size : sizeof(float))); + HIP_CHECK(hipMalloc(&d_cu_sq, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_sqp, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_skv, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_skvp, (bs + 1) * sizeof(int))); + + HIP_CHECK(hipMemcpy(d_Q, h_Q.data(), size_Q * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_K, h_K.data(), size_K * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_V, h_V.data(), size_V * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(d_O, 0, size_O * sizeof(DataType))); + HIP_CHECK(hipMemcpy(d_cu_sq, h_cu_seqlens_q.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_sqp, h_cu_seqlens_q_padded.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_skv, h_cu_seqlens_kv.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_skvp, h_cu_seqlens_kv_padded.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + + Launcher::run_attn_fwd_kernel(d_Q, d_K, d_V, + static_cast(nullptr), + 0.0f, sqr_dk_scale, d_O, d_softmax_lse, + d_cu_sq, d_cu_sqp, d_cu_skv, d_cu_skvp, + nullptr, total_padded_q); + HIP_CHECK(hipDeviceSynchronize()); + + HIP_CHECK(hipMemcpy(h_O_gpu.data(), d_O, size_O * sizeof(DataType), hipMemcpyDeviceToHost)); + + std::cout << " [FWD] "; + check_output(h_O_gpu, h_O_cpu, bs, head_num, head_dim, + h_cu_seqlens_q, h_cu_seqlens_q_padded, "Output", 1e-2f, 1e-2f, false); + + HIP_CHECK(hipFree(d_Q)); HIP_CHECK(hipFree(d_K)); HIP_CHECK(hipFree(d_V)); + HIP_CHECK(hipFree(d_O)); HIP_CHECK(hipFree(d_softmax_lse)); + HIP_CHECK(hipFree(d_cu_sq)); HIP_CHECK(hipFree(d_cu_sqp)); + HIP_CHECK(hipFree(d_cu_skv)); HIP_CHECK(hipFree(d_cu_skvp)); + + return true; +} + +// --------------------------------------------------------------------------- +// Backward test +// --------------------------------------------------------------------------- + +template +bool test_bwd(bool varlen_q, int fix_sq, + bool varlen_kv, int fix_skv, + const std::string& label, + const std::vector& h_cu_seqlens_q, + const std::vector& h_cu_seqlens_q_padded, + const std::vector& h_padded_q_to_batch, + const std::vector& h_cu_seqlens_kv, + const std::vector& h_cu_seqlens_kv_padded, + int total_padded_q, + int total_padded_kv_seq) +{ + using BwdLauncher = AttnBackwardMfma16x16KernelLauncher; + using FwdLauncher = AttnForwardMfma16x16KernelLauncher; + + constexpr int bs = Config::bs; + constexpr int head_num = Config::head_num; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int max_seq_q = Config::max_seq_q; + constexpr int head_dim = Config::head_dim; + + std::mt19937 gen(42); + std::uniform_real_distribution dis(-1.0f, 1.0f); + + size_t size_Q = (size_t)total_padded_q * head_num * head_dim; + size_t size_K = (size_t)total_padded_kv_seq * head_num * head_dim; + size_t size_V = size_K; + size_t size_dO = size_Q; + size_t size_P = (size_t)total_padded_q * head_num * max_seq_kv; + + std::vector h_Q(size_Q), h_K(size_K), h_V(size_V); + std::vector h_grad_O(size_dO); + std::vector h_P(size_P, DataType(0.0f)); + std::vector h_grad_Q_gpu(size_Q, DataType(0.0f)); + std::vector h_grad_K_gpu(size_K, DataType(0.0f)); + std::vector h_grad_V_gpu(size_V, DataType(0.0f)); + std::vector h_grad_Q_cpu(size_Q, DataType(0.0f)); + std::vector h_grad_K_cpu(size_K, DataType(0.0f)); + std::vector h_grad_V_cpu(size_V, DataType(0.0f)); + + for(size_t i = 0; i < size_Q; i++) h_Q[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_K; i++) h_K[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_V; i++) h_V[i] = DataType(dis(gen)); + for(size_t i = 0; i < size_dO; i++) h_grad_O[i] = DataType(dis(gen)); + + if constexpr(std::is_same::value) + { + for(size_t i = 0; i < size_Q; i++) h_Q[i] = float(hip_bfloat16(h_Q[i])); + for(size_t i = 0; i < size_K; i++) h_K[i] = float(hip_bfloat16(h_K[i])); + for(size_t i = 0; i < size_V; i++) h_V[i] = float(hip_bfloat16(h_V[i])); + for(size_t i = 0; i < size_dO; i++) h_grad_O[i] = float(hip_bfloat16(h_grad_O[i])); + } + + float sqr_dk_scale = 1.0f / std::sqrt(static_cast(head_dim)); + + DataType *d_Q, *d_K, *d_V, *d_O, *d_dO; + float* d_softmax_lse; + DataType *d_dQ, *d_dK, *d_dV; + int *d_cu_sq, *d_cu_sqp, *d_cu_skv, *d_cu_skvp; + int* d_padded_q_to_batch; + + HIP_CHECK(hipMalloc(&d_Q, size_Q * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_K, size_K * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_V, size_V * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_O, total_padded_q * head_num * head_dim * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_dO, size_dO * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_softmax_lse, + FwdLauncher::calc_workspace_size(total_padded_q) > 0 + ? FwdLauncher::calc_workspace_size(total_padded_q) + : sizeof(float))); + HIP_CHECK(hipMalloc(&d_dQ, size_Q * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_dK, size_K * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_dV, size_V * sizeof(DataType))); + HIP_CHECK(hipMalloc(&d_cu_sq, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_sqp, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_skv, (bs + 1) * sizeof(int))); + HIP_CHECK(hipMalloc(&d_cu_skvp, (bs + 1) * sizeof(int))); + if(total_padded_q > 0) + HIP_CHECK(hipMalloc(&d_padded_q_to_batch, total_padded_q * sizeof(int))); + else + d_padded_q_to_batch = nullptr; + + HIP_CHECK(hipMemcpy(d_Q, h_Q.data(), size_Q * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_K, h_K.data(), size_K * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_V, h_V.data(), size_V * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_dO, h_grad_O.data(), size_dO * sizeof(DataType), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_sq, h_cu_seqlens_q.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_sqp, h_cu_seqlens_q_padded.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_skv, h_cu_seqlens_kv.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_cu_skvp, h_cu_seqlens_kv_padded.data(), (bs + 1) * sizeof(int), hipMemcpyHostToDevice)); + if(total_padded_q > 0) + HIP_CHECK(hipMemcpy(d_padded_q_to_batch, h_padded_q_to_batch.data(), + total_padded_q * sizeof(int), hipMemcpyHostToDevice)); + + FwdLauncher::run_attn_fwd_kernel(d_Q, d_K, d_V, static_cast(nullptr), 0.0f, + sqr_dk_scale, d_O, d_softmax_lse, + d_cu_sq, d_cu_sqp, d_cu_skv, d_cu_skvp, + d_padded_q_to_batch, total_padded_q); + HIP_CHECK(hipDeviceSynchronize()); + + std::vector h_softmax_lse(total_padded_q * head_num); + HIP_CHECK(hipMemcpy(h_softmax_lse.data(), d_softmax_lse, + h_softmax_lse.size() * sizeof(float), hipMemcpyDeviceToHost)); + + reference_attn_probs_bf16_dots_with_given_lse( + h_Q, h_K, bs, head_num, max_seq_kv, head_dim, sqr_dk_scale, Config::mask_type, + h_cu_seqlens_q, h_cu_seqlens_q_padded, h_cu_seqlens_kv, h_cu_seqlens_kv_padded, + h_softmax_lse, h_P); + + // bf16_weights=false: MFMA bwd recomputes P in float (Option A). + attn_backward(h_Q.data(), h_K.data(), h_V.data(), h_grad_O.data(), + h_P.data(), static_cast(nullptr), 0.0f, + h_grad_Q_cpu.data(), h_grad_K_cpu.data(), h_grad_V_cpu.data(), + bs, head_num, max_seq_kv, head_dim, Config::mask_type, + h_cu_seqlens_q.data(), h_cu_seqlens_q_padded.data(), + h_cu_seqlens_kv.data(), h_cu_seqlens_kv_padded.data(), + total_padded_q, total_padded_kv_seq, + max_seq_q, false); + + HIP_CHECK(hipMemset(d_dQ, 0, size_Q * sizeof(DataType))); + HIP_CHECK(hipMemset(d_dK, 0, size_K * sizeof(DataType))); + HIP_CHECK(hipMemset(d_dV, 0, size_V * sizeof(DataType))); + + BwdLauncher::run_attn_bwd_kernel(d_Q, d_K, d_V, d_dO, d_softmax_lse, + d_dQ, d_dK, d_dV, sqr_dk_scale, + d_cu_sq, d_cu_sqp, d_cu_skv, d_cu_skvp); + HIP_CHECK(hipDeviceSynchronize()); + + HIP_CHECK(hipMemcpy(h_grad_Q_gpu.data(), d_dQ, size_Q * sizeof(DataType), hipMemcpyDeviceToHost)); + HIP_CHECK(hipMemcpy(h_grad_K_gpu.data(), d_dK, size_K * sizeof(DataType), hipMemcpyDeviceToHost)); + HIP_CHECK(hipMemcpy(h_grad_V_gpu.data(), d_dV, size_V * sizeof(DataType), hipMemcpyDeviceToHost)); + + std::cout << " [BWD] "; + check_output(h_grad_Q_gpu, h_grad_Q_cpu, bs, head_num, head_dim, + h_cu_seqlens_q, h_cu_seqlens_q_padded, "grad_Q", 1e-2f, 1e-2f, false); + std::cout << " [BWD] "; + check_output(h_grad_K_gpu, h_grad_K_cpu, bs, head_num, head_dim, + h_cu_seqlens_kv, h_cu_seqlens_kv_padded, "grad_K", 1e-2f, 1e-2f, false); + std::cout << " [BWD] "; + check_output(h_grad_V_gpu, h_grad_V_cpu, bs, head_num, head_dim, + h_cu_seqlens_kv, h_cu_seqlens_kv_padded, "grad_V", 1e-2f, 1e-2f, false); + + HIP_CHECK(hipFree(d_Q)); HIP_CHECK(hipFree(d_K)); HIP_CHECK(hipFree(d_V)); + HIP_CHECK(hipFree(d_O)); HIP_CHECK(hipFree(d_dO)); HIP_CHECK(hipFree(d_softmax_lse)); + HIP_CHECK(hipFree(d_dQ)); HIP_CHECK(hipFree(d_dK)); HIP_CHECK(hipFree(d_dV)); + HIP_CHECK(hipFree(d_cu_sq)); HIP_CHECK(hipFree(d_cu_sqp)); + HIP_CHECK(hipFree(d_cu_skv)); HIP_CHECK(hipFree(d_cu_skvp)); + if(d_padded_q_to_batch) HIP_CHECK(hipFree(d_padded_q_to_batch)); + + return true; +} + +// --------------------------------------------------------------------------- +// Run one test case (fwd + bwd) +// --------------------------------------------------------------------------- + +template +void run_test_case(bool varlen_q, int fix_sq, + bool varlen_kv, int fix_skv, + const std::string& label) +{ + constexpr int bs = Config::bs; + constexpr int max_seq_q = Config::max_seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + + std::mt19937 gen(42); + + std::vector h_cu_sq, h_cu_sqp, h_q2b; + std::vector h_cu_skv, h_cu_skvp; + int total_padded_kv_seq; + int total_padded_q; + int total_actual_kv_seq; + + if(varlen_q) + total_padded_q = build_varlen_cu_seqlens_q( + bs, max_seq_q, gen, h_cu_sq, h_cu_sqp, h_q2b); + else + total_padded_q = build_fixed_cu_seqlens_q( + bs, fix_sq, h_cu_sq, h_cu_sqp, h_q2b); + + if(varlen_kv) + build_cu_seqlens_kv( + bs, max_seq_kv, gen, h_cu_skv, h_cu_skvp, + total_actual_kv_seq, total_padded_kv_seq); + else + build_fixed_cu_seqlens_kv( + bs, fix_skv, h_cu_skv, h_cu_skvp, total_padded_kv_seq); + + std::cout << "\n===== " << label << " =====" << std::endl; + std::cout << " bs=" << bs + << " max_sq=" << max_seq_q << " max_skv=" << max_seq_kv + << " total_padded_q=" << total_padded_q + << " total_padded_kv=" << total_padded_kv_seq << std::endl; + + test_fwd( + varlen_q, fix_sq, varlen_kv, fix_skv, label, + h_cu_sq, h_cu_sqp, h_q2b, h_cu_skv, h_cu_skvp, + total_padded_q, total_padded_kv_seq); + + test_bwd( + varlen_q, fix_sq, varlen_kv, fix_skv, label, + h_cu_sq, h_cu_sqp, h_q2b, h_cu_skv, h_cu_skvp, + total_padded_q, total_padded_kv_seq); + + std::cout << "====================================\n" << std::endl; +} + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main() +{ + // Test 1: sq∈[1,16] varlen+pad, skv∈[2,16] varlen+pad + { + using Cfg = FmhaKernelConfig<2048, 8, 16, 128, 256, false, CausalMaskType::DISABLE, 16>; + run_test_case(true, 0, true, 0, + "Test 1: sq varlen+pad [1,16], skv varlen+pad [2,16]"); + } + + // Test 2: sq=1 fixed, skv∈[2,16] varlen+pad + { + using Cfg = FmhaKernelConfig<2048, 8, 16, 128, 256, false, CausalMaskType::DISABLE, 1>; + run_test_case(false, 1, true, 0, + "Test 2: sq=1 fixed, skv varlen+pad [2,16]"); + } + + // Test 3: sq=16, skv=16; fixed, no padding + { + using Cfg = FmhaKernelConfig<2048, 8, 16, 128, 256, false, CausalMaskType::DISABLE, 16>; + run_test_case(false, 16, false, 16, + "Test 3: sq=16 fixed, skv=16 fixed"); + } + + // Test 4: sq=17, skv=17; fixed, no padding + { + using Cfg = FmhaKernelConfig<2048, 8, 17, 128, 256, false, CausalMaskType::DISABLE, 17>; + run_test_case(false, 17, false, 17, + "Test 4: sq=17 fixed, skv=17 fixed"); + } + + return 0; +} diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 56d91775a..026d116ed 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -406,11 +406,11 @@ def _get_max_segments_per_sequence(self): if self.qkv_layout.is_thd(): if 90400 <= get_cudnn_version() < 90500: return self.num_segments_per_seq - else: - # +1 for testing runtime_segments < max_segments - return self.num_segments_per_seq + 1 - else: - return 1 + if is_hip_extension() and os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1": + return self.num_segments_per_seq + # +1 for testing runtime_segments < max_segments + return self.num_segments_per_seq + 1 + return 1 def _check_configs(self): # TODO(rewang): probably adds this in is_fused_attn_available @@ -526,6 +526,49 @@ def _check_configs(self): "the F16_arbitrary_seqlen backend." ) + def _setup_segments_ck_smallseq(self, generate_random_segment_ids): + """ + Segment ids / seqlens for NVTE_FUSED_ATTN_CK_SMALLSEQ + padded ragged layouts. + + num_segments_per_seq follows max_seqlen_q; max_seqlen_q==1 uses a fixed Q row and + corrected seqlens_q. KV always uses generate_random_segment_ids. + """ + num_segments_per_seq = self.max_seqlen_q + if self.max_seqlen_q == 1: + segment_ids_q = jnp.ones((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) + segment_pos_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) + pad_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32) + seqlens_q, offsets_q = get_seqlens_and_offsets(segment_ids_q) + seqlens_q = jnp.ones((self.batch_size, 1), dtype=jnp.int32) + else: + segment_ids_q, segment_pos_q, pad_q = generate_random_segment_ids( + self.batch_size, self.max_seqlen_q, num_segments_per_seq, seed=42 + ) + seqlens_q, offsets_q = get_seqlens_and_offsets(segment_ids_q) + + min_segment_len = None if self.window_size is None else seqlens_q + segment_ids_kv, segment_pos_kv, pad_kv = generate_random_segment_ids( + self.batch_size, + self.max_seqlen_kv, + num_segments_per_seq, + seed=2024, + min_segment_len=min_segment_len, + ) + seqlens_kv, offsets_kv = get_seqlens_and_offsets(segment_ids_kv) + return ( + num_segments_per_seq, + segment_ids_q, + segment_pos_q, + pad_q, + seqlens_q, + offsets_q, + segment_ids_kv, + segment_pos_kv, + pad_kv, + seqlens_kv, + offsets_kv, + ) + def _setup_inputs(self): self._check_configs() @@ -654,30 +697,47 @@ def generate_random_segment_ids( return segment_ids, segment_pos, segment_pad if self.qkv_layout.is_thd(): - self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( - self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 - ) - self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) - # TODO(rewang): record only self attention and find the reason of cross attention - if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv: - self.segment_ids_kv = self.segment_ids_q - self.segment_pos_kv = self.segment_pos_q - self.pad_kv = self.pad_q - else: - # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support - min_segment_len = None - if ( - self.window_size is not None or self.attn_mask_type.is_bottom_right() - ): # SWA or BRCM requires kv_len >= q_len - min_segment_len = self.seqlens_q - self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( - self.batch_size, - self.max_seqlen_kv, + if is_hip_extension() and os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1": + ( self.num_segments_per_seq, - seed=2024, - min_segment_len=min_segment_len, + self.segment_ids_q, + self.segment_pos_q, + self.pad_q, + self.seqlens_q, + self.offsets_q, + self.segment_ids_kv, + self.segment_pos_kv, + self.pad_kv, + self.seqlens_kv, + self.offsets_kv, + ) = self._setup_segments_ck_smallseq(generate_random_segment_ids) + else: + self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( + self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) - self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv) + self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) + # TODO(rewang): record only self attention and find the reason of cross attention + if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv: + self.segment_ids_kv = self.segment_ids_q + self.segment_pos_kv = self.segment_pos_q + self.pad_kv = self.pad_q + else: + # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support + min_segment_len = None + if ( + self.window_size is not None or self.attn_mask_type.is_bottom_right() + ): # SWA or BRCM requires kv_len >= q_len + min_segment_len = self.seqlens_q + self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = ( + generate_random_segment_ids( + self.batch_size, + self.max_seqlen_kv, + self.num_segments_per_seq, + seed=2024, + min_segment_len=min_segment_len, + ) + ) + self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv) else: self.segment_ids_q, self.pad_q = gen_valid( self.batch_size, self.max_seqlen_q, pad_ratio @@ -1902,3 +1962,100 @@ def fused_fn(q, k, v): for name, x, y in zip(("dQ", "dK", "dV"), grads1, grads2): # Bitwise reproducibility across consecutive runs assert_allclose(x, y, atol=0, rtol=0, err_msg=f"{name} not bitwise reproducible") + + +_SKIP_ROCM_CK_SMALLSEQ = pytest.mark.skipif( + not is_hip_extension(), + reason="CK unfused small-seq tests only on ROCm", +) + +# ROCm CK small-seq tests. +@pytest.fixture +def ck_smallseq_env(monkeypatch): + """Enable CK small-seq path and disable XLA GPU graphs for these tests.""" + # gfx942 uses the dedicated unfused small-seq path (NVTE_FUSED_ATTN_CK_SMALLSEQ), + # which requires XLA GPU graphs disabled. + if get_device_compute_capability(0) == 94: + if "xla_gpu_graph_level=0" not in os.environ.get("XLA_FLAGS", ""): + pytest.skip("Test must be run with XLA_FLAGS='--xla_gpu_graph_level=0'") + monkeypatch.setenv("NVTE_FUSED_ATTN_CK_SMALLSEQ", "1") + yield + +@_SKIP_ROCM_CK_SMALLSEQ +@pytest.mark.usefixtures("ck_smallseq_env") +class TestFusedAttnCkSmallseq: + """ + ROCm CK small-seq (NVTE_FUSED_ATTN_CK_SMALLSEQ). + Covers 1<=s_q<=16, 2<=s_kv<=16 THD self/cross attention and BSHD. + """ + + @staticmethod + @pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16], ids=["BF16", "FP16"]) + @pytest.mark.parametrize( + "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, qkv_layout", + [ + # cross-attention (q=1 attends over kv), THD + no bias + pytest.param(4000, 1, 2, 16, 16, 128, 128, QKVLayout.THD_THD_THD, id="cross-attn-THD_THD_THD-4000-1-2-16-16-128-128"), + pytest.param(4000, 1, 4, 16, 16, 128, 128, QKVLayout.THD_THD_THD, id="cross-attn-THD_THD_THD-4000-1-4-16-16-128-128"), + pytest.param(4000, 1, 6, 16, 16, 128, 128, QKVLayout.THD_THD_THD, id="cross-attn-THD_THD_THD-4000-1-6-16-16-128-128"), + pytest.param(4000, 1, 8, 16, 16, 128, 128, QKVLayout.THD_THD_THD, id="cross-attn-THD_THD_THD-4000-1-8-16-16-128-128"), + pytest.param(4000, 1, 12, 16, 16, 128, 128, QKVLayout.THD_THD_THD, id="cross-attn-THD_THD_THD-4000-1-12-16-16-128-128"), + pytest.param(4000, 1, 16, 16, 16, 128, 128, QKVLayout.THD_THD_THD, id="cross-attn-THD_THD_THD-4000-1-16-16-16-128-128"), + pytest.param(4000, 1, 4, 32, 32, 128, 128, QKVLayout.THD_THD_THD, id="cross-attn-THD_THD_THD-256-1-4-32-32-128-128"), + pytest.param(4000, 1, 6, 16, 16, 256, 256, QKVLayout.THD_THD_THD, id="cross-attn-THD_THD_THD-128-1-6-16-16-256-256"), + # self-attention (s_q == s_kv), THD + padding + pytest.param(4000, 8, 8, 16, 16, 128, 128, QKVLayout.THD_THD_THD, id="self-attn-THD_THD_THD-2-8-8-16-16-128-128"), + pytest.param(4000, 8, 8, 16, 16, 128, 128, QKVLayout.THD_THD_THD, id="self-attn-THD_THD_THD-32-8-8-16-16-128-128"), + pytest.param(4000, 16, 16, 16, 16, 128, 128, QKVLayout.THD_THD_THD, id="self-attn-THD_THD_THD-48-16-16-16-16-128-128"), + pytest.param(4000, 16, 16, 16, 16, 128, 128, QKVLayout.THD_THD_THD, id="self-attn-THD_THD_THD-16-16-16-16-16-128-128"), + pytest.param(4000, 17, 17, 16, 16, 128, 128, QKVLayout.THD_THD_THD, id="self-attn-THD_THD_THD-8-17-17-16-16-128-128"), + # cross-attention (s_q != s_kv), THD + padding + pytest.param(4000, 4, 8, 16, 16, 128, 128, QKVLayout.THD_THD_THD, id="cross-attn-THD_THD_THD-64-4-8-16-16-128-128"), + pytest.param(4000, 8, 12, 16, 16, 128, 128, QKVLayout.THD_THD_THD, id="cross-attn-THD_THD_THD-64-8-12-16-16-128-128"), + pytest.param(4000, 12, 16, 16, 16, 128, 128, QKVLayout.THD_THD_THD, id="cross-attn-THD_THD_THD-32-12-16-16-16-128-128"), + # self-attention, BSHD + no mask + no bias + pytest.param(4000, 16, 16, 16, 16, 128, 128, QKVLayout.BSHD_BSHD_BSHD, id="self-attn-BSHD_BSHD_BSHD-4-16-16-16-16-128-128"), + pytest.param(4000, 17, 17, 16, 16, 128, 128, QKVLayout.BSHD_BSHD_BSHD, id="self-attn-BSHD_BSHD_BSHD-4000-17-17-16-16-128-128"), + ], + ) + def test_smallseq( + dtype, + b, + s_q, + s_kv, + h_q, + h_kv, + d_qk, + d_v, + qkv_layout, + ): + """CK small-seq THD/BSHD: no bias; padding mask for THD, no mask for BSHD. + + """ + attn_mask_type = ( + AttnMaskType.NO_MASK + if qkv_layout == QKVLayout.BSHD_BSHD_BSHD + else AttnMaskType.PADDING_MASK + ) + runner = FusedAttnRunner( + batch_size=b, + max_seqlen_q=s_q, + max_seqlen_kv=s_kv, + num_heads_q=h_q, + num_heads_kv=h_kv, + head_dim_qk=d_qk, + head_dim_v=d_v, + attn_bias_type=AttnBiasType.NO_BIAS, + attn_mask_type=attn_mask_type, + softmax_type=AttnSoftmaxType.VANILLA_SOFTMAX, + dropout_prob=0.0, + use_old_rng=True, + dtype=dtype, + is_training=True, + qkv_layout=qkv_layout, + bias_shape=None, + window_size=None, + seq_desc_format=SeqDescFormat.Seqlens, + ) + runner.test_forward() + runner.test_backward() diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 02eaaea93..6eed17b3a 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -290,6 +290,7 @@ if(USE_ROCM) fused_attn_rocm/fused_attn.cpp fused_attn_rocm/fused_attn_aotriton.cpp fused_attn_rocm/fused_attn_ck.cpp + fused_attn_rocm/fused_attn_smallseq.cpp fused_attn_rocm/utils.cpp gemm/ck_grouped_gemm/ck_grouped_gemm.cpp gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -355,6 +356,11 @@ else() #USE_ROCM add_library(transformer_engine SHARED ${te_hip_sources}) + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/fused_attn_rocm/small_seq_kernels") + target_include_directories(transformer_engine PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}/fused_attn_rocm/small_seq_kernels") + endif() + # Workaround for TheRock installation that moved some headers from system-wide location # to rocm_sysdeps but missing it in the default include path. target_include_directories(transformer_engine SYSTEM PRIVATE "${ROCM_PATH}/lib/rocm_sysdeps/include") diff --git a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp index 127d75b4c..b3146ed42 100644 --- a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp +++ b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp @@ -143,6 +143,12 @@ struct CkAttnBwdArgs : CKAttnCommonArgs { hipError_t ck_attn_fwd(const CKAttnFwdArgs& args, hipStream_t stream); hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream); +uint64_t get_runtime_max_seqlen(uint64_t b, + const void* cu_seqlen_ptr, + const void* cu_seqlen_padded_ptr, + void* workspace, + hipStream_t stream); + }//namespace ck_fused_attn #endif // CK_FUSED_ATTN_H diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index b621d7174..1e7dc5217 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -10,6 +10,7 @@ #include "transformer_engine/fused_attn.h" #include "fused_attn_aotriton.h" #include "fused_attn_ck.h" +#include "fused_attn_smallseq.h" #include "../common.h" #include "../util/cuda_runtime.h" //cuda::sm_arch #include "../util/system.h" //getenv diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 0c7d80a6a..e57e10439 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -13,6 +13,7 @@ #include "../util/cuda_runtime.h" #include "../util/system.h" #include "fused_attn_ck.h" +#include "fused_attn_smallseq.h" #include "utils.h" namespace transformer_engine { @@ -273,6 +274,22 @@ void generate_alibi_slope(uint64_t h, float* alibi_slope_ptr){ } } +// Legacy CK small-seq layout (cross-attn varlen): one thread per batch index fills +// padded_q_to_batch[i] = b for i in [cu_seqlens_q_padded[b], cu_seqlens_q_padded[b+1]). +__global__ void build_padded_q_to_batch_kernel(const int* cu_seqlens_q_padded, + int bs, + int* padded_q_to_batch) { + int b = blockIdx.x * blockDim.x + threadIdx.x; + if(b >= bs) { + return; + } + int start = cu_seqlens_q_padded[b]; + int end = cu_seqlens_q_padded[b + 1]; + for(int i = start; i < end; ++i) { + padded_q_to_batch[i] = b; + } +} + // no device std::upper_bound // in an increasing array with given size len, search for the index that: // array[index] <= target < array[index+1] @@ -498,6 +515,18 @@ void fused_attn_ck_fwd_impl( // (planner returns nullptr, accumulates total) and execution mode. WorkspacePlanner planner(workspace); + // Prefix layout matches legacy CK small-seq: [max_seqlen_q probe][max_seqlen_kv probe][padded_q_to_batch] + void* ck_smallseq_workspace_prefix = nullptr; + const bool ck_small_seq_enabled = + is_nvte_ck_small_seq_enabled() && + small_seq_static_config_ok(static_cast(dtype), static_cast(dtype), + bias_type, dropout_probability, d_qk, d_v, h, hg, mask_type) && + is_ragged; + if(ck_small_seq_enabled) { + ck_smallseq_workspace_prefix = + planner.allocate(small_seq_extra_workspace_bytes(max_tokens_q)); + } + void* devPtrAlibiSlope = nullptr; if(bias_type == NVTE_Bias_Type::NVTE_ALIBI){ // ck requires an alibi slope array even if in standard (vanilla) mode @@ -658,21 +687,72 @@ void fused_attn_ck_fwd_impl( pad_remap(dtype, b, h, s_q, d_v, max_tokens_q, false, o_stride[0], o_stride[1], o_stride[2], devPtrO, devPtrCuSeqlensQ, devPtrCuSeqlenPaddedQ, devPtrOWithoutPadding, stream); pad_remap_lse(b, h, s_q, max_tokens_q, false, devPtrSoftmaxAux, devPtrCuSeqlensQ, devPtrCuSeqlenPaddedQ, devPtrSoftmaxLSEWithoutPadding, stream); }else if(bshd_to_thd || is_ragged){ - ck_args.max_tokens_q = max_tokens_q; - ck_args.q_ptr = devPtrQ; - ck_args.stride_h_q = q_stride[1]; ck_args.stride_s_q = q_stride[2]; - ck_args.k_ptr = devPtrK; - ck_args.stride_h_k = k_stride[1]; ck_args.stride_s_k = k_stride[2]; - ck_args.v_ptr = devPtrV; - ck_args.stride_h_v = v_stride[1]; ck_args.stride_s_v = v_stride[2]; - ck_args.cu_seqlen_q_ptr = devPtrCuSeqlensQ; - ck_args.cu_seqlen_kv_ptr = devPtrCuSeqlensKV; - ck_args.cu_seqlen_q_padded_ptr = devPtrCuSeqlenPaddedQ; - ck_args.cu_seqlen_kv_padded_ptr = devPtrCuSeqlenPaddedKV; - ck_args.o_ptr = devPtrO; - ck_args.stride_h_o = o_stride[1]; ck_args.stride_s_o = o_stride[2]; - ck_args.lse_ptr = devPtrSoftmaxLSEWithoutPadding; - NVTE_CHECK_CUDA(ck_fused_attn::ck_attn_fwd(ck_args, stream)); + bool ran_smallseq = false; + if(ck_smallseq_workspace_prefix != nullptr && is_ragged && ck_small_seq_enabled) { + void* workspace_next = ck_smallseq_workspace_prefix; + void* max_seqlen_workspace_q = workspace_next; + void* max_seqlen_workspace_kv = + static_cast(static_cast(workspace_next) + sizeof(uint64_t)); + hipStream_t hip_stream = reinterpret_cast(stream); + const size_t runtime_max_seqlen_q = static_cast(ck_fused_attn::get_runtime_max_seqlen( + b, devPtrCuSeqlensQ, devPtrCuSeqlenPaddedQ, max_seqlen_workspace_q, hip_stream)); + const size_t runtime_max_seqlen_kv = static_cast(ck_fused_attn::get_runtime_max_seqlen( + b, devPtrCuSeqlensKV, devPtrCuSeqlenPaddedKV, max_seqlen_workspace_kv, hip_stream)); + workspace_next = + static_cast(static_cast(workspace_next) + 2 * sizeof(uint64_t)); + if(nvte_log_ck_config) { + std::cout << std::endl << "attn_fwd(ck small-seq): "; + std::cout << "b: " << b << ", "; + std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", "; + std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << ", "; + std::cout << "flow: " + << (is_runtime_small_seq_eligible(runtime_max_seqlen_q, runtime_max_seqlen_kv) + ? "ck-smallseq" + : "regular ck/aiter") + << std::endl; + } + if(is_runtime_small_seq_eligible(runtime_max_seqlen_q, runtime_max_seqlen_kv)) { + const int total_padded_q = static_cast(max_tokens_q); + int* devPtrPaddedQToBatch = static_cast(workspace_next); + workspace_next = static_cast(static_cast(workspace_next) + + static_cast(total_padded_q) * sizeof(int)); + (void)workspace_next; // Legacy layout: remainder reserved for small-seq scratch (unused here). + constexpr int k_build_padded_threads = 256; + const int bs = static_cast(b); + if(bs > 0) { + const unsigned grid_x = static_cast( + (static_cast(bs) + k_build_padded_threads - 1) / k_build_padded_threads); + dim3 grid(grid_x); + dim3 block(k_build_padded_threads); + build_padded_q_to_batch_kernel<<>>( + static_cast(devPtrCuSeqlenPaddedQ), bs, devPtrPaddedQToBatch); + NVTE_CHECK_CUDA(hipGetLastError()); + } + NVTE_CHECK_CUDA(hipStreamSynchronize(hip_stream)); + ran_smallseq = fused_attn_smallseq_fwd( + b, h, d_qk, max_tokens_q, max_tokens_kv, scaling_factor, devPtrQ, devPtrK, devPtrV, + devPtrO, devPtrSoftmaxLSEWithoutPadding, devPtrCuSeqlensQ, devPtrCuSeqlenPaddedQ, + devPtrCuSeqlensKV, devPtrCuSeqlenPaddedKV, devPtrPaddedQToBatch, + static_cast(dtype), stream); + } + } + if(!ran_smallseq) { + ck_args.max_tokens_q = max_tokens_q; + ck_args.q_ptr = devPtrQ; + ck_args.stride_h_q = q_stride[1]; ck_args.stride_s_q = q_stride[2]; + ck_args.k_ptr = devPtrK; + ck_args.stride_h_k = k_stride[1]; ck_args.stride_s_k = k_stride[2]; + ck_args.v_ptr = devPtrV; + ck_args.stride_h_v = v_stride[1]; ck_args.stride_s_v = v_stride[2]; + ck_args.cu_seqlen_q_ptr = devPtrCuSeqlensQ; + ck_args.cu_seqlen_kv_ptr = devPtrCuSeqlensKV; + ck_args.cu_seqlen_q_padded_ptr = devPtrCuSeqlenPaddedQ; + ck_args.cu_seqlen_kv_padded_ptr = devPtrCuSeqlenPaddedKV; + ck_args.o_ptr = devPtrO; + ck_args.stride_h_o = o_stride[1]; ck_args.stride_s_o = o_stride[2]; + ck_args.lse_ptr = devPtrSoftmaxLSEWithoutPadding; + NVTE_CHECK_CUDA(ck_fused_attn::ck_attn_fwd(ck_args, stream)); + } // aiter asm output softmax_lse with padding pad_remap_lse(b, h, s_q, max_tokens_q, is_ragged, devPtrSoftmaxAux, devPtrCuSeqlenPaddedQ, devPtrCuSeqlenPaddedQ, devPtrSoftmaxLSEWithoutPadding, stream); }else{ @@ -746,6 +826,18 @@ void fused_attn_ck_bwd_impl( // (planner returns nullptr, accumulates total) and execution mode. WorkspacePlanner planner(workspace); + // Prefix layout matches legacy CK small-seq: [max_seqlen_q probe][max_seqlen_kv probe][padded_q_to_batch] + void* ck_smallseq_workspace_prefix = nullptr; + const bool ck_small_seq_enabled = + is_nvte_ck_small_seq_enabled() && + small_seq_static_config_ok(static_cast(dtype), static_cast(dtype), + bias_type, dropout_probability, d_qk, d_v, h, hg, mask_type) && + is_ragged; + if(ck_small_seq_enabled) { + ck_smallseq_workspace_prefix = + planner.allocate(small_seq_extra_workspace_bytes(max_tokens_q)); + } + // First h*max_tokens_q*sizeof(float) is the lse-d buffer (passed as softmax_lsed) void* lse_workspace = planner.allocate(h*max_tokens_q*sizeof(float)); @@ -1024,32 +1116,70 @@ void fused_attn_ck_bwd_impl( pad_remap(dtype, b, hg, s_kv, d_v, max_tokens_kv, is_ragged, v_stride[0], v_stride[1], v_stride[2], devPtrdV, devPtrCuSeqlensKV, devPtrSeqOffsetsKV, devPtrdVWithoutPadding, stream); }else if(bshd_to_thd || is_ragged){ pad_remap_lse(b, h, s_q, max_tokens_q, is_ragged, devPtrSoftmaxAux, devPtrCuSeqlenPaddedQ, devPtrCuSeqlenPaddedQ, devPtrSoftmaxLSEWithoutPadding, stream); - ck_args.max_tokens_q = max_tokens_q; ck_args.max_tokens_kv = max_tokens_kv; - ck_args.q_ptr = devPtrQ; - ck_args.stride_h_q = q_stride[1]; ck_args.stride_s_q = q_stride[2]; - ck_args.k_ptr = devPtrK; - ck_args.stride_h_k = k_stride[1]; ck_args.stride_s_k = k_stride[2]; - ck_args.v_ptr = devPtrV; - ck_args.stride_h_v = v_stride[1]; ck_args.stride_s_v = v_stride[2]; - ck_args.cu_seqlen_q_ptr = devPtrCuSeqlensQ; - ck_args.cu_seqlen_kv_ptr = devPtrCuSeqlensKV; - ck_args.cu_seqlen_q_padded_ptr = devPtrCuSeqlenPaddedQ; - ck_args.cu_seqlen_kv_padded_ptr = devPtrCuSeqlenPaddedKV; - ck_args.o_ptr = devPtrO; - ck_args.stride_h_o = o_stride[1]; ck_args.stride_s_o = o_stride[2]; - ck_args.lse_ptr = devPtrSoftmaxLSEWithoutPadding; - // dO and O share the same stride - ck_args.do_ptr = devPtrdO; - ck_args.stride_h_do = o_stride[1]; ck_args.stride_s_do = o_stride[2]; - ck_args.dq_ptr = devPtrdQ; - ck_args.stride_h_dq = q_stride[1]; ck_args.stride_s_dq = q_stride[2]; - ck_args.stride_h_dk_expanded = dk_expanded_stride[1]; ck_args.stride_s_dk_expanded = dk_expanded_stride[2]; - ck_args.stride_h_dv_expanded = dv_expanded_stride[1]; ck_args.stride_s_dv_expanded = dv_expanded_stride[2]; - ck_args.dk_ptr = devPtrdK; - ck_args.stride_h_dk = k_stride[1]; ck_args.stride_s_dk = k_stride[2]; - ck_args.dv_ptr = devPtrdV; - ck_args.stride_h_dv = v_stride[1]; ck_args.stride_s_dv = v_stride[2]; - NVTE_CHECK_CUDA(ck_fused_attn::ck_attn_bwd(ck_args, stream)); + bool ran_smallseq_bwd = false; + if(ck_smallseq_workspace_prefix != nullptr && is_ragged && ck_small_seq_enabled) { + void* workspace_next = ck_smallseq_workspace_prefix; + void* max_seqlen_workspace_q = workspace_next; + void* max_seqlen_workspace_kv = + static_cast(static_cast(workspace_next) + sizeof(uint64_t)); + hipStream_t hip_stream = reinterpret_cast(stream); + const size_t runtime_max_seqlen_q = static_cast(ck_fused_attn::get_runtime_max_seqlen( + b, devPtrCuSeqlensQ, devPtrCuSeqlenPaddedQ, max_seqlen_workspace_q, hip_stream)); + const size_t runtime_max_seqlen_kv = static_cast(ck_fused_attn::get_runtime_max_seqlen( + b, devPtrCuSeqlensKV, devPtrCuSeqlenPaddedKV, max_seqlen_workspace_kv, hip_stream)); + workspace_next = + static_cast(static_cast(workspace_next) + 2 * sizeof(uint64_t)); + if(nvte_log_ck_config) { + std::cout << std::endl << "attn_bwd(ck small-seq): "; + std::cout << "b: " << b << ", "; + std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", "; + std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << ", "; + std::cout << "flow: " + << (is_runtime_small_seq_eligible(runtime_max_seqlen_q, runtime_max_seqlen_kv) + ? "ck-smallseq" + : "regular ck/aiter") + << std::endl; + } + if(is_runtime_small_seq_eligible(runtime_max_seqlen_q, runtime_max_seqlen_kv)) { + const int total_padded_q = static_cast(max_tokens_q); + workspace_next = static_cast(static_cast(workspace_next) + + static_cast(total_padded_q) * sizeof(int)); + (void)workspace_next; // Same prefix carve as fwd; small-seq bwd path does not consume padded map. + ran_smallseq_bwd = fused_attn_smallseq_bwd( + b, h, d_qk, max_tokens_q, max_tokens_kv, scaling_factor, devPtrQ, devPtrK, devPtrV, + devPtrdO, devPtrSoftmaxLSEWithoutPadding, devPtrdQ, devPtrdK, devPtrdV, + devPtrCuSeqlensQ, devPtrCuSeqlenPaddedQ, devPtrCuSeqlensKV, devPtrCuSeqlenPaddedKV, + static_cast(dtype), stream); + } + } + if(!ran_smallseq_bwd) { + ck_args.max_tokens_q = max_tokens_q; ck_args.max_tokens_kv = max_tokens_kv; + ck_args.q_ptr = devPtrQ; + ck_args.stride_h_q = q_stride[1]; ck_args.stride_s_q = q_stride[2]; + ck_args.k_ptr = devPtrK; + ck_args.stride_h_k = k_stride[1]; ck_args.stride_s_k = k_stride[2]; + ck_args.v_ptr = devPtrV; + ck_args.stride_h_v = v_stride[1]; ck_args.stride_s_v = v_stride[2]; + ck_args.cu_seqlen_q_ptr = devPtrCuSeqlensQ; + ck_args.cu_seqlen_kv_ptr = devPtrCuSeqlensKV; + ck_args.cu_seqlen_q_padded_ptr = devPtrCuSeqlenPaddedQ; + ck_args.cu_seqlen_kv_padded_ptr = devPtrCuSeqlenPaddedKV; + ck_args.o_ptr = devPtrO; + ck_args.stride_h_o = o_stride[1]; ck_args.stride_s_o = o_stride[2]; + ck_args.lse_ptr = devPtrSoftmaxLSEWithoutPadding; + // dO and O share the same stride + ck_args.do_ptr = devPtrdO; + ck_args.stride_h_do = o_stride[1]; ck_args.stride_s_do = o_stride[2]; + ck_args.dq_ptr = devPtrdQ; + ck_args.stride_h_dq = q_stride[1]; ck_args.stride_s_dq = q_stride[2]; + ck_args.stride_h_dk_expanded = dk_expanded_stride[1]; ck_args.stride_s_dk_expanded = dk_expanded_stride[2]; + ck_args.stride_h_dv_expanded = dv_expanded_stride[1]; ck_args.stride_s_dv_expanded = dv_expanded_stride[2]; + ck_args.dk_ptr = devPtrdK; + ck_args.stride_h_dk = k_stride[1]; ck_args.stride_s_dk = k_stride[2]; + ck_args.dv_ptr = devPtrdV; + ck_args.stride_h_dv = v_stride[1]; ck_args.stride_s_dv = v_stride[2]; + NVTE_CHECK_CUDA(ck_fused_attn::ck_attn_bwd(ck_args, stream)); + } }else{ ck_args.bias_b = bias_b; ck_args.bias_h = bias_h; ck_args.q_ptr = devPtrQ; diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp new file mode 100644 index 000000000..a19019df7 --- /dev/null +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.cpp @@ -0,0 +1,368 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include "fused_attn_smallseq.h" + +#include +#include +#include + +#include +#include +#include + +#include "../common.h" +#include "../util/cuda_runtime.h" +#include "utils.h" + +#ifdef USE_FUSED_ATTN_CK +#include +#include "attn_bwd_mfma_16x16.h" +#include "attn_fwd_mfma_dispatch.h" +#endif + +namespace transformer_engine { +namespace fused_attn_rocm { + +bool small_seq_static_config_ok(NVTEDType q_dtype, + NVTEDType kv_dtype, + NVTE_Bias_Type bias_type, + float dropout, + size_t head_dim_qk, + size_t head_dim_v, + size_t num_attn_heads, + size_t num_gqa_groups, + NVTE_Mask_Type mask_type) { + if(dropout != 0.0f) return false; + if(bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) return false; + if(q_dtype != kv_dtype) return false; + if(!(q_dtype == NVTEDType::kNVTEFloat16 || q_dtype == NVTEDType::kNVTEBFloat16)) return false; + if(head_dim_qk != head_dim_v) return false; + if(head_dim_qk != 128 && head_dim_qk != 256) return false; + if(num_gqa_groups == 0 || num_attn_heads % num_gqa_groups != 0) return false; + if(num_attn_heads != num_gqa_groups) return false; + if(num_attn_heads != 16 && num_attn_heads != 32) return false; + if(!(is_padding_mask(mask_type) || mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) return false; + return true; +} + +bool is_runtime_small_seq_eligible(size_t runtime_max_seqlen_q, size_t runtime_max_seqlen_kv) { + return runtime_max_seqlen_q > 0 && runtime_max_seqlen_q <= kSmallSeqMaxSeqlen && + runtime_max_seqlen_kv > 0 && runtime_max_seqlen_kv <= kSmallSeqMaxSeqlen; +} + +bool supports_hip_small_seq(size_t num_attn_heads, + size_t num_gqa_groups, + size_t head_dim_qk, + size_t head_dim_v) { + if(num_attn_heads != num_gqa_groups) return false; + if(num_attn_heads != 16 && num_attn_heads != 32) return false; + if(head_dim_qk != head_dim_v) return false; + return head_dim_qk == 128 || head_dim_qk == 256; +} + +size_t small_seq_extra_workspace_bytes(size_t max_tokens_q) { + return 2 * sizeof(uint64_t) + max_tokens_q * sizeof(int32_t); +} + +bool is_nvte_ck_small_seq_enabled() { + if (transformer_engine::cuda::sm_arch() != 94) { + return false; + } + const char* env_p = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ"); + return env_p != nullptr && std::strcmp(env_p, "1") == 0; +} + +#ifndef USE_FUSED_ATTN_CK + +bool fused_attn_smallseq_fwd(size_t, size_t, size_t, size_t, size_t, float, + const void*, const void*, const void*, void*, void*, + const void*, const void*, const void*, const void*, + const void*, NVTEDType, cudaStream_t) { + return false; +} + +bool fused_attn_smallseq_bwd(size_t, size_t, size_t, size_t, size_t, float, + const void*, const void*, const void*, const void*, const void*, + void*, void*, void*, + const void*, const void*, const void*, const void*, + NVTEDType, cudaStream_t) { + return false; +} + +#else // USE_FUSED_ATTN_CK + +// HIP small-seq kernels: head dim 512 is not supported here — upstream instantiations exceed the +// 64 KiB LDS limit on CDNA (gfx942 / gfx950) for the 17×17 small-seq tile configuration. +namespace { + +constexpr int kMaxBsInst = 16384; + +hipStream_t to_hip_stream(cudaStream_t s) { + return reinterpret_cast(s); +} + +template +bool launch_fwd_inst(size_t actual_batch, + float attn_scale, + const T* Q, + const T* K, + const T* V, + T* O, + float* softmax_lse, + const int* cu_q, + const int* cu_qp, + const int* cu_kv, + const int* cu_kvp, + const int* padded_q_to_batch, + int total_padded_q, + hipStream_t stream) { + if(actual_batch > static_cast(kMaxBsInst)) { + return false; + } + using Config = + FmhaKernelConfig; + using Launcher = AttnForwardMfmaDispatchLauncher; + const float sqr_dk_scale = attn_scale / std::sqrt(static_cast(HEAD_DIM)); + Launcher::run_attn_fwd_kernel(Q, K, V, nullptr, 0.0f, sqr_dk_scale, O, softmax_lse, cu_q, cu_qp, + cu_kv, cu_kvp, padded_q_to_batch, total_padded_q); + NVTE_CHECK_CUDA(hipStreamSynchronize(stream)); + return true; +} + +template +bool launch_fwd_dispatch(size_t batch, + size_t num_heads, + int head_dim, + float attn_scale, + const T* Q, + const T* K, + const T* V, + T* O, + float* softmax_lse, + const int* cu_q, + const int* cu_qp, + const int* cu_kv, + const int* cu_kvp, + const int* padded_q_to_batch, + int total_padded_q, + hipStream_t stream) { + if(num_heads == 16) { + if(head_dim == 128) { + return launch_fwd_inst(batch, attn_scale, Q, K, V, O, softmax_lse, cu_q, cu_qp, + cu_kv, cu_kvp, padded_q_to_batch, total_padded_q, stream); + } + if(head_dim == 256) { + return launch_fwd_inst(batch, attn_scale, Q, K, V, O, softmax_lse, cu_q, cu_qp, + cu_kv, cu_kvp, padded_q_to_batch, total_padded_q, stream); + } + } + if(num_heads == 32) { + if(head_dim == 128) { + return launch_fwd_inst(batch, attn_scale, Q, K, V, O, softmax_lse, cu_q, cu_qp, + cu_kv, cu_kvp, padded_q_to_batch, total_padded_q, stream); + } + if(head_dim == 256) { + return launch_fwd_inst(batch, attn_scale, Q, K, V, O, softmax_lse, cu_q, cu_qp, + cu_kv, cu_kvp, padded_q_to_batch, total_padded_q, stream); + } + } + return false; +} + +template +bool launch_bwd_inst(size_t actual_batch, + float attn_scale, + const T* Q, + const T* K, + const T* V, + const T* dO, + const float* softmax_lse, + T* dQ, + T* dK, + T* dV, + const int* cu_q, + const int* cu_qp, + const int* cu_kv, + const int* cu_kvp, + hipStream_t stream) { + if(actual_batch > static_cast(kMaxBsInst)) { + return false; + } + using Config = + FmhaKernelConfig; + using Launcher = AttnBackwardMfma16x16KernelLauncher; + const float sqr_dk_scale = attn_scale / std::sqrt(static_cast(HEAD_DIM)); + Launcher::run_attn_bwd_kernel(Q, K, V, dO, softmax_lse, dQ, dK, dV, sqr_dk_scale, cu_q, cu_qp, + cu_kv, cu_kvp); + NVTE_CHECK_CUDA(hipStreamSynchronize(stream)); + return true; +} + +template +bool launch_bwd_dispatch(size_t batch, + size_t num_heads, + int head_dim, + float attn_scale, + const T* Q, + const T* K, + const T* V, + const T* dO, + const float* softmax_lse, + T* dQ, + T* dK, + T* dV, + const int* cu_q, + const int* cu_qp, + const int* cu_kv, + const int* cu_kvp, + hipStream_t stream) { + if(num_heads == 16) { + if(head_dim == 128) { + return launch_bwd_inst(batch, attn_scale, Q, K, V, dO, softmax_lse, dQ, dK, dV, + cu_q, cu_qp, cu_kv, cu_kvp, stream); + } + if(head_dim == 256) { + return launch_bwd_inst(batch, attn_scale, Q, K, V, dO, softmax_lse, dQ, dK, dV, + cu_q, cu_qp, cu_kv, cu_kvp, stream); + } + } + if(num_heads == 32) { + if(head_dim == 128) { + return launch_bwd_inst(batch, attn_scale, Q, K, V, dO, softmax_lse, dQ, dK, dV, + cu_q, cu_qp, cu_kv, cu_kvp, stream); + } + if(head_dim == 256) { + return launch_bwd_inst(batch, attn_scale, Q, K, V, dO, softmax_lse, dQ, dK, dV, + cu_q, cu_qp, cu_kv, cu_kvp, stream); + } + } + return false; +} + +} // namespace + +bool fused_attn_smallseq_fwd(size_t batch_size, + size_t num_heads, + size_t head_dim_qk, + size_t max_tokens_q, + size_t max_tokens_kv, + float attn_scale, + const void* dev_ptr_q, + const void* dev_ptr_k, + const void* dev_ptr_v, + void* dev_ptr_o, + void* dev_ptr_softmax_lse, + const void* dev_ptr_cu_seqlens_q, + const void* dev_ptr_cu_seqlens_q_padded, + const void* dev_ptr_cu_seqlens_kv, + const void* dev_ptr_cu_seqlens_kv_padded, + const void* dev_ptr_padded_q_to_batch, + NVTEDType dtype, + cudaStream_t stream) { + (void)max_tokens_kv; + const int* cu_q = static_cast(dev_ptr_cu_seqlens_q); + const int* cu_qp = static_cast(dev_ptr_cu_seqlens_q_padded); + const int* cu_kv = static_cast(dev_ptr_cu_seqlens_kv); + const int* cu_kvp = static_cast(dev_ptr_cu_seqlens_kv_padded); + const int* padded_q_to_batch = static_cast(dev_ptr_padded_q_to_batch); + float* softmax_lse = static_cast(dev_ptr_softmax_lse); + const int total_padded_q = static_cast(max_tokens_q); + const int hd = static_cast(head_dim_qk); + const hipStream_t hip_stream = to_hip_stream(stream); + + if(!supports_hip_small_seq(num_heads, num_heads, head_dim_qk, head_dim_qk)) { + return false; + } + + if(dtype == NVTEDType::kNVTEBFloat16) { + using T = hip_bfloat16; + const T* Q = static_cast(dev_ptr_q); + const T* K = static_cast(dev_ptr_k); + const T* V = static_cast(dev_ptr_v); + T* O = static_cast(dev_ptr_o); + return launch_fwd_dispatch(batch_size, num_heads, hd, attn_scale, Q, K, V, O, softmax_lse, + cu_q, cu_qp, cu_kv, cu_kvp, padded_q_to_batch, total_padded_q, + hip_stream); + } + if(dtype == NVTEDType::kNVTEFloat16) { + using T = __half; + const T* Q = static_cast(dev_ptr_q); + const T* K = static_cast(dev_ptr_k); + const T* V = static_cast(dev_ptr_v); + T* O = static_cast(dev_ptr_o); + return launch_fwd_dispatch(batch_size, num_heads, hd, attn_scale, Q, K, V, O, softmax_lse, + cu_q, cu_qp, cu_kv, cu_kvp, padded_q_to_batch, total_padded_q, + hip_stream); + } + return false; +} + +bool fused_attn_smallseq_bwd(size_t batch_size, + size_t num_heads, + size_t head_dim_qk, + size_t max_tokens_q, + size_t max_tokens_kv, + float attn_scale, + const void* dev_ptr_q, + const void* dev_ptr_k, + const void* dev_ptr_v, + const void* dev_ptr_do, + const void* dev_ptr_softmax_lse, + void* dev_ptr_dq, + void* dev_ptr_dk, + void* dev_ptr_dv, + const void* dev_ptr_cu_seqlens_q, + const void* dev_ptr_cu_seqlens_q_padded, + const void* dev_ptr_cu_seqlens_kv, + const void* dev_ptr_cu_seqlens_kv_padded, + NVTEDType dtype, + cudaStream_t stream) { + (void)max_tokens_q; + (void)max_tokens_kv; + const int* cu_q = static_cast(dev_ptr_cu_seqlens_q); + const int* cu_qp = static_cast(dev_ptr_cu_seqlens_q_padded); + const int* cu_kv = static_cast(dev_ptr_cu_seqlens_kv); + const int* cu_kvp = static_cast(dev_ptr_cu_seqlens_kv_padded); + const float* softmax_lse = static_cast(dev_ptr_softmax_lse); + const int hd = static_cast(head_dim_qk); + const hipStream_t hip_stream = to_hip_stream(stream); + + if(!supports_hip_small_seq(num_heads, num_heads, head_dim_qk, head_dim_qk)) { + return false; + } + + if(dtype == NVTEDType::kNVTEBFloat16) { + using T = hip_bfloat16; + const T* Q = static_cast(dev_ptr_q); + const T* K = static_cast(dev_ptr_k); + const T* V = static_cast(dev_ptr_v); + const T* dO = static_cast(dev_ptr_do); + T* dQ = static_cast(dev_ptr_dq); + T* dK = static_cast(dev_ptr_dk); + T* dV = static_cast(dev_ptr_dv); + return launch_bwd_dispatch(batch_size, num_heads, hd, attn_scale, Q, K, V, dO, softmax_lse, + dQ, dK, dV, cu_q, cu_qp, cu_kv, cu_kvp, hip_stream); + } + if(dtype == NVTEDType::kNVTEFloat16) { + using T = __half; + const T* Q = static_cast(dev_ptr_q); + const T* K = static_cast(dev_ptr_k); + const T* V = static_cast(dev_ptr_v); + const T* dO = static_cast(dev_ptr_do); + T* dQ = static_cast(dev_ptr_dq); + T* dK = static_cast(dev_ptr_dk); + T* dV = static_cast(dev_ptr_dv); + return launch_bwd_dispatch(batch_size, num_heads, hd, attn_scale, Q, K, V, dO, softmax_lse, + dQ, dK, dV, cu_q, cu_qp, cu_kv, cu_kvp, hip_stream); + } + return false; +} + +#endif // USE_FUSED_ATTN_CK + +} // namespace fused_attn_rocm +} // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h new file mode 100644 index 000000000..412c66753 --- /dev/null +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_smallseq.h @@ -0,0 +1,84 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALLSEQ_H_ +#define TRANSFORMER_ENGINE_FUSED_ATTN_ROCM_FUSED_ATTN_SMALLSEQ_H_ + +#include + +#include + +namespace transformer_engine { +namespace fused_attn_rocm { + +constexpr size_t kSmallSeqMaxSeqlen = 17; + +/** Static config only (no packed max seqlen as proof of per-segment lengths). */ +bool small_seq_static_config_ok(NVTEDType q_dtype, + NVTEDType kv_dtype, + NVTE_Bias_Type bias_type, + float dropout, + size_t head_dim_qk, + size_t head_dim_v, + size_t num_attn_heads, + size_t num_gqa_groups, + NVTE_Mask_Type mask_type); + +bool is_runtime_small_seq_eligible(size_t runtime_max_seqlen_q, size_t runtime_max_seqlen_kv); + +bool supports_hip_small_seq(size_t num_attn_heads, + size_t num_gqa_groups, + size_t head_dim_qk, + size_t head_dim_v); + +size_t small_seq_extra_workspace_bytes(size_t max_tokens_q); + +bool is_nvte_ck_small_seq_enabled(); + +bool fused_attn_smallseq_fwd(size_t batch_size, + size_t num_heads, + size_t head_dim_qk, + size_t max_tokens_q, + size_t max_tokens_kv, + float attn_scale, + const void* dev_ptr_q, + const void* dev_ptr_k, + const void* dev_ptr_v, + void* dev_ptr_o, + void* dev_ptr_softmax_lse, + const void* dev_ptr_cu_seqlens_q, + const void* dev_ptr_cu_seqlens_q_padded, + const void* dev_ptr_cu_seqlens_kv, + const void* dev_ptr_cu_seqlens_kv_padded, + const void* dev_ptr_padded_q_to_batch, + NVTEDType dtype, + cudaStream_t stream); + +bool fused_attn_smallseq_bwd(size_t batch_size, + size_t num_heads, + size_t head_dim_qk, + size_t max_tokens_q, + size_t max_tokens_kv, + float attn_scale, + const void* dev_ptr_q, + const void* dev_ptr_k, + const void* dev_ptr_v, + const void* dev_ptr_do, + const void* dev_ptr_softmax_lse, + void* dev_ptr_dq, + void* dev_ptr_dk, + void* dev_ptr_dv, + const void* dev_ptr_cu_seqlens_q, + const void* dev_ptr_cu_seqlens_q_padded, + const void* dev_ptr_cu_seqlens_kv, + const void* dev_ptr_cu_seqlens_kv_padded, + NVTEDType dtype, + cudaStream_t stream); + +} // namespace fused_attn_rocm +} // namespace transformer_engine + +#endif diff --git a/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_bwd.h b/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_bwd.h new file mode 100644 index 000000000..c9856ed6a --- /dev/null +++ b/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_bwd.h @@ -0,0 +1,565 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "attn_common.h" +#include + +// --------------------------------------------------------------------------- +// Kernel 1: compute_grad_v_kernel +// +// Computes grad_V = attn_weights^T @ grad_O +// attn_weights layout: [total_padded_q, head_num, max_seq_kv] +// grad_O layout: [total_padded_q, head_num, head_dim] +// grad_V layout: [total_padded_kv_seq, head_num, head_dim] +// --------------------------------------------------------------------------- + +template +__global__ void compute_grad_v_kernel(const T* attn_weights, + const T* grad_O, + T* grad_V, + const int* cu_seqlens_q, + const int* cu_seqlens_q_padded, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded) +{ + constexpr int seq_q = Config::seq_q; // == 1 + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = BLOCK_K; + constexpr int dwordx4_load_elt = 16 / sizeof(T); + constexpr int warp_size = 64; + constexpr int process_head_per_warp = warp_size / (head_dim / block_k); + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * process_head_per_warp * tasks_per_block; + int thread_id = threadIdx.x; + int thread_batch_offset = thread_id / (head_dim / block_k); + int thread_head_offset = thread_id % (head_dim / block_k) * block_k; + + uint4 load_dwordx4_tmp_var[block_k / dwordx4_load_elt]; + T attn[max_seq_kv]; + + for(int task = 0; task < tasks_per_block; task++) + { + int block_batch_head_idx = base_block_offset + task * process_head_per_warp; + int cur_idx = block_batch_head_idx + thread_batch_offset; + + int batch_idx = cur_idx / (Config::seq_q * Config::head_num); + int seq_head_idx = cur_idx % (Config::seq_q * Config::head_num); + int head_idx = seq_head_idx % Config::head_num; + + if(batch_idx >= Config::bs) + continue; + + // Skip batches where actual Q seq is 0 — no grad_O to read from. + int actual_seq_q = cu_seqlens_q[batch_idx + 1] - cu_seqlens_q[batch_idx]; + if(actual_seq_q == 0) + continue; + + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + int q_storage_offset = cu_seqlens_q_padded[batch_idx]; // seq_q_idx == 0 + + // attn_weights layout: [total_padded_q, head_num, max_seq_kv] + int attn_offset = (q_storage_offset * Config::head_num + head_idx) * max_seq_kv; +#pragma unroll + for(int i = 0; i < max_seq_kv; i++) + attn[i] = attn_weights[attn_offset + i]; + + // Compute grad_V = attn_weights^T @ grad_O + for(int j = 0; j < seq_kv; j++) + { + uint4 store_dwordx4_tmp_var[block_k / dwordx4_load_elt]; +#pragma unroll + for(int i = 0; i < block_k / dwordx4_load_elt; i++) + { + store_dwordx4_tmp_var[i].x = 0; + store_dwordx4_tmp_var[i].y = 0; + store_dwordx4_tmp_var[i].z = 0; + store_dwordx4_tmp_var[i].w = 0; + } + + // grad_O layout: [total_padded_seq_q, head_num, head_dim] +#pragma unroll + for(int i = 0; i < block_k / dwordx4_load_elt; i++) + { + load_dwordx4_tmp_var[i] = + *((uint4*)&grad_O[(q_storage_offset * Config::head_num + head_idx) * head_dim + + thread_head_offset + i * dwordx4_load_elt]); + } + +#pragma unroll + for(int b = 0; b < block_k; b++) + { + ((T*)&store_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt] += + attn[j] * + ((T*)&load_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt]; + } + +#pragma unroll + for(int i = 0; i < block_k / dwordx4_load_elt; i++) + { + int grad_v_idx = + (cu_seqlens_kv_padded[batch_idx] + j) * Config::head_num * head_dim + + head_idx * head_dim + thread_head_offset + i * dwordx4_load_elt; + *((uint4*)&grad_V[grad_v_idx]) = store_dwordx4_tmp_var[i]; + } + } + } +} + +// --------------------------------------------------------------------------- +// Kernel 2: compute_grad_attn_kernel +// +// Computes grad_attn = grad_O @ V^T (same structure as compute_scores_kernel) +// grad_O layout: [total_padded_q, head_num, head_dim] +// V layout: [total_padded_kv_seq, head_num, head_dim] +// grad_attn layout: [total_padded_q, head_num, max_seq_kv] (workspace reuse) +// --------------------------------------------------------------------------- + +template +__global__ void compute_grad_attn_kernel(const T* grad_O, + const T* V, + T* grad_attn, + const int* cu_seqlens_q, + const int* cu_seqlens_q_padded, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded) +{ + constexpr int seq_q = Config::seq_q; // == 1 + static_assert(seq_q == 1, "seq_q must be 1 for this kernel implementation."); + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = 64; + constexpr int thread_block_size = 64; + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * thread_block_size * tasks_per_block; + int thread_id = threadIdx.x; + + for(int task = 0; task < tasks_per_block; task++) + { + int cur_batch_idx = base_block_offset + task * thread_block_size + thread_id; + int batch_idx = cur_batch_idx / (Config::seq_q * Config::head_num); + int seq_head_idx = cur_batch_idx % (Config::seq_q * Config::head_num); + int head_idx = seq_head_idx % Config::head_num; + + if(batch_idx >= Config::bs) + continue; + + // Skip batches where actual Q seq is 0 — no row exists in workspace for them. + int actual_seq_q = cu_seqlens_q[batch_idx + 1] - cu_seqlens_q[batch_idx]; + if(actual_seq_q == 0) + continue; + + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + int q_storage_offset = cu_seqlens_q_padded[batch_idx]; // seq_idx == 0 + + float results[max_seq_kv]; + T fetch_grad_O[block_k]; + T fetch_V[block_k]; + + // grad_O layout: [total_padded_seq_q, head_num, head_dim] + T* grad_O_ptr = + (T*)&grad_O[(q_storage_offset * Config::head_num + head_idx) * head_dim]; + + const T* V_base = + &V[cu_seqlens_kv_padded[batch_idx] * Config::head_num * head_dim + head_idx * head_dim]; + int V_stride = Config::head_num * head_dim; + + // workspace layout: [total_padded_q, head_num, max_seq_kv] + T* grad_attn_ptr = (T*)&grad_attn[(q_storage_offset * Config::head_num + head_idx) * max_seq_kv]; + + uint4 ls_dwordx4_tmp_var; + + for(int i = 0; i < seq_kv; i++) + results[i] = 0.0f; + + for(int dim_offset = 0; dim_offset < head_dim; dim_offset += block_k) + { + if constexpr(std::is_same::value) + { + for(int k = 0; k < block_k / 8; k++) + { + ls_dwordx4_tmp_var = *((uint4*)&grad_O_ptr[dim_offset + k * 8]); + fetch_grad_O[k * 8 + 0] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[0]; + fetch_grad_O[k * 8 + 1] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[1]; + fetch_grad_O[k * 8 + 2] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[0]; + fetch_grad_O[k * 8 + 3] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[1]; + fetch_grad_O[k * 8 + 4] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[0]; + fetch_grad_O[k * 8 + 5] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[1]; + fetch_grad_O[k * 8 + 6] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[0]; + fetch_grad_O[k * 8 + 7] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[1]; + } + + for(int kv_idx = 0; kv_idx < seq_kv; kv_idx++) + { + for(int k = 0; k < block_k / 8; k++) + { + ls_dwordx4_tmp_var = + *((uint4*)&V_base[kv_idx * V_stride + dim_offset + k * 8]); + fetch_V[k * 8 + 0] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[0]; + fetch_V[k * 8 + 1] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[1]; + fetch_V[k * 8 + 2] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[0]; + fetch_V[k * 8 + 3] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[1]; + fetch_V[k * 8 + 4] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[0]; + fetch_V[k * 8 + 5] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[1]; + fetch_V[k * 8 + 6] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[0]; + fetch_V[k * 8 + 7] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[1]; + } +#pragma unroll + for(int k = 0; k < block_k; k++) + { + results[kv_idx] += + static_cast(fetch_grad_O[k]) * static_cast(fetch_V[k]); + } + } + } + else + { + for(int k = 0; k < block_k / 4; k++) + { + ls_dwordx4_tmp_var = *((uint4*)&grad_O_ptr[dim_offset + k * 4]); + fetch_grad_O[k * 4 + 0] = *((T*)&ls_dwordx4_tmp_var.x); + fetch_grad_O[k * 4 + 1] = *((T*)&ls_dwordx4_tmp_var.y); + fetch_grad_O[k * 4 + 2] = *((T*)&ls_dwordx4_tmp_var.z); + fetch_grad_O[k * 4 + 3] = *((T*)&ls_dwordx4_tmp_var.w); + } + + for(int kv_idx = 0; kv_idx < seq_kv; kv_idx++) + { + for(int k = 0; k < block_k / 4; k++) + { + ls_dwordx4_tmp_var = + *((uint4*)&V_base[kv_idx * V_stride + dim_offset + k * 4]); + fetch_V[k * 4 + 0] = *((T*)&ls_dwordx4_tmp_var.x); + fetch_V[k * 4 + 1] = *((T*)&ls_dwordx4_tmp_var.y); + fetch_V[k * 4 + 2] = *((T*)&ls_dwordx4_tmp_var.z); + fetch_V[k * 4 + 3] = *((T*)&ls_dwordx4_tmp_var.w); + } +#pragma unroll + for(int k = 0; k < block_k; k++) + { + results[kv_idx] += fetch_grad_O[k] * fetch_V[k]; + } + } + } + } + + for(int i = 0; i < seq_kv; i++) + { + grad_attn_ptr[i] = T(results[i]); + } + // Zero out padding positions beyond seq_kv + for(int i = seq_kv; i < max_seq_kv; i++) + { + grad_attn_ptr[i] = T(0.0f); + } + } +} + +// --------------------------------------------------------------------------- +// Kernel 3: softmax_backward_kernel +// +// Softmax backprop: attn * (grad_attn - sum(grad_attn * attn)) +// Writes the result back into grad_attn (workspace reuse as grad_scores). +// --------------------------------------------------------------------------- + +template +__global__ void softmax_backward_kernel(const T* attn_weights, + const T* dropout_mask, + T* grad_attn, + float dropout_scale, + const int* cu_seqlens_kv, + const int* padded_q_to_batch, + uint32_t total_elt) +{ + const uint32_t block_id = blockIdx.x; + const uint32_t thread_id = threadIdx.x; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int block_size = Config::step2_block_size; + constexpr int per_grad_attn_size = max_seq_kv; // seq_q == 1 + constexpr int valid_thread_range = block_size / per_grad_attn_size * per_grad_attn_size; + const uint32_t cur_block_offset = block_id * valid_thread_range + thread_id; + bool is_tail = block_id * valid_thread_range + block_size >= total_elt; + int real_row_num = is_tail ? (total_elt - block_id * valid_thread_range) / max_seq_kv + : valid_thread_range / max_seq_kv; + + if(cur_block_offset < total_elt && thread_id < valid_thread_range) + { + __shared__ T tmp_grad_score[valid_thread_range]; + constexpr int row_num = valid_thread_range / max_seq_kv; + __shared__ T reduce_grad_score[row_num]; + + // [total_padded_q, head_num, max_seq_kv] flat layout + int global_row_idx = cur_block_offset / max_seq_kv; + int padded_q_slot = global_row_idx / Config::head_num; + int k_idx = cur_block_offset % max_seq_kv; + + // All rows in the buffer belong to active batches (empty-Q batches have no row). + int batch_idx = padded_q_to_batch[padded_q_slot]; + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + + T grad_attn_value = grad_attn[cur_block_offset]; + if constexpr(Config::enable_dropout_mask) + { + grad_attn_value = grad_attn_value * dropout_mask[cur_block_offset] * dropout_scale; + } + T attn_weight = attn_weights[cur_block_offset]; + T grad_score = grad_attn_value * attn_weight; + tmp_grad_score[thread_id] = grad_score; + __syncthreads(); + + // Reduce within block + if(thread_id < real_row_num) + { + T sum = T(0.0f); +#pragma unroll + for(int i = 0; i < max_seq_kv; i++) + sum += tmp_grad_score[thread_id * max_seq_kv + i]; + reduce_grad_score[thread_id] = sum; + } + __syncthreads(); + + grad_score -= attn_weight * reduce_grad_score[thread_id / max_seq_kv]; + + // Apply causal mask and KV-padding mask + if constexpr(Config::mask_type == CausalMaskType::TOP_LEFT) + { + // q_idx == 0; mask: k_idx > 0 || k_idx >= seq_kv + if(k_idx > 0 || k_idx >= seq_kv) + grad_score = T(0.0f); + } + else if constexpr(Config::mask_type == CausalMaskType::BOTTOM_RIGHT) + { + if(k_idx >= seq_kv) + grad_score = T(0.0f); + } + else + { + if(k_idx >= seq_kv) + grad_score = T(0.0f); + } + + grad_attn[cur_block_offset] = grad_score; + } +} + +// --------------------------------------------------------------------------- +// Kernel 4: compute_grad_qk_kernel +// +// Computes grad_Q = grad_scores @ K * scale +// grad_K = grad_scores^T @ Q * scale +// --------------------------------------------------------------------------- + +template +__global__ void compute_grad_qk_kernel(const T* grad_scores, + const T* Q, + const T* K, + T* grad_Q, + T* grad_K, + float scale, + const int* cu_seqlens_q, + const int* cu_seqlens_q_padded, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = BLOCK_K; + constexpr int dwordx4_load_elt = 16 / sizeof(T); + constexpr int warp_size = 64; + constexpr int process_head_per_warp = warp_size / (head_dim / block_k); + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * process_head_per_warp * tasks_per_block; + int thread_id = threadIdx.x; + int thread_batch_offset = thread_id / (head_dim / block_k); + int thread_head_offset = thread_id % (head_dim / block_k) * block_k; + + uint4 load_dwordx4_tmp_var[block_k / dwordx4_load_elt]; + T grad_score_vals[max_seq_kv]; + + for(int task = 0; task < tasks_per_block; task++) + { + int block_batch_head_idx = base_block_offset + task * process_head_per_warp; + int cur_idx = block_batch_head_idx + thread_batch_offset; + + int batch_idx = cur_idx / (Config::seq_q * Config::head_num); + int seq_head_idx = cur_idx % (Config::seq_q * Config::head_num); + int seq_q_idx = seq_head_idx / Config::head_num; + int head_idx = seq_head_idx % Config::head_num; + + if(batch_idx >= Config::bs) + continue; + + // Skip batches where actual Q seq is 0 — no grad_Q/grad_K to compute. + int actual_seq_q = cu_seqlens_q[batch_idx + 1] - cu_seqlens_q[batch_idx]; + if(actual_seq_q == 0) + continue; + + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + int q_storage_offset = cu_seqlens_q_padded[batch_idx]; // seq_q_idx == 0 + + // workspace layout: [total_padded_q, head_num, max_seq_kv] + int gs_offset = (q_storage_offset * Config::head_num + head_idx) * max_seq_kv; +#pragma unroll + for(int i = 0; i < max_seq_kv; i++) + grad_score_vals[i] = grad_scores[gs_offset + i]; + + // Compute grad_Q = grad_scores @ K * scale + uint4 store_dwordx4_tmp_var[block_k / dwordx4_load_elt]; +#pragma unroll + for(int i = 0; i < block_k / dwordx4_load_elt; i++) + { + store_dwordx4_tmp_var[i].x = 0; + store_dwordx4_tmp_var[i].y = 0; + store_dwordx4_tmp_var[i].z = 0; + store_dwordx4_tmp_var[i].w = 0; + } + + for(int j = 0; j < seq_kv; j++) + { +#pragma unroll + for(int i = 0; i < block_k / dwordx4_load_elt; i++) + { + int k_idx = (cu_seqlens_kv_padded[batch_idx] + j) * Config::head_num * head_dim + + head_idx * head_dim + thread_head_offset + i * dwordx4_load_elt; + load_dwordx4_tmp_var[i] = *((uint4*)&K[k_idx]); + } +#pragma unroll + for(int b = 0; b < block_k; b++) + { + ((T*)&store_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt] += + grad_score_vals[j] * + ((T*)&load_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt]; + } + } + +#pragma unroll + for(int i = 0; i < block_k / dwordx4_load_elt; i++) + { + // grad_Q layout: [total_padded_seq_q, head_num, head_dim] + T* grad_Q_ptr = &grad_Q[(q_storage_offset * Config::head_num + head_idx) * head_dim + + thread_head_offset + i * dwordx4_load_elt]; + for(int b = 0; b < dwordx4_load_elt; b++) + { + grad_Q_ptr[b] = ((T*)&store_dwordx4_tmp_var[i])[b] * scale; + } + } + + // Compute grad_K = grad_scores^T @ Q * scale + // Q layout: [total_padded_seq_q, head_num, head_dim] +#pragma unroll + for(int i = 0; i < block_k / dwordx4_load_elt; i++) + { + load_dwordx4_tmp_var[i] = + *((uint4*)&Q[(q_storage_offset * Config::head_num + head_idx) * head_dim + + thread_head_offset + i * dwordx4_load_elt]); + } + + for(int j = 0; j < seq_kv; j++) + { +#pragma unroll + for(int b = 0; b < block_k; b++) + { + T val = grad_score_vals[j] * + ((T*)&load_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt] * + T(scale); + int grad_k_idx = + (cu_seqlens_kv_padded[batch_idx] + j) * Config::head_num * head_dim + + head_idx * head_dim + thread_head_offset + b; + grad_K[grad_k_idx] = val; + } + } + } +} + +// --------------------------------------------------------------------------- +// AttnBackwardKernelLauncher +// +// Orchestrates the 4-kernel backward pipeline: +// 1. compute_grad_v_kernel (attn_weights^T @ grad_O) +// 2. compute_grad_attn_kernel (grad_O @ V^T) +// 3. softmax_backward_kernel +// 4. compute_grad_qk_kernel (grad_scores @ K, grad_scores^T @ Q) +// --------------------------------------------------------------------------- + +template +struct AttnBackwardKernelLauncher +{ + // workspace layout: [total_padded_q, head_num, max_seq_kv] + static size_t calc_workspace_size(int total_padded_q) + { + constexpr int head_num = Config::head_num; + constexpr int max_seq_kv = Config::max_seq_kv; + return (size_t)total_padded_q * head_num * max_seq_kv * sizeof(T); + } + + static void run_attn_bwd_kernel(const T* Q, + const T* K, + const T* V, + const T* grad_O, + const T* attn_weights, + const T* dropout_mask, + float dropout_p, + float sqr_dk_scale, + T* grad_Q, + T* grad_K, + T* grad_V, + T* workspace, + const int* cu_seqlens_q, + const int* cu_seqlens_q_padded, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + const int* padded_q_to_batch, + int total_padded_q) + { + constexpr int bs = Config::bs; + constexpr int head_num = Config::head_num; + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int warp_size = 64; + + constexpr int merge_bs = bs * head_num; + float scale = sqr_dk_scale; + float dropout_scale = (dropout_p > 0.0f) ? (1.0f / (1.0f - dropout_p)) : 1.0f; + + dim3 block(warp_size); + + // Step 1: Compute grad_V = attn_weights^T @ grad_O — grid covers all (bs * head_num) tasks + constexpr int tasks_per_block_v = 16; + dim3 grid_v((bs * seq_q * head_num + tasks_per_block_v - 1) / tasks_per_block_v); + compute_grad_v_kernel<<>>( + attn_weights, grad_O, grad_V, cu_seqlens_q, cu_seqlens_q_padded, cu_seqlens_kv, + cu_seqlens_kv_padded); + + // Step 2: Compute grad_attn = grad_O @ V^T — grid covers all (bs * head_num) tasks + constexpr int tasks_per_block_attn = 16; + constexpr int process_head_per_warp = warp_size / (head_dim / 64); + dim3 grid_grad_attn( + (bs * seq_q * head_num + tasks_per_block_attn * process_head_per_warp - 1) / + (tasks_per_block_attn * process_head_per_warp)); + compute_grad_attn_kernel<<>>( + grad_O, V, workspace, cu_seqlens_q, cu_seqlens_q_padded, cu_seqlens_kv, + cu_seqlens_kv_padded); + + // Step 3: Softmax backward — grid covers [total_padded_q, head_num, max_seq_kv] elements + constexpr int work_thread_num = Config::step2_block_size / max_seq_kv * max_seq_kv; + uint32_t total_elt = (uint32_t)total_padded_q * head_num * max_seq_kv; + dim3 grid_softmax((total_elt + work_thread_num - 1) / work_thread_num); + dim3 block_softmax(Config::step2_block_size); + softmax_backward_kernel<<>>( + attn_weights, dropout_mask, workspace, dropout_scale, cu_seqlens_kv, + padded_q_to_batch, total_elt); + + // Step 4: Compute grad_Q and grad_K — grid covers all (bs * head_num) tasks + constexpr int tasks_per_block_qk = 4; + dim3 grid_qk((bs * seq_q * head_num + tasks_per_block_qk - 1) / tasks_per_block_qk); + compute_grad_qk_kernel<<>>( + workspace, Q, K, grad_Q, grad_K, scale, cu_seqlens_q, cu_seqlens_q_padded, + cu_seqlens_kv, cu_seqlens_kv_padded); + } +}; diff --git a/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_bwd_mfma_16x16.h b/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_bwd_mfma_16x16.h new file mode 100644 index 000000000..9b6b50b9b --- /dev/null +++ b/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_bwd_mfma_16x16.h @@ -0,0 +1,693 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "attn_common.h" +#include +#include + +#ifndef ATTN_MFMA_TYPES_DEFINED +#define ATTN_MFMA_TYPES_DEFINED +using bhalf_t = __bf16; +using bf16x4 = __bf16 __attribute__((ext_vector_type(4))); +using bf16x8 = __bf16 __attribute__((ext_vector_type(8))); +using floatx4 = float __attribute__((ext_vector_type(4))); +#endif + +#ifndef CEIL_DIV +#define CEIL_DIV(a, b) (((a) + (b)-1) / (b)) +#endif + +template +__device__ __forceinline__ bf16x8 bwd_load_cvt_bf16x8(const T* src) +{ + if constexpr(sizeof(T) == 2) + { + return *(const bf16x8*)src; + } + else + { + bf16x8 r; + #pragma unroll + for(int i = 0; i < 8; i++) + r[i] = static_cast(src[i]); + return r; + } +} + +// --------------------------------------------------------------------------- +// grad_V kernel: grad_V = attn^T @ grad_O +// Grid: (1, head_num, bs), Block: 256 +// --------------------------------------------------------------------------- + +template +__launch_bounds__(256, 1) +__global__ void fmha_bwd_grad_v_mfma_16x16_kernel( + const T* Q, + const T* K, + const float* softmax_lse, + const T* grad_O, + T* grad_V, + float scale, + const int* cu_seqlens_q, + const int* cu_seqlens_q_padded, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded) +{ + constexpr int head_dim = Config::head_dim; + constexpr int head_num = Config::head_num; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int max_seq_q = Config::max_seq_q; + constexpr int hd_pad = head_dim + 4; + constexpr int q_tiles = CEIL_DIV(max_seq_q, 16); + constexpr int kv_tiles = CEIL_DIV(max_seq_kv, 16); + constexpr int lds_q_rows = q_tiles * 16; + constexpr int lds_kv_rows = kv_tiles * 16; + constexpr int attn_pad = lds_kv_rows + 4; + + const int batch_idx = blockIdx.z; + const int head_idx = blockIdx.y; + const int tid = threadIdx.x; + const int warp_id = tid / 64; + const int lane_id = tid % 64; + const int lane_row = lane_id / 16; + const int lane_col = lane_id % 16; + + const int actual_q = cu_seqlens_q[batch_idx + 1] - cu_seqlens_q[batch_idx]; + if(actual_q == 0) + return; + + const int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + const int q_offset = cu_seqlens_q_padded[batch_idx]; + const int kv_offset = cu_seqlens_kv_padded[batch_idx]; + + __shared__ __attribute__((aligned(128))) float attn_lds[lds_q_rows * attn_pad]; + __shared__ __attribute__((aligned(128))) bhalf_t Q_lds_bwd[lds_q_rows * hd_pad]; + __shared__ __attribute__((aligned(128))) bhalf_t K_lds_bwd[lds_kv_rows * hd_pad]; + __shared__ __attribute__((aligned(128))) bhalf_t dO_lds[lds_q_rows * hd_pad]; + + // Load Q → Q_lds_bwd + { + constexpr int threads_per_row = head_dim / 8; + const int row = tid / threads_per_row; + const int col = (tid % threads_per_row) * 8; + + for(int r = row; r < lds_q_rows; r += (256 / threads_per_row)) + { + if(r < actual_q) + { + const T* q_src = Q + ((size_t)(q_offset + r) * head_num + head_idx) * head_dim; + *(bf16x8*)(&Q_lds_bwd[r * hd_pad + col]) = bwd_load_cvt_bf16x8(q_src + col); + } + else + *(bf16x8*)(&Q_lds_bwd[r * hd_pad + col]) = bf16x8{0, 0, 0, 0, 0, 0, 0, 0}; + } + } + + // Load K → K_lds_bwd + { + constexpr int threads_per_row = head_dim / 8; + const int row = tid / threads_per_row; + const int col = (tid % threads_per_row) * 8; + + for(int r = row; r < lds_kv_rows; r += (256 / threads_per_row)) + { + if(r < seq_kv) + { + const T* k_src = K + ((size_t)(kv_offset + r) * head_num + head_idx) * head_dim; + *(bf16x8*)(&K_lds_bwd[r * hd_pad + col]) = bwd_load_cvt_bf16x8(k_src + col); + } + else + *(bf16x8*)(&K_lds_bwd[r * hd_pad + col]) = bf16x8{0, 0, 0, 0, 0, 0, 0, 0}; + } + } + + __syncthreads(); + + // QK^T (same MFMA tiling as forward) → exp(S - LSE) = P + float P_reg[q_tiles * kv_tiles * 4]; + #pragma unroll + for(int qt = 0; qt < q_tiles; qt++) + { + #pragma unroll + for(int kvt = 0; kvt < kv_tiles; kvt++) + { + floatx4 acc = {0, 0, 0, 0}; + constexpr int total_hd_tiles = CEIL_DIV(head_dim, 16); + + #pragma unroll + for(int k = 0; k < total_hd_tiles; ++k) + { + const int dim_off = k * 16; + bf16x4 a = *(const bf16x4*)(&Q_lds_bwd[(qt * 16 + lane_col) * hd_pad + dim_off + lane_row * 4]); + bf16x4 b = *(const bf16x4*)(&K_lds_bwd[(kvt * 16 + lane_col) * hd_pad + dim_off + lane_row * 4]); + acc = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a, b, acc, 0, 0, 0); + } + + int reg_base = (qt * kv_tiles + kvt) * 4; + #pragma unroll + for(int i = 0; i < 4; i++) + { + int q_row = qt * 16 + lane_row * 4 + i; + int kv_pos = kvt * 16 + lane_col; + bool masked = (kv_pos >= seq_kv) || (q_row >= actual_q); + if constexpr(Config::mask_type == CausalMaskType::TOP_LEFT) + { + if(kv_pos > q_row) + masked = true; + } + float S = acc[i] * scale; + float lse = + softmax_lse[((size_t)(q_offset + q_row) * head_num + head_idx)]; + float pr = masked ? 0.0f : expf(S - lse); + P_reg[reg_base + i] = pr; + } + } + } + + // Scatter P_reg → attn_lds (same pattern as former workspace write) + if(warp_id == 0) + { + #pragma unroll + for(int qt = 0; qt < q_tiles; qt++) + { + #pragma unroll + for(int kvt = 0; kvt < kv_tiles; kvt++) + { + #pragma unroll + for(int i = 0; i < 4; i++) + { + int q_row = qt * 16 + lane_row * 4 + i; + int kv_pos = kvt * 16 + lane_col; + if(q_row < actual_q && kv_pos < max_seq_kv) + { + int reg_idx = (qt * kv_tiles + kvt) * 4 + i; + float w = (kv_pos < seq_kv) ? P_reg[reg_idx] : 0.0f; + attn_lds[q_row * attn_pad + kv_pos] = w; + } + } + } + } + } + + __syncthreads(); + + // Load grad_O → dO_lds + { + constexpr int threads_per_row = head_dim / 8; + const int do_row = tid / threads_per_row; + const int do_col = (tid % threads_per_row) * 8; + + for(int r = do_row; r < lds_q_rows; r += (256 / threads_per_row)) + { + if(r < actual_q) + { + const T* do_src = grad_O + ((size_t)(q_offset + r) * head_num + head_idx) * head_dim; + *(bf16x8*)(&dO_lds[r * hd_pad + do_col]) = bwd_load_cvt_bf16x8(do_src + do_col); + } + else + { + *(bf16x8*)(&dO_lds[r * hd_pad + do_col]) = bf16x8{0, 0, 0, 0, 0, 0, 0, 0}; + } + } + } + + __syncthreads(); + + // MFMA: grad_V = attn^T @ grad_O (4 warps split head_dim) + constexpr int BK = 64; + + #pragma unroll + for(int kv_tile = 0; kv_tile < kv_tiles; kv_tile++) + { + constexpr int total_d_tiles = CEIL_DIV(head_dim, BK); + + #pragma unroll + for(int d = 0; d < total_d_tiles; d++) + { + const int dim_idx = d * BK + warp_id * 16; + + floatx4 acc = {0, 0, 0, 0}; + + #pragma unroll + for(int q_tile = 0; q_tile < q_tiles; q_tile++) + { + bf16x4 a; + #pragma unroll + for(int k = 0; k < 4; k++) + { + int q_row = q_tile * 16 + lane_row * 4 + k; + int kv_pos = kv_tile * 16 + lane_col; + float val = (q_row < actual_q && kv_pos < seq_kv) + ? attn_lds[q_row * attn_pad + kv_pos] : 0.0f; + a[k] = static_cast(val); + } + + // B: dO[q, d] + bf16x4 b; + #pragma unroll + for(int k = 0; k < 4; k++) + { + int q_row = q_tile * 16 + lane_row * 4 + k; + b[k] = dO_lds[q_row * hd_pad + dim_idx + lane_col]; + } + + acc = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a, b, acc, 0, 0, 0); + } + + // Write grad_V + #pragma unroll + for(int i = 0; i < 4; i++) + { + int kv_pos = kv_tile * 16 + lane_row * 4 + i; + if(kv_pos < seq_kv) + { + int gv_idx = (kv_offset + kv_pos) * head_num * head_dim + + head_idx * head_dim + dim_idx + lane_col; + grad_V[gv_idx] = static_cast(acc[i]); + } + } + } + } +} + +// --------------------------------------------------------------------------- +// Fused backward kernel: grad_attn → softmax_bwd → grad_Q + grad_K +// Grid: (1, head_num, bs), Block: 256 +// --------------------------------------------------------------------------- + +template +__launch_bounds__(256, 1) +__global__ void fmha_bwd_fused_mfma_16x16_kernel( + const T* Q, + const T* K, + const T* V, + const T* grad_O, + const float* softmax_lse, + T* grad_Q, + T* grad_K, + float scale, + const int* cu_seqlens_q, + const int* cu_seqlens_q_padded, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded) +{ + constexpr int head_dim = Config::head_dim; + constexpr int head_num = Config::head_num; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int max_seq_q = Config::max_seq_q; + constexpr int hd_pad = head_dim + 4; + constexpr int q_tiles = CEIL_DIV(max_seq_q, 16); + constexpr int kv_tiles = CEIL_DIV(max_seq_kv, 16); + constexpr int lds_q_rows = q_tiles * 16; + constexpr int lds_kv_rows = kv_tiles * 16; + constexpr int lds_sm_stride = lds_kv_rows + 4; + + const int batch_idx = blockIdx.z; + const int head_idx = blockIdx.y; + const int tid = threadIdx.x; + const int warp_id = tid / 64; + const int lane_id = tid % 64; + const int lane_row = lane_id / 16; + const int lane_col = lane_id % 16; + + const int actual_q = cu_seqlens_q[batch_idx + 1] - cu_seqlens_q[batch_idx]; + if(actual_q == 0) + return; + + const int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + const int q_offset = cu_seqlens_q_padded[batch_idx]; + const int kv_offset = cu_seqlens_kv_padded[batch_idx]; + + __shared__ __attribute__((aligned(128))) bhalf_t Q_lds[lds_q_rows * hd_pad]; + __shared__ __attribute__((aligned(128))) bhalf_t dO_lds[lds_q_rows * hd_pad]; + __shared__ __attribute__((aligned(128))) bhalf_t KV_lds[lds_kv_rows * hd_pad]; + __shared__ float SM_lds[lds_q_rows * lds_sm_stride]; + + // Load Q → Q_lds + { + constexpr int threads_per_row = head_dim / 8; + const int row = tid / threads_per_row; + const int col = (tid % threads_per_row) * 8; + + for(int r = row; r < lds_q_rows; r += (256 / threads_per_row)) + { + if(r < actual_q) + { + const T* q_src = Q + ((size_t)(q_offset + r) * head_num + head_idx) * head_dim; + *(bf16x8*)(&Q_lds[r * hd_pad + col]) = bwd_load_cvt_bf16x8(q_src + col); + } + else + { + *(bf16x8*)(&Q_lds[r * hd_pad + col]) = bf16x8{0, 0, 0, 0, 0, 0, 0, 0}; + } + } + } + + // Load dO → dO_lds + { + constexpr int threads_per_row = head_dim / 8; + const int row = tid / threads_per_row; + const int col = (tid % threads_per_row) * 8; + + for(int r = row; r < lds_q_rows; r += (256 / threads_per_row)) + { + if(r < actual_q) + { + const T* do_src = grad_O + ((size_t)(q_offset + r) * head_num + head_idx) * head_dim; + *(bf16x8*)(&dO_lds[r * hd_pad + col]) = bwd_load_cvt_bf16x8(do_src + col); + } + else + { + *(bf16x8*)(&dO_lds[r * hd_pad + col]) = bf16x8{0, 0, 0, 0, 0, 0, 0, 0}; + } + } + } + + // Load V → KV_lds + { + constexpr int threads_per_row = head_dim / 8; + const int row = tid / threads_per_row; + const int col = (tid % threads_per_row) * 8; + const int clamped_max = max(seq_kv - 1, 0); + + for(int r = row; r < lds_kv_rows; r += (256 / threads_per_row)) + { + const int clamped_r = min(r, clamped_max); + const T* v_src = V + ((size_t)(kv_offset + clamped_r) * head_num + head_idx) * head_dim; + *(bf16x8*)(&KV_lds[r * hd_pad + col]) = bwd_load_cvt_bf16x8(v_src + col); + } + } + + __syncthreads(); + + // grad_attn = dO @ V^T via MFMA (all 4 warps redundant) + float grad_attn[q_tiles * kv_tiles * 4]; + + #pragma unroll + for(int qt = 0; qt < q_tiles; qt++) + { + #pragma unroll + for(int kvt = 0; kvt < kv_tiles; kvt++) + { + floatx4 acc = {0, 0, 0, 0}; + constexpr int total_d_tiles = CEIL_DIV(head_dim, 16); + + #pragma unroll + for(int dtile = 0; dtile < total_d_tiles; dtile++) + { + const int dim_off = dtile * 16; + // A: dO[q, d] + bf16x4 a = *(const bf16x4*)(&dO_lds[(qt * 16 + lane_col) * hd_pad + dim_off + lane_row * 4]); + // B: V[kv, d] + bf16x4 b = *(const bf16x4*)(&KV_lds[(kvt * 16 + lane_col) * hd_pad + dim_off + lane_row * 4]); + + acc = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a, b, acc, 0, 0, 0); + } + + int reg_base = (qt * kv_tiles + kvt) * 4; + #pragma unroll + for(int i = 0; i < 4; i++) + grad_attn[reg_base + i] = acc[i]; + } + } + + // Reload K into KV_lds (overwrite V) and recompute P_ij = exp(S_ij - LSE_i) + { + constexpr int threads_per_row = head_dim / 8; + const int row = tid / threads_per_row; + const int col = (tid % threads_per_row) * 8; + const int clamped_max = max(seq_kv - 1, 0); + + for(int r = row; r < lds_kv_rows; r += (256 / threads_per_row)) + { + const int clamped_r = min(r, clamped_max); + const T* k_src = K + ((size_t)(kv_offset + clamped_r) * head_num + head_idx) * head_dim; + *(bf16x8*)(&KV_lds[r * hd_pad + col]) = bwd_load_cvt_bf16x8(k_src + col); + } + } + + __syncthreads(); + + float attn_reg[q_tiles * kv_tiles * 4]; + + #pragma unroll + for(int qt = 0; qt < q_tiles; qt++) + { + #pragma unroll + for(int kvt = 0; kvt < kv_tiles; kvt++) + { + floatx4 acc_s = {0, 0, 0, 0}; + constexpr int total_hd_tiles = CEIL_DIV(head_dim, 16); + + #pragma unroll + for(int k = 0; k < total_hd_tiles; ++k) + { + const int dim_off = k * 16; + bf16x4 a = *(const bf16x4*)(&Q_lds[(qt * 16 + lane_col) * hd_pad + dim_off + lane_row * 4]); + bf16x4 b = *(const bf16x4*)(&KV_lds[(kvt * 16 + lane_col) * hd_pad + dim_off + lane_row * 4]); + acc_s = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a, b, acc_s, 0, 0, 0); + } + + int reg_base = (qt * kv_tiles + kvt) * 4; + #pragma unroll + for(int i = 0; i < 4; i++) + { + int q_row = qt * 16 + lane_row * 4 + i; + int kv_pos = kvt * 16 + lane_col; + bool masked = (kv_pos >= seq_kv) || (q_row >= actual_q); + if constexpr(Config::mask_type == CausalMaskType::TOP_LEFT) + { + if(kv_pos > q_row) + masked = true; + } + float S = acc_s[i] * scale; + float lse = + softmax_lse[((size_t)(q_offset + q_row) * head_num + head_idx)]; + attn_reg[reg_base + i] = masked ? 0.0f : expf(S - lse); + } + } + } + + // Softmax backward: grad_score = attn * (grad_attn - dot_sum) + float grad_score[q_tiles * kv_tiles * 4]; + + #pragma unroll + for(int qt = 0; qt < q_tiles; qt++) + { + #pragma unroll + for(int i = 0; i < 4; i++) + { + int q_row = qt * 16 + lane_row * 4 + i; + + // dot_sum = sum_kv(grad_attn * attn) + float dot_sum = 0.0f; + #pragma unroll + for(int kvt = 0; kvt < kv_tiles; kvt++) + { + int reg_idx = (qt * kv_tiles + kvt) * 4 + i; + float partial = grad_attn[reg_idx] * attn_reg[reg_idx]; + + // Reduce across lane_col + #pragma unroll + for(int off = 8; off > 0; off /= 2) + partial += __shfl_xor(partial, off, 16); + + dot_sum += partial; + } + + // grad_score = attn * (grad_attn - dot_sum) + #pragma unroll + for(int kvt = 0; kvt < kv_tiles; kvt++) + { + int reg_idx = (qt * kv_tiles + kvt) * 4 + i; + int kv_pos = kvt * 16 + lane_col; + float gs = attn_reg[reg_idx] * (grad_attn[reg_idx] - dot_sum); + + // Zero invalid + if(q_row >= actual_q || kv_pos >= seq_kv) + gs = 0.0f; + + grad_score[reg_idx] = gs; + } + } + } + + // Write grad_scores → SM_lds + #pragma unroll + for(int qt = 0; qt < q_tiles; qt++) + { + #pragma unroll + for(int kvt = 0; kvt < kv_tiles; kvt++) + { + #pragma unroll + for(int i = 0; i < 4; i++) + { + int q_row = qt * 16 + lane_row * 4 + i; + int kv_pos = kvt * 16 + lane_col; + int reg_idx = (qt * kv_tiles + kvt) * 4 + i; + SM_lds[q_row * lds_sm_stride + kv_pos] = grad_score[reg_idx]; + } + } + } + + __syncthreads(); + + // K is already in KV_lds from P recomputation + + // grad_Q = grad_scores @ K * scale (4 warps split head_dim) + #pragma unroll + for(int qt = 0; qt < q_tiles; qt++) + { + constexpr int BK = 64; + constexpr int total_d_tiles = CEIL_DIV(head_dim, BK); + + #pragma unroll + for(int d = 0; d < total_d_tiles; d++) + { + const int dim_idx = d * BK + warp_id * 16; + floatx4 acc = {0, 0, 0, 0}; + + #pragma unroll + for(int kvt = 0; kvt < kv_tiles; kvt++) + { + // A: grad_scores (transposed SM_lds read) + bf16x4 a; + #pragma unroll + for(int k = 0; k < 4; k++) + { + int q_row = qt * 16 + lane_col; + int kv_pos = kvt * 16 + lane_row * 4 + k; + a[k] = static_cast(SM_lds[q_row * lds_sm_stride + kv_pos]); + } + + // B: K[kv, d] + bf16x4 b; + const int kv_base = kvt * 16; + #pragma unroll + for(int k = 0; k < 4; k++) + { + b[k] = KV_lds[(kv_base + lane_row * 4 + k) * hd_pad + dim_idx + lane_col]; + } + + acc = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a, b, acc, 0, 0, 0); + } + + // Write grad_Q + #pragma unroll + for(int i = 0; i < 4; i++) + { + int q_row = qt * 16 + lane_row * 4 + i; + if(q_row < actual_q) + { + int gq_idx = ((size_t)(q_offset + q_row) * head_num + head_idx) * head_dim + + dim_idx + lane_col; + grad_Q[gq_idx] = static_cast(acc[i] * scale); + } + } + } + } + + // grad_K = grad_scores^T @ Q * scale (4 warps split head_dim) + #pragma unroll + for(int kvt = 0; kvt < kv_tiles; kvt++) + { + constexpr int BK = 64; + constexpr int total_d_tiles = CEIL_DIV(head_dim, BK); + + #pragma unroll + for(int d = 0; d < total_d_tiles; d++) + { + const int dim_idx = d * BK + warp_id * 16; + floatx4 acc = {0, 0, 0, 0}; + + #pragma unroll + for(int qt = 0; qt < q_tiles; qt++) + { + // A: grad_scores^T (direct SM_lds read) + bf16x4 a; + #pragma unroll + for(int k = 0; k < 4; k++) + { + int q_row = qt * 16 + lane_row * 4 + k; + int kv_pos = kvt * 16 + lane_col; + a[k] = static_cast(SM_lds[q_row * lds_sm_stride + kv_pos]); + } + + // B: Q[q, d] + bf16x4 b; + const int q_base = qt * 16; + #pragma unroll + for(int k = 0; k < 4; k++) + { + b[k] = Q_lds[(q_base + lane_row * 4 + k) * hd_pad + dim_idx + lane_col]; + } + + acc = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a, b, acc, 0, 0, 0); + } + + // Write grad_K + #pragma unroll + for(int i = 0; i < 4; i++) + { + int kv_pos = kvt * 16 + lane_row * 4 + i; + if(kv_pos < seq_kv) + { + int gk_idx = (kv_offset + kv_pos) * head_num * head_dim + + head_idx * head_dim + dim_idx + lane_col; + grad_K[gk_idx] = static_cast(acc[i] * scale); + } + } + } + } +} + +// --------------------------------------------------------------------------- +// AttnBackwardMfma16x16KernelLauncher — Grid: (1, head_num, bs), Block: 256 +// --------------------------------------------------------------------------- + +template +struct AttnBackwardMfma16x16KernelLauncher +{ + using bwd_softmax_aux_scalar = float; + + /// Option A: backward recomputes P from Q, K, and softmax_lse — no P workspace. + static size_t calc_workspace_size(int total_padded_q) + { + (void)total_padded_q; + return 0; + } + + static void run_attn_bwd_kernel(const T* Q, + const T* K, + const T* V, + const T* grad_O, + const float* softmax_lse, + T* grad_Q, + T* grad_K, + T* grad_V, + float sqr_dk_scale, + const int* cu_seqlens_q, + const int* cu_seqlens_q_padded, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded) + { + float scale = sqr_dk_scale; + + dim3 grid(1, Config::head_num, Config::bs); + dim3 block(256); + + // Kernel B: grad_V = P^T @ grad_O (P recomputed from Q, K, LSE) + fmha_bwd_grad_v_mfma_16x16_kernel<<>>( + Q, K, softmax_lse, grad_O, grad_V, scale, + cu_seqlens_q, cu_seqlens_q_padded, + cu_seqlens_kv, cu_seqlens_kv_padded); + + // Kernel A: fused grad_attn / softmax_bwd / grad_Q / grad_K + fmha_bwd_fused_mfma_16x16_kernel<<>>( + Q, K, V, grad_O, softmax_lse, + grad_Q, grad_K, scale, + cu_seqlens_q, cu_seqlens_q_padded, + cu_seqlens_kv, cu_seqlens_kv_padded); + } +}; diff --git a/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_common.h b/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_common.h new file mode 100644 index 000000000..fd21080b4 --- /dev/null +++ b/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_common.h @@ -0,0 +1,69 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// Error checking macro +// --------------------------------------------------------------------------- + +#define HIP_CHECK(call) \ + do \ + { \ + hipError_t err = call; \ + if(err != hipSuccess) \ + { \ + printf("HIP error %s:%d: '%s'\n", __FILE__, __LINE__, hipGetErrorString(err)); \ + exit(1); \ + } \ + } while(0) + +// --------------------------------------------------------------------------- +// Causal mask type +// --------------------------------------------------------------------------- + +enum class CausalMaskType +{ + DISABLE = 0, + TOP_LEFT = 1, + BOTTOM_RIGHT = 2 +}; + +// inline to avoid ODR violation across multiple translation units (C++17) +inline std::map CausalMaskTypeName = { + {CausalMaskType::DISABLE, "DISABLE"}, + {CausalMaskType::TOP_LEFT, "TOP_LEFT"}, + {CausalMaskType::BOTTOM_RIGHT, "BOTTOM_RIGHT"}}; + +// --------------------------------------------------------------------------- +// Kernel configuration struct +// +// Template parameters encode the static layout dimensions used by all kernels. +// Runtime variability (actual Q/KV lengths per batch) is handled via cu_seqlens. +// --------------------------------------------------------------------------- + +template +struct FmhaKernelConfig +{ + static constexpr int bs = BS; + static constexpr int head_num = HEAD_NUM; + static constexpr int max_seq_q = MAX_SEQ_Q; + // Backward compat alias for scalar fwd/bwd kernels (hardcoded seq_q=1) + static constexpr int seq_q = 1; + static constexpr int max_seq_kv = MAX_SEQ_KV; + static constexpr int head_dim = HEAD_DIM; + static constexpr int step2_block_size = STEP2_BLOCK_SIZE; + static constexpr bool enable_dropout_mask = ENABLE_DROPOUT_MASK; + static constexpr enum CausalMaskType mask_type = MAKS_TYPE; +}; diff --git a/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_fwd.h b/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_fwd.h new file mode 100644 index 000000000..cbf48144c --- /dev/null +++ b/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_fwd.h @@ -0,0 +1,450 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "attn_common.h" +#include + +// --------------------------------------------------------------------------- +// Kernel 1: compute_scores_kernel +// +// Computes attention scores: Q @ K^T * scale +// Q layout: [total_padded_q, head_num, head_dim] +// K layout: [total_padded_kv_seq, head_num, head_dim] +// scores layout: [total_padded_q, head_num, max_seq_kv] +// --------------------------------------------------------------------------- + +template +__global__ void compute_scores_kernel(const T* Q, + const T* K, + T* scores, + float scale, + const int* cu_seqlens_q, + const int* cu_seqlens_q_padded, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded) +{ + // seq_q is 1 in static layout (storage), but actual Q length per batch may be 0 or 1. + constexpr int seq_q = Config::seq_q; // == 1 (padded storage dimension) + static_assert(seq_q == 1, "seq_q must be 1 for this kernel implementation."); + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = 64; + constexpr int thread_block_size = 64; + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * thread_block_size * tasks_per_block; + int thread_id = threadIdx.x; + + for(int task = 0; task < tasks_per_block; task++) + { + int cur_batch_idx = base_block_offset + task * thread_block_size + thread_id; + // Layout: [batch, seq_q(storage=1), head_num, head_dim] + // cur_batch_idx represents the combined index for (batch * seq_q * head_num) + int batch_idx = cur_batch_idx / (Config::seq_q * Config::head_num); + int seq_head_idx = cur_batch_idx % (Config::seq_q * Config::head_num); + int seq_idx = seq_head_idx / Config::head_num; + int head_idx = seq_head_idx % Config::head_num; + + if(batch_idx >= Config::bs) + continue; + + // Skip batches where actual Q sequence length is 0. + // Memory is still allocated (padded to seq_q=1), but no computation needed. + int actual_seq_q = cu_seqlens_q[batch_idx + 1] - cu_seqlens_q[batch_idx]; + if(actual_seq_q == 0) + continue; + + // Get actual sequence length for this batch + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + int kv_offset = cu_seqlens_kv_padded[batch_idx]; + + // Q storage offset: cu_seqlens_q_padded[batch_idx] is the slot for this batch. + // seq_idx is always 0 because seq_q == 1. + int q_storage_offset = cu_seqlens_q_padded[batch_idx]; + + float results[max_seq_kv]; + T fetch_Q[block_k]; + T fetch_K[block_k]; + // Q: [total_padded_seq_q, head_num, head_dim] + T* Q_ptr = (T*)&Q[(q_storage_offset * Config::head_num + head_idx) * head_dim]; + // K: [total_padded_seq_kv, head_num, head_dim] + T* K_ptr = (T*)&K[(kv_offset * Config::head_num + head_idx) * head_dim]; + // scores workspace: [total_padded_q, head_num, max_seq_kv] + // index by padded Q slot: cu_seqlens_q_padded[batch_idx] + T* score_ptr = + (T*)&scores[(cu_seqlens_q_padded[batch_idx] * Config::head_num + head_idx) * + max_seq_kv]; + uint4 ls_dwordx4_tmp_var; + for(int i = 0; i < seq_kv; i++) + results[i] = 0.0f; + for(int dim_offset = 0; dim_offset < head_dim; dim_offset += block_k) + { + if constexpr(std::is_same::value) + { + for(int k = 0; k < block_k / 8; k++) + { + ls_dwordx4_tmp_var = *((uint4*)&Q_ptr[dim_offset + k * 8]); + fetch_Q[k * 8 + 0] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[0]; + fetch_Q[k * 8 + 1] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[1]; + fetch_Q[k * 8 + 2] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[0]; + fetch_Q[k * 8 + 3] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[1]; + fetch_Q[k * 8 + 4] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[0]; + fetch_Q[k * 8 + 5] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[1]; + fetch_Q[k * 8 + 6] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[0]; + fetch_Q[k * 8 + 7] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[1]; + } + for(int kv_idx = 0; kv_idx < seq_kv; kv_idx++) + { + for(int k = 0; k < block_k / 8; k++) + { + // K layout: [batch, seq_kv, head_num, head_dim] + ls_dwordx4_tmp_var = *((uint4*)&K_ptr[kv_idx * Config::head_num * head_dim + + dim_offset + k * 8]); + fetch_K[k * 8 + 0] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[0]; + fetch_K[k * 8 + 1] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.x)[1]; + fetch_K[k * 8 + 2] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[0]; + fetch_K[k * 8 + 3] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.y)[1]; + fetch_K[k * 8 + 4] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[0]; + fetch_K[k * 8 + 5] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.z)[1]; + fetch_K[k * 8 + 6] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[0]; + fetch_K[k * 8 + 7] = ((hip_bfloat16*)&ls_dwordx4_tmp_var.w)[1]; + } +#pragma unroll + for(int k = 0; k < block_k; k++) + { + results[kv_idx] += + static_cast(fetch_Q[k]) * static_cast(fetch_K[k]); + } + } + } + else + { + for(int k = 0; k < block_k / 4; k++) + { + ls_dwordx4_tmp_var = *((uint4*)&Q_ptr[dim_offset + k * 4]); + fetch_Q[k * 4 + 0] = *((T*)&ls_dwordx4_tmp_var.x); + fetch_Q[k * 4 + 1] = *((T*)&ls_dwordx4_tmp_var.y); + fetch_Q[k * 4 + 2] = *((T*)&ls_dwordx4_tmp_var.z); + fetch_Q[k * 4 + 3] = *((T*)&ls_dwordx4_tmp_var.w); + } + for(int kv_idx = 0; kv_idx < seq_kv; kv_idx++) + { + for(int k = 0; k < block_k / 4; k++) + { + // K layout: [batch, seq_kv, head_num, head_dim] + ls_dwordx4_tmp_var = *((uint4*)&K_ptr[kv_idx * Config::head_num * head_dim + + dim_offset + k * 4]); + fetch_K[k * 4 + 0] = *((T*)&ls_dwordx4_tmp_var.x); + fetch_K[k * 4 + 1] = *((T*)&ls_dwordx4_tmp_var.y); + fetch_K[k * 4 + 2] = *((T*)&ls_dwordx4_tmp_var.z); + fetch_K[k * 4 + 3] = *((T*)&ls_dwordx4_tmp_var.w); + } +#pragma unroll + for(int k = 0; k < block_k; k++) + { + results[kv_idx] += fetch_Q[k] * fetch_K[k]; + } + } + } + } + for(int i = 0; i < seq_kv; i++) + { + score_ptr[i] = T(results[i] * scale); + } + // Zero out padding positions + for(int i = seq_kv; i < max_seq_kv; i++) + { + score_ptr[i] = T(-1e9f); + } + } +} + +// --------------------------------------------------------------------------- +// Kernel 2: apply_mask_and_softmax_kernel +// +// Applies causal/KV-padding masks and computes numerically stable softmax. +// scores layout: [total_padded_q, head_num, max_seq_kv] +// padded_q_to_batch: host-precomputed reverse map [padded_q_slot] -> batch_idx +// --------------------------------------------------------------------------- + +template +__global__ void apply_mask_and_softmax_kernel(T* scores, + const T* dropout_mask, + float dropout_scale, + const int* cu_seqlens_kv, + const int* padded_q_to_batch, + uint32_t total_elt) +{ + const uint32_t block_id = blockIdx.x; + const uint32_t thread_id = threadIdx.x; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int block_size = Config::step2_block_size; + constexpr int per_score_size = max_seq_kv; // seq_q == 1 + constexpr int valid_thread_range = block_size / per_score_size * per_score_size; + const uint32_t cur_block_offset = block_id * valid_thread_range + thread_id; + bool is_tail = block_id * valid_thread_range + block_size >= total_elt; + int real_row_num = is_tail ? (total_elt - block_id * valid_thread_range) / max_seq_kv + : valid_thread_range / max_seq_kv; + + if(cur_block_offset < total_elt && thread_id < valid_thread_range) + { + __shared__ T tmp_scores[valid_thread_range]; + constexpr int row_num = valid_thread_range / max_seq_kv; + __shared__ T row_max[row_num]; + __shared__ T row_sum[row_num]; + + // scores layout: [total_padded_q, head_num, max_seq_kv] + // global_row_idx encodes (padded_q_slot * head_num + head_idx) + int global_row_idx = cur_block_offset / max_seq_kv; + int padded_q_slot = global_row_idx / Config::head_num; + int k_idx = cur_block_offset % max_seq_kv; + + // Reverse-map padded Q slot to batch_idx via host-precomputed table. + // All slots in the buffer are guaranteed to belong to a valid (active) batch + // because empty-Q batches contribute no rows. + int batch_idx = padded_q_to_batch[padded_q_slot]; + + // Get actual KV sequence length for this batch + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + + tmp_scores[thread_id] = scores[cur_block_offset]; + + // Apply causal mask / KV-padding mask + if constexpr(Config::mask_type == CausalMaskType::TOP_LEFT) + { + // q_idx == 0 (seq_q == 1); mask: k_idx > 0 || k_idx >= seq_kv + if(k_idx > 0 || k_idx >= seq_kv) + tmp_scores[thread_id] = T(-1e9f); + } + else if constexpr(Config::mask_type == CausalMaskType::BOTTOM_RIGHT) + { + // q_idx == 0; mask: k_idx < 0 (never) || k_idx >= seq_kv + if(k_idx >= seq_kv) + tmp_scores[thread_id] = T(-1e9f); + } + else + { + if(k_idx >= seq_kv) + tmp_scores[thread_id] = T(-1e9f); + } + __syncthreads(); + + // Find max for each row (numerically stable softmax) + if(thread_id < real_row_num) + { + T max_val = T(-1e9f); +#pragma unroll + for(int i = 0; i < max_seq_kv; i++) + { + max_val = max(max_val, tmp_scores[thread_id * max_seq_kv + i]); + } + row_max[thread_id] = max_val; + } + __syncthreads(); + + // Compute exp(score - max) and sum for each row + T exp_val = T(exp(float(tmp_scores[thread_id] - row_max[thread_id / max_seq_kv]))); + tmp_scores[thread_id] = exp_val; + __syncthreads(); + + if(thread_id < real_row_num) + { + T sum = T(0.0f); +#pragma unroll + for(int i = 0; i < max_seq_kv; i++) + { + sum += tmp_scores[thread_id * max_seq_kv + i]; + } + row_sum[thread_id] = sum; + } + __syncthreads(); + + // Normalize and apply dropout + T attn_weight = tmp_scores[thread_id] / row_sum[thread_id / max_seq_kv]; + + if constexpr(Config::enable_dropout_mask) + { + attn_weight = attn_weight * dropout_mask[cur_block_offset] * dropout_scale; + } + + scores[cur_block_offset] = attn_weight; + } +} + +// --------------------------------------------------------------------------- +// Kernel 3: compute_output_kernel +// +// Computes attention output: attn_weights @ V +// attn_weights layout: [total_padded_q, head_num, max_seq_kv] +// V layout: [total_padded_kv_seq, head_num, head_dim] +// O layout: [total_padded_q, head_num, head_dim] +// --------------------------------------------------------------------------- + +template +__global__ void compute_output_kernel(const T* attn_weights, + const T* V, + T* O, + const int* cu_seqlens_q, + const int* cu_seqlens_q_padded, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded) +{ + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int block_k = BLOCK_K; + constexpr int dwordx4_load_elt = 16 / sizeof(T); + constexpr int warp_size = 64; + constexpr int process_head_per_warp = warp_size / (head_dim / block_k); + constexpr int tasks_per_block = TASKS_PER_BLOCK; + + int base_block_offset = blockIdx.x * process_head_per_warp * tasks_per_block; + int thread_id = threadIdx.x; + int thread_batch_offset = thread_id / (head_dim / block_k); + int thread_head_offset = thread_id % (head_dim / block_k) * block_k; + + uint4 load_dwordx4_tmp_var[block_k / dwordx4_load_elt], + store_dwordx4_tmp_var[block_k / dwordx4_load_elt]; + T result[block_k]; + T attn[max_seq_kv]; + + for(int task = 0; task < tasks_per_block; task++) + { + int block_batch_head_idx = base_block_offset + task * process_head_per_warp; + int cur_idx = block_batch_head_idx + thread_batch_offset; + + // Layout: [batch, seq_q(storage=1), head_num, head_dim] + int batch_idx = cur_idx / (Config::seq_q * Config::head_num); + int seq_head_idx = cur_idx % (Config::seq_q * Config::head_num); + int seq_q_idx = seq_head_idx / Config::head_num; + int head_idx = seq_head_idx % Config::head_num; + + if(batch_idx >= Config::bs) + continue; + + // Skip batches where actual Q seq length is 0 — no output to write. + int actual_seq_q = cu_seqlens_q[batch_idx + 1] - cu_seqlens_q[batch_idx]; + if(actual_seq_q == 0) + continue; + + // Get actual sequence length for this batch + int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + int kv_offset = cu_seqlens_kv_padded[batch_idx]; + + // Q output storage offset: one slot per batch, seq_q_idx is always 0. + int q_storage_offset = cu_seqlens_q_padded[batch_idx]; + +#pragma unroll + for(int i = 0; i < block_k / dwordx4_load_elt; i++) + { + store_dwordx4_tmp_var[i].x = 0; + store_dwordx4_tmp_var[i].y = 0; + store_dwordx4_tmp_var[i].z = 0; + store_dwordx4_tmp_var[i].w = 0; + } + // attn_weights layout: [total_padded_q, head_num, max_seq_kv] + int attn_offset = (cu_seqlens_q_padded[batch_idx] * Config::head_num + head_idx) * max_seq_kv; +#pragma unroll + for(int i = 0; i < max_seq_kv; i++) + attn[i] = attn_weights[attn_offset + i]; + for(int j = 0; j < seq_kv; j++) + { +#pragma unroll + for(int i = 0; i < block_k / dwordx4_load_elt; i++) + { + // V layout: [total_padded_seq_kv, head_num, head_dim] + load_dwordx4_tmp_var[i] = + *((uint4*)&V[((kv_offset + j) * Config::head_num + head_idx) * head_dim + + thread_head_offset + i * dwordx4_load_elt]); + } +#pragma unroll + for(int b = 0; b < block_k; b++) + ((T*)&store_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt] += + attn[j] * + ((T*)&load_dwordx4_tmp_var[b / dwordx4_load_elt])[b % dwordx4_load_elt]; + } +#pragma unroll + for(int i = 0; i < block_k / dwordx4_load_elt; i++) + // O layout: [total_padded_seq_q, head_num, head_dim] + *((uint4*)&O[(q_storage_offset * Config::head_num + head_idx) * head_dim + + thread_head_offset + i * dwordx4_load_elt]) = store_dwordx4_tmp_var[i]; + } +} + +// --------------------------------------------------------------------------- +// AttnForwardKernelLauncher +// +// Orchestrates the 3-kernel forward pipeline: +// 1. compute_scores_kernel (Q @ K^T * scale) +// 2. apply_mask_and_softmax_kernel +// 3. compute_output_kernel (attn_weights @ V) +// --------------------------------------------------------------------------- + +template +struct AttnForwardKernelLauncher +{ + // workspace layout: [total_padded_q, head_num, max_seq_kv] + // total_padded_q = cu_seqlens_q_padded[bs] — known on host before kernel launch. + static size_t calc_workspace_size(int total_padded_q) + { + constexpr int head_num = Config::head_num; + constexpr int max_seq_kv = Config::max_seq_kv; + return (size_t)total_padded_q * head_num * max_seq_kv * sizeof(T); + } + + static void run_attn_fwd_kernel(const T* Q, + const T* K, + const T* V, + const T* dropout_mask, + float dropout_p, + float sqr_dk_scale, + T* O, + T* workspace, + const int* cu_seqlens_q, + const int* cu_seqlens_q_padded, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + const int* padded_q_to_batch, + int total_padded_q) + { + constexpr int bs = Config::bs; + constexpr int head_num = Config::head_num; + constexpr int seq_q = Config::seq_q; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int head_dim = Config::head_dim; + constexpr int warp_size = 64; + + constexpr int merge_bs = bs * head_num; + float scale = sqr_dk_scale; + float dropout_scale = (dropout_p > 0.0f) ? (1.0f / (1.0f - dropout_p)) : 1.0f; + + // Step 1: QK^T scores — grid covers all (batch * head_num) tasks + constexpr int kernel1_threads = 64; + dim3 block(kernel1_threads); + dim3 grid(merge_bs / kernel1_threads); + compute_scores_kernel<<>>( + Q, K, workspace, scale, cu_seqlens_q, cu_seqlens_q_padded, cu_seqlens_kv, + cu_seqlens_kv_padded); + + // Step 2: Mask + softmax — grid covers [total_padded_q, head_num, max_seq_kv] elements + constexpr int work_thread_num = + Config::step2_block_size / max_seq_kv * max_seq_kv; // seq_q == 1 + uint32_t total_elt = (uint32_t)total_padded_q * head_num * max_seq_kv; + dim3 grid2((total_elt + work_thread_num - 1) / work_thread_num); + dim3 block2(Config::step2_block_size); + apply_mask_and_softmax_kernel + <<>>(workspace, dropout_mask, dropout_scale, cu_seqlens_kv, + padded_q_to_batch, total_elt); + + // Step 3: Weighted sum over V — grid covers all (batch * head_num) tasks + constexpr int kernel3_block_k = 8; + constexpr int kernel3_threads = 64; + constexpr int process_head_per_warp = warp_size / (head_dim / kernel3_block_k); + dim3 block3(kernel3_threads); + dim3 grid3((merge_bs / process_head_per_warp + 2 - 1) / 2); + compute_output_kernel<<>>( + workspace, V, O, cu_seqlens_q, cu_seqlens_q_padded, cu_seqlens_kv, + cu_seqlens_kv_padded); + } +}; diff --git a/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_fwd_mfma.h b/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_fwd_mfma.h new file mode 100644 index 000000000..e753e7ab5 --- /dev/null +++ b/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_fwd_mfma.h @@ -0,0 +1,398 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "attn_common.h" +#include + +using bhalf_t = __bf16; +using bf16x4 = __bf16 __attribute__((ext_vector_type(4))); +using bf16x8 = __bf16 __attribute__((ext_vector_type(8))); +using floatx4 = float __attribute__((ext_vector_type(4))); + +#ifndef CEIL_DIV +#define CEIL_DIV(a, b) (((a) + (b)-1) / (b)) +#endif + +template +__device__ __forceinline__ bf16x8 load_cvt_bf16x8(const T* src) +{ + if constexpr(sizeof(T) == 2) + { + return *(const bf16x8*)src; + } + else + { + // T = float + bf16x8 r; + #pragma unroll + for(int i = 0; i < 8; i++) + { + r[i] = static_cast(src[i]); + } + return r; + } +} + +// --------------------------------------------------------------------------- +// MFMA 4x4x4 forward kernel (seq_q ≤ 4, online softmax, 16 heads/wave) +// +// Thread: warp[0-3], lane[0-63], mfma_block=lane/4 (head), mfma_tid=lane%4 (Q row) +// LDS: Q_lds[seq_q × 16 × hd_pad], KV_lds[4 × 16 × hd_pad] (reused K→V) +// Grid: (1, ceil(heads/16), bs), Block: 256 +// --------------------------------------------------------------------------- + +template +__launch_bounds__(256, (Config::head_dim == 128) ? 3 : 1) +__global__ void fmha_fwd_mfma_kernel( + const T* Q, + const T* K, + const T* V, + T* O, + T* workspace, + const T* dropout_mask, + float dropout_scale, + float scale, + const int* cu_seqlens_q, + const int* cu_seqlens_q_padded, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded) +{ + // Compile-time constants + constexpr int head_dim = Config::head_dim; + constexpr int head_num = Config::head_num; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int max_seq_q = Config::max_seq_q; + constexpr int hd_pad = head_dim + 4; + + static_assert(max_seq_q >= 1 && max_seq_q <= 4, "4x4x4 kernel supports max_seq_q 1..4"); + + // 4 warps split head_dim for Attn×V + constexpr int dims_per_warp = head_dim / 4; + constexpr int num_dim_groups = dims_per_warp / 4; + + // Thread mapping + const int batch_idx = blockIdx.z; + const int head_group = blockIdx.y; + const int tid = threadIdx.x; + const int warp_id = tid / 64; + const int lane_id = tid % 64; + const int mfma_block = lane_id / 4; // which head within group [0,16) + const int mfma_tid = lane_id % 4; // Q-row worker within MFMA block [0,4) + + const int head_base = head_group * 16; + const int head_idx = head_base + mfma_block; + const bool valid_head = (head_idx < head_num); + + const int actual_q = cu_seqlens_q[batch_idx + 1] - cu_seqlens_q[batch_idx]; + if(actual_q == 0) + return; + + const int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + const int kv_offset = cu_seqlens_kv_padded[batch_idx]; + const int q_offset = cu_seqlens_q_padded[batch_idx]; + + const int warp_dim_start = warp_id * dims_per_warp; + + // LDS + __shared__ __attribute__((aligned(128))) bhalf_t Q_lds[max_seq_q * 16 * hd_pad]; + __shared__ __attribute__((aligned(128))) bhalf_t KV_lds[4 * 16 * hd_pad]; + + // Cooperative load: each thread loads 8 bf16 values + const int load_idx = tid * 8; + const int load_head = load_idx / head_dim; + const int load_dim = load_idx % head_dim; + const int load_lds_off = load_head * hd_pad + load_dim; + + // MFMA LDS read offsets + const int q_lds_base = mfma_block * hd_pad; + const int k_lds_base = mfma_tid * 16 * hd_pad + mfma_block * hd_pad; + + // Load Q → Q_lds + #pragma unroll + for(int qr = 0; qr < max_seq_q; qr++) + { + const int q_lds_offset = qr * 16 * hd_pad; + + if(qr < actual_q && head_base + load_head < head_num) + { + const T* q_src = Q + ((size_t)(q_offset + qr) * head_num + head_base) * head_dim; + *(bf16x8*)(&Q_lds[q_lds_offset + load_lds_off]) = load_cvt_bf16x8(q_src + load_idx); + } + else + { + *(bf16x8*)(&Q_lds[q_lds_offset + load_lds_off]) = bf16x8{0, 0, 0, 0, 0, 0, 0, 0}; + } + } + + // Online attention: fused QK^T → softmax → Attn×V per KV group of 4 + float running_max[max_seq_q]; + float running_sum[max_seq_q]; + float v_acc[max_seq_q][num_dim_groups]; + + #pragma unroll + for(int m = 0; m < max_seq_q; m++) + { + running_max[m] = -INFINITY; + running_sum[m] = 0.0f; + #pragma unroll + for(int dg = 0; dg < num_dim_groups; dg++) + v_acc[m][dg] = 0.0f; + } + + const int num_kv_groups = CEIL_DIV(seq_kv, 4); + + for(int kv_grp = 0; kv_grp < num_kv_groups; kv_grp++) + { + const int kv_base = kv_grp * 4; + + // Load K[4 positions] → KV_lds + #pragma unroll + for(int kv = 0; kv < 4; kv++) + { + const int kv_pos = kv_base + kv; + const int clamped_kv = min(kv_pos, max(seq_kv - 1, 0)); + const T* k_src = K + ((size_t)(kv_offset + clamped_kv) * head_num + head_base) * head_dim; + const int kv_lds_offset = kv * 16 * hd_pad; + + if(head_base + load_head < head_num) + { + *(bf16x8*)(&KV_lds[kv_lds_offset + load_lds_off]) = load_cvt_bf16x8(k_src + load_idx); + } + else + { + *(bf16x8*)(&KV_lds[kv_lds_offset + load_lds_off]) = bf16x8{0, 0, 0, 0, 0, 0, 0, 0}; + } + } + + __syncthreads(); + + // MFMA QK^T + floatx4 qk_acc = {0, 0, 0, 0}; + + #pragma unroll + for(int k = 0; k < head_dim; k += 4) + { + bf16x4 q_a, k_b; + + if(mfma_tid < actual_q) + { + q_a = *(const bf16x4*)(&Q_lds[mfma_tid * 16 * hd_pad + q_lds_base + k]); + } + else + { + q_a = bf16x4{0, 0, 0, 0}; + } + + k_b = *(const bf16x4*)(&KV_lds[k_lds_base + k]); + + qk_acc = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(q_a, k_b, qk_acc, 0, 0, 0); + } + + // Online softmax: extract scores, update running_max/sum, rescale v_acc + float my_weights[4]; + + #pragma unroll + for(int m = 0; m < max_seq_q; m++) + { + float scores[4]; + #pragma unroll + for(int s = 0; s < 4; s++) + { + int kv_pos = kv_base + s; + bool masked = (kv_pos >= seq_kv) || (m >= actual_q); + if constexpr(Config::mask_type == CausalMaskType::TOP_LEFT) + { + if(kv_pos > m) + masked = true; + } + scores[s] = masked ? -INFINITY : (__shfl(qk_acc[m], s, 4) * scale); + } + + float tile_max = fmaxf(fmaxf(scores[0], scores[1]), fmaxf(scores[2], scores[3])); + float new_max = fmaxf(running_max[m], tile_max); + + // Rescale previous accumulations (guard -inf - (-inf) = NaN) + if(running_max[m] > -INFINITY) + { + float rescale = expf(running_max[m] - new_max); + running_sum[m] *= rescale; + #pragma unroll + for(int dg = 0; dg < num_dim_groups; dg++) + v_acc[m][dg] *= rescale; + } + running_max[m] = new_max; + + float weights[4]; + #pragma unroll + for(int s = 0; s < 4; s++) + { + weights[s] = (running_max[m] > -INFINITY) ? expf(scores[s] - running_max[m]) : 0.0f; + running_sum[m] += weights[s]; + } + + if(m == mfma_tid) + { + #pragma unroll + for(int s = 0; s < 4; s++) + my_weights[s] = weights[s]; + } + } + + // Apply dropout + if constexpr(Config::enable_dropout_mask) + { + if(valid_head && mfma_tid < actual_q) + { + const int ws_off = ((q_offset + mfma_tid) * head_num + head_idx) * max_seq_kv; + #pragma unroll + for(int s = 0; s < 4; s++) + { + int kv_pos = kv_base + s; + if(kv_pos < seq_kv) + { + my_weights[s] *= static_cast(dropout_mask[ws_off + kv_pos]) + * dropout_scale; + } + } + } + } + + // Convert weights to bf16 for V MFMA + bf16x4 weight_a; + if(mfma_tid < actual_q) + { + #pragma unroll + for(int i = 0; i < 4; i++) + weight_a[i] = static_cast(my_weights[i]); + } + else + { + weight_a = bf16x4{0, 0, 0, 0}; + } + + __syncthreads(); + + // Load V[4 positions] → KV_lds + #pragma unroll + for(int kv = 0; kv < 4; kv++) + { + const int kv_pos = kv_base + kv; + const int kv_lds_offset = kv * 16 * hd_pad; + + if(kv_pos < seq_kv && head_base + load_head < head_num) + { + const T* v_src = V + ((size_t)(kv_offset + kv_pos) * head_num + head_base) * head_dim; + *(bf16x8*)(&KV_lds[kv_lds_offset + load_lds_off]) = load_cvt_bf16x8(v_src + load_idx); + } + else + { + *(bf16x8*)(&KV_lds[kv_lds_offset + load_lds_off]) = bf16x8{0, 0, 0, 0, 0, 0, 0, 0}; + } + } + + __syncthreads(); + + // MFMA weights × V → accumulate v_acc + #pragma unroll + for(int dg = 0; dg < num_dim_groups; dg++) + { + const int out_d = warp_dim_start + dg * 4 + mfma_tid; + + bf16x4 v_b; + #pragma unroll + for(int i = 0; i < 4; i++) + { + v_b[i] = KV_lds[i * 16 * hd_pad + mfma_block * hd_pad + out_d]; + } + + floatx4 mfma_acc; + #pragma unroll + for(int m = 0; m < max_seq_q; m++) + mfma_acc[m] = v_acc[m][dg]; + #pragma unroll + for(int m = max_seq_q; m < 4; m++) + mfma_acc[m] = 0.0f; + + mfma_acc = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + weight_a, v_b, mfma_acc, 0, 0, 0); + + #pragma unroll + for(int m = 0; m < max_seq_q; m++) + v_acc[m][dg] = mfma_acc[m]; + } + + __syncthreads(); + } + + // Normalize: v_acc /= running_sum + #pragma unroll + for(int m = 0; m < max_seq_q; m++) + { + float inv_sum = (running_sum[m] > 0.0f) ? (1.0f / running_sum[m]) : 0.0f; + #pragma unroll + for(int dg = 0; dg < num_dim_groups; dg++) + v_acc[m][dg] *= inv_sum; + } + + // Write output O[total_padded_q, head_num, head_dim] + if(valid_head) + { + #pragma unroll + for(int m = 0; m < max_seq_q; m++) + { + if(m < actual_q) + { + #pragma unroll + for(int dg = 0; dg < num_dim_groups; dg++) + { + const int out_d = warp_dim_start + dg * 4 + mfma_tid; + O[((size_t)(q_offset + m) * head_num + head_idx) * head_dim + out_d] = + static_cast(v_acc[m][dg]); + } + } + } + } +} + +// --------------------------------------------------------------------------- +// AttnForwardMfmaKernelLauncher — Grid: (1, ceil(heads/16), bs), Block: 256 +// --------------------------------------------------------------------------- + +template +struct AttnForwardMfmaKernelLauncher +{ + using fwd_aux_buffer_scalar = T; + + static size_t calc_workspace_size(int total_padded_q) + { + return (size_t)total_padded_q * Config::head_num * Config::max_seq_kv * sizeof(T); + } + + static void run_attn_fwd_kernel(const T* Q, + const T* K, + const T* V, + const T* dropout_mask, + float dropout_p, + float sqr_dk_scale, + T* O, + T* workspace, + const int* cu_seqlens_q, + const int* cu_seqlens_q_padded, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + const int* padded_q_to_batch, + int total_padded_q) + { + float dropout_scale = (dropout_p > 0.0f) ? (1.0f / (1.0f - dropout_p)) : 1.0f; + + dim3 grid(1, CEIL_DIV(Config::head_num, 16), Config::bs); + dim3 block(256); + + fmha_fwd_mfma_kernel<<>>( + Q, K, V, O, workspace, + dropout_mask, dropout_scale, sqr_dk_scale, + cu_seqlens_q, cu_seqlens_q_padded, + cu_seqlens_kv, cu_seqlens_kv_padded); + } +}; diff --git a/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_fwd_mfma_16x16.h b/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_fwd_mfma_16x16.h new file mode 100644 index 000000000..bf16b32a0 --- /dev/null +++ b/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_fwd_mfma_16x16.h @@ -0,0 +1,397 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "attn_common.h" +#include + + +#ifndef ATTN_MFMA_TYPES_DEFINED +#define ATTN_MFMA_TYPES_DEFINED +using bhalf_t = __bf16; +using bf16x4 = __bf16 __attribute__((ext_vector_type(4))); +using bf16x8 = __bf16 __attribute__((ext_vector_type(8))); +using floatx4 = float __attribute__((ext_vector_type(4))); +#endif + +#ifndef CEIL_DIV +#define CEIL_DIV(a, b) (((a) + (b)-1) / (b)) +#endif + +template +__device__ __forceinline__ bf16x8 load_cvt_bf16x8_16(const T* src) +{ + if constexpr(sizeof(T) == 2) + { + return *(const bf16x8*)src; + } + else + { + // T = float + bf16x8 r; + #pragma unroll + for(int i = 0; i < 8; i++) + { + r[i] = static_cast(src[i]); + } + return r; + } +} + +// --------------------------------------------------------------------------- +// MFMA 16x16x16 forward kernel (tiled Q and KV, 1 head/block) +// +// Thread: warp[0-3], lane_row=lane/16 [0,4), lane_col=lane%16 [0,16) +// LDS: Q_lds[lds_q_rows × hd_pad], KV_lds[lds_kv_rows × hd_pad], +// SM_lds[lds_q_rows × lds_sm_stride] +// Grid: (1, head_num, bs), Block: 256 +// +// softmax_lse (Option A / FA2-style aux): one float per (padded Q row, head), +// index ((q_offset + q_row) * head_num + head_idx), +// value log(sum_j exp(scale * QK^T_{row,j})) = row_max + log(row_sum_exp). +// --------------------------------------------------------------------------- + +template +__launch_bounds__(256, 1) +__global__ void fmha_fwd_mfma_16x16_kernel( + const T* Q, + const T* K, + const T* V, + T* O, + float* softmax_lse, + const T* dropout_mask, + float dropout_scale, + float scale, + const int* cu_seqlens_q, + const int* cu_seqlens_q_padded, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded) +{ + // Compile-time constants + constexpr int head_dim = Config::head_dim; + constexpr int head_num = Config::head_num; + constexpr int max_seq_kv = Config::max_seq_kv; + constexpr int max_seq_q = Config::max_seq_q; + constexpr int hd_pad = head_dim + 4; + constexpr int q_tiles = CEIL_DIV(max_seq_q, 16); + constexpr int kv_tiles = CEIL_DIV(max_seq_kv, 16); + constexpr int lds_q_rows = q_tiles * 16; + constexpr int lds_kv_rows = kv_tiles * 16; + constexpr int lds_sm_stride = lds_kv_rows + 4; + + static_assert(max_seq_q >= 1, "max_seq_q must be >= 1"); + + // Thread mapping + const int batch_idx = blockIdx.z; + const int head_idx = blockIdx.y; + const int tid = threadIdx.x; + const int warp_id = tid / 64; + const int lane_id = tid % 64; + const int lane_row = lane_id / 16; + const int lane_col = lane_id % 16; + + const int actual_q = cu_seqlens_q[batch_idx + 1] - cu_seqlens_q[batch_idx]; + if(actual_q == 0) + return; + + const int seq_kv = cu_seqlens_kv[batch_idx + 1] - cu_seqlens_kv[batch_idx]; + const int kv_offset = cu_seqlens_kv_padded[batch_idx]; + const int q_offset = cu_seqlens_q_padded[batch_idx]; + + // LDS + __shared__ __attribute__((aligned(128))) bhalf_t Q_lds[lds_q_rows * hd_pad]; + __shared__ __attribute__((aligned(128))) bhalf_t KV_lds[lds_kv_rows * hd_pad]; + __shared__ float SM_lds[lds_q_rows * lds_sm_stride]; + + // Load Q → Q_lds + { + constexpr int threads_per_row = head_dim / 8; + const int row = tid / threads_per_row; + const int col = (tid % threads_per_row) * 8; + + for(int r = row; r < lds_q_rows; r += (256 / threads_per_row)) + { + if(r < actual_q) + { + const T* q_src = Q + ((size_t)(q_offset + r) * head_num + head_idx) * head_dim; + *(bf16x8*)(&Q_lds[r * hd_pad + col]) = load_cvt_bf16x8_16(q_src + col); + } + else + { + *(bf16x8*)(&Q_lds[r * hd_pad + col]) = bf16x8{0, 0, 0, 0, 0, 0, 0, 0}; + } + } + } + + // Load K → KV_lds + { + constexpr int threads_per_row = head_dim / 8; + const int row = tid / threads_per_row; + const int col = (tid % threads_per_row) * 8; + + for(int r = row; r < lds_kv_rows; r += (256 / threads_per_row)) + { + if(r < seq_kv) + { + const T* k_src = K + ((size_t)(kv_offset + r) * head_num + head_idx) * head_dim; + *(bf16x8*)(&KV_lds[r * hd_pad + col]) = load_cvt_bf16x8_16(k_src + col); + } + else + { + *(bf16x8*)(&KV_lds[r * hd_pad + col]) = bf16x8{0, 0, 0, 0, 0, 0, 0, 0}; + } + } + } + + __syncthreads(); + + // QK^T via MFMA (all 4 warps redundant) + float attn_weight[q_tiles * kv_tiles * 4]; + + #pragma unroll + for(int qt = 0; qt < q_tiles; qt++) + { + #pragma unroll + for(int kvt = 0; kvt < kv_tiles; kvt++) + { + floatx4 acc = {0, 0, 0, 0}; + constexpr int total_hd_tiles = CEIL_DIV(head_dim, 16); + + #pragma unroll + for(int k = 0; k < total_hd_tiles; ++k) + { + const int dim_off = k * 16; + bf16x4 a = *(const bf16x4*)(&Q_lds[(qt * 16 + lane_col) * hd_pad + dim_off + lane_row * 4]); + bf16x4 b = *(const bf16x4*)(&KV_lds[(kvt * 16 + lane_col) * hd_pad + dim_off + lane_row * 4]); + acc = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a, b, acc, 0, 0, 0); + } + + int reg_base = (qt * kv_tiles + kvt) * 4; + #pragma unroll + for(int i = 0; i < 4; i++) + attn_weight[reg_base + i] = acc[i] * scale; + } + } + + // Softmax: two-pass across KV tiles per Q row + #pragma unroll + for(int qt = 0; qt < q_tiles; qt++) + { + #pragma unroll + for(int i = 0; i < 4; i++) + { + int q_row = qt * 16 + lane_row * 4 + i; + + // Pass 1: find global row_max across all KV tiles + float row_max = -INFINITY; + #pragma unroll + for(int kvt = 0; kvt < kv_tiles; kvt++) + { + int reg_idx = (qt * kv_tiles + kvt) * 4 + i; + int kv_pos = kvt * 16 + lane_col; + + bool masked = (kv_pos >= seq_kv) || (q_row >= actual_q); + if constexpr(Config::mask_type == CausalMaskType::TOP_LEFT) + { + if(kv_pos > q_row) + masked = true; + } + + float val = masked ? -INFINITY : attn_weight[reg_idx]; + + float tile_max = val; + #pragma unroll + for(int off = 8; off > 0; off /= 2) + tile_max = fmaxf(tile_max, __shfl_xor(tile_max, off, 16)); + + row_max = fmaxf(row_max, tile_max); + } + + // Pass 2: compute exp and sum across all KV tiles + float row_sum = 0.0f; + #pragma unroll + for(int kvt = 0; kvt < kv_tiles; kvt++) + { + int reg_idx = (qt * kv_tiles + kvt) * 4 + i; + int kv_pos = kvt * 16 + lane_col; + + bool masked = (kv_pos >= seq_kv) || (q_row >= actual_q); + if constexpr(Config::mask_type == CausalMaskType::TOP_LEFT) + { + if(kv_pos > q_row) + masked = true; + } + + float exp_val = masked ? 0.0f : expf(attn_weight[reg_idx] - row_max); + attn_weight[reg_idx] = exp_val; + + float tile_sum = exp_val; + #pragma unroll + for(int off = 8; off > 0; off /= 2) + tile_sum += __shfl_xor(tile_sum, off, 16); + row_sum += tile_sum; + } + + // Log-sum-exp per row (matches FlashAttention-style LSE; pre-dropout) + float lse_row = (row_sum > 0.0f) ? (row_max + logf(row_sum)) : -INFINITY; + if(lane_col == 0 && q_row < actual_q) + { + softmax_lse[((size_t)(q_offset + q_row) * head_num + head_idx)] = lse_row; + } + + // Normalize and apply dropout + float inv_sum = __builtin_amdgcn_rcpf(row_sum); + #pragma unroll + for(int kvt = 0; kvt < kv_tiles; kvt++) + { + int reg_idx = (qt * kv_tiles + kvt) * 4 + i; + attn_weight[reg_idx] *= inv_sum; + + if constexpr(Config::enable_dropout_mask) + { + int kv_pos = kvt * 16 + lane_col; + if(q_row < actual_q && kv_pos < seq_kv) + { + const int ws_offset = ((q_offset + q_row) * head_num + head_idx) * max_seq_kv; + attn_weight[reg_idx] *= static_cast(dropout_mask[ws_offset + kv_pos]) * dropout_scale; + } + } + } + } + } + + // Write weights to SM_lds for Attn×V + #pragma unroll + for(int qt = 0; qt < q_tiles; qt++) + { + #pragma unroll + for(int kvt = 0; kvt < kv_tiles; kvt++) + { + #pragma unroll + for(int i = 0; i < 4; i++) + { + int q_row = qt * 16 + lane_row * 4 + i; + int kv_pos = kvt * 16 + lane_col; + int reg_idx = (qt * kv_tiles + kvt) * 4 + i; + SM_lds[q_row * lds_sm_stride + kv_pos] = attn_weight[reg_idx]; + } + } + } + + __syncthreads(); + + // Load V → KV_lds (clamped; invalid positions zeroed by softmax weights) + { + constexpr int threads_per_row = head_dim / 8; + const int v_row = tid / threads_per_row; + const int v_col = (tid % threads_per_row) * 8; + const int clamped_max = max(seq_kv - 1, 0); + + for(int r = v_row; r < lds_kv_rows; r += (256 / threads_per_row)) + { + const int clamped_r = min(r, clamped_max); + const T* v_src = V + ((size_t)(kv_offset + clamped_r) * head_num + head_idx) * head_dim; + *(bf16x8*)(&KV_lds[r * hd_pad + v_col]) = load_cvt_bf16x8_16(v_src + v_col); + } + } + + __syncthreads(); + + // Attn×V via MFMA (4 warps split head_dim, tiled over Q and KV) + { + constexpr int BK = 64; + constexpr int total_d_tiles = CEIL_DIV(head_dim, BK); + + #pragma unroll + for(int qt = 0; qt < q_tiles; qt++) + { + #pragma unroll + for(int d = 0; d < total_d_tiles; d++) + { + const int dim_idx = d * BK + warp_id * 16; + floatx4 acc = {0, 0, 0, 0}; + + #pragma unroll + for(int kvt = 0; kvt < kv_tiles; kvt++) + { + // A: softmax weights (transposed read from SM_lds) + bf16x4 a; + #pragma unroll + for(int k = 0; k < 4; k++) + { + int q_idx = qt * 16 + lane_col; + int kv_pos = kvt * 16 + lane_row * 4 + k; + a[k] = static_cast(SM_lds[q_idx * lds_sm_stride + kv_pos]); + } + + // B: V[kv, d] + bf16x4 b; + const int kv_base = kvt * 16; + #pragma unroll + for(int k = 0; k < 4; k++) + { + b[k] = KV_lds[(kv_base + lane_row * 4 + k) * hd_pad + dim_idx + lane_col]; + } + + acc = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a, b, acc, 0, 0, 0); + } + + // Write output + #pragma unroll + for(int i = 0; i < 4; i++) + { + int q_row = qt * 16 + lane_row * 4 + i; + if(q_row < actual_q) + { + O[((size_t)(q_offset + q_row) * head_num + head_idx) * head_dim + dim_idx + lane_col] = + static_cast(acc[i]); + } + } + } + } + } +} + +// --------------------------------------------------------------------------- +// AttnForwardMfma16x16KernelLauncher — Grid: (1, head_num, bs), Block: 256 +// --------------------------------------------------------------------------- + +template +struct AttnForwardMfma16x16KernelLauncher +{ + using fwd_aux_buffer_scalar = float; + + /// Per-(padded Q row, head) softmax log-sum-exp (float), FA2-compatible aux. + static size_t calc_workspace_size(int total_padded_q) + { + return (size_t)total_padded_q * Config::head_num * sizeof(float); + } + + static void run_attn_fwd_kernel(const T* Q, + const T* K, + const T* V, + const T* dropout_mask, + float dropout_p, + float sqr_dk_scale, + T* O, + float* softmax_lse, + const int* cu_seqlens_q, + const int* cu_seqlens_q_padded, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + const int* padded_q_to_batch, + int total_padded_q) + { + float dropout_scale = (dropout_p > 0.0f) ? (1.0f / (1.0f - dropout_p)) : 1.0f; + float scale = sqr_dk_scale; + + dim3 grid(1, Config::head_num, Config::bs); + dim3 block(256); + + fmha_fwd_mfma_16x16_kernel<<>>( + Q, K, V, O, softmax_lse, + dropout_mask, dropout_scale, scale, + cu_seqlens_q, cu_seqlens_q_padded, + cu_seqlens_kv, cu_seqlens_kv_padded); + } +}; diff --git a/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_fwd_mfma_dispatch.h b/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_fwd_mfma_dispatch.h new file mode 100644 index 000000000..d00ea5f79 --- /dev/null +++ b/transformer_engine/common/fused_attn_rocm/small_seq_kernels/attn_fwd_mfma_dispatch.h @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "attn_fwd_mfma.h" +#include "attn_fwd_mfma_16x16.h" + +// --------------------------------------------------------------------------- +// Dispatch: seq_q ≤ 4 → 4x4x4 (16 heads/wave), seq_q > 4 → 16x16x16 +// --------------------------------------------------------------------------- + +template +struct AttnForwardMfmaDispatchLauncher +{ + static_assert(Config::max_seq_q >= 1, + "max_seq_q must be >= 1"); + + static size_t calc_workspace_size(int total_padded_q) + { + if constexpr(Config::max_seq_q <= 4) + return AttnForwardMfmaKernelLauncher::calc_workspace_size(total_padded_q); + else + return AttnForwardMfma16x16KernelLauncher::calc_workspace_size(total_padded_q); + } + + /// `aux`: 4x4 path = `T*` attention workspace; 16x16 path = `float*` softmax LSE (see + /// AttnForwardMfma16x16KernelLauncher::calc_workspace_size). + static void run_attn_fwd_kernel(const T* Q, + const T* K, + const T* V, + const T* dropout_mask, + float dropout_p, + float sqr_dk_scale, + T* O, + void* aux, + const int* cu_seqlens_q, + const int* cu_seqlens_q_padded, + const int* cu_seqlens_kv, + const int* cu_seqlens_kv_padded, + const int* padded_q_to_batch, + int total_padded_q) + { + if constexpr(Config::max_seq_q <= 4) + { + AttnForwardMfmaKernelLauncher::run_attn_fwd_kernel( + Q, K, V, dropout_mask, dropout_p, sqr_dk_scale, O, static_cast(aux), + cu_seqlens_q, cu_seqlens_q_padded, cu_seqlens_kv, cu_seqlens_kv_padded, + padded_q_to_batch, total_padded_q); + } + else + { + AttnForwardMfma16x16KernelLauncher::run_attn_fwd_kernel( + Q, K, V, dropout_mask, dropout_p, sqr_dk_scale, O, static_cast(aux), + cu_seqlens_q, cu_seqlens_q_padded, cu_seqlens_kv, cu_seqlens_kv_padded, + padded_q_to_batch, total_padded_q); + } + } +}; diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 02efd1b38..c6c8ce29d 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -9,6 +9,9 @@ #include "../extensions.h" #include "transformer_engine/fused_attn.h" #include "transformer_engine/transformer_engine.h" +#ifdef USE_ROCM +#include "common/fused_attn_rocm/fused_attn_smallseq.h" +#endif namespace transformer_engine { namespace jax { @@ -214,6 +217,19 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( nvte_tensor_pack_destroy(&aux_output_tensors); auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape()); +#ifdef USE_ROCM + if (is_ragged && ::transformer_engine::fused_attn_rocm::is_nvte_ck_small_seq_enabled()) { + if (::transformer_engine::fused_attn_rocm::small_seq_static_config_ok( + static_cast(dtype), static_cast(dtype), bias_type, + dropout_probability, qk_head_dim, v_head_dim, attn_heads, num_gqa_groups, mask_type)) { + const size_t max_tokens_q_upper = input_batch * q_max_seqlen; + size_t total_ws_bytes = workspace_shape.empty() ? static_cast(1) : workspace_shape[0]; + total_ws_bytes += + ::transformer_engine::fused_attn_rocm::small_seq_extra_workspace_bytes(max_tokens_q_upper); + workspace_shape = std::vector{std::max(total_ws_bytes, static_cast(1))}; + } + } +#endif return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype()); } @@ -504,6 +520,19 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( nvte_tensor_pack_destroy(&aux_input_tensors); auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); +#ifdef USE_ROCM + if (is_ragged && ::transformer_engine::fused_attn_rocm::is_nvte_ck_small_seq_enabled()) { + if (::transformer_engine::fused_attn_rocm::small_seq_static_config_ok( + static_cast(dtype), static_cast(dtype), bias_type, + dropout_probability, qk_head_dim, v_head_dim, attn_heads, num_gqa_groups, mask_type)) { + const size_t max_tokens_q_upper = input_batch * q_max_seqlen; + size_t total_ws_bytes = work_shape.empty() ? static_cast(1) : work_shape[0]; + total_ws_bytes += + ::transformer_engine::fused_attn_rocm::small_seq_extra_workspace_bytes(max_tokens_q_upper); + work_shape = std::vector{std::max(total_ws_bytes, static_cast(1))}; + } + } +#endif return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); }