From d89f90fd30b0510862869404cdaa38a9b56a7e21 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 28 Apr 2026 17:00:41 +0000 Subject: [PATCH 01/20] Updated QoLA (to port CK receipt patch) and TE manifest --- 3rdparty/QoLA | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/QoLA b/3rdparty/QoLA index 549844d77..a597de03f 160000 --- a/3rdparty/QoLA +++ b/3rdparty/QoLA @@ -1 +1 @@ -Subproject commit 549844d771ed3155dd75a6bf2c714cb3f710bada +Subproject commit a597de03f36bf4ea59fe3681675c45e24c441669 From 312212f03eff1ffdb23c80858ea6af3fe19c023f Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 28 Apr 2026 17:04:34 +0000 Subject: [PATCH 02/20] Updated manifest --- transformer_engine/common/ck_fused_attn/qola_manifest.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/ck_fused_attn/qola_manifest.toml b/transformer_engine/common/ck_fused_attn/qola_manifest.toml index b6877c6c0..9bdce2e60 100644 --- a/transformer_engine/common/ck_fused_attn/qola_manifest.toml +++ b/transformer_engine/common/ck_fused_attn/qola_manifest.toml @@ -1,5 +1,5 @@ [qola] -aiter_commit = "33f2e6af5f39379c739720080ed0033d533f5cb2" # pinned AITER submodule commit +aiter_commit = "8f816a049449f39609ee7daca8c21d63aa4274ed" # pinned AITER submodule commit namespace = "te" rocm_versions = ["7.2"] From 89f6983c1a41f6547b6a243921f99f833375e6bf Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 28 Apr 2026 17:53:17 +0000 Subject: [PATCH 03/20] Corrected AITER mha args validation against pinned commit --- .../common/ck_fused_attn/CMakeLists.txt | 33 +++++++++++++++++-- .../common/ck_fused_attn/aiter_prebuilt.cmake | 18 ++++++++-- .../ck_fused_attn/check_aiter_mha_args.py | 5 ++- .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 2 ++ 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/CMakeLists.txt b/transformer_engine/common/ck_fused_attn/CMakeLists.txt index b11e848dd..7ae3e3c7a 100644 --- a/transformer_engine/common/ck_fused_attn/CMakeLists.txt +++ b/transformer_engine/common/ck_fused_attn/CMakeLists.txt @@ -33,9 +33,37 @@ if(NOT Python_EXECUTABLE) find_package(Python COMPONENTS Interpreter QUIET) endif() +# Resolve the manifest-pinned AITER commit (defines AITER_SHA) and bring the +# QoLA-nested AITER source tree to that commit before any consumer reads it +# (header validation below, header includes for the .cpp build later, and +# QoLA's own kernel build if the prebuilt cache misses). +include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake") + if(Python_EXECUTABLE) + set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") + execute_process( + COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${__QOLA_DIR}:$ENV{PYTHONPATH}" + ${Python_EXECUTABLE} -c + "from qola.build_tools.submodule import ensure_aiter_commit; ensure_aiter_commit(r'${__AITER_SOURCE_DIR}', r'${AITER_SHA}')" + RESULT_VARIABLE AITER_CHECKOUT_RESULT + OUTPUT_VARIABLE AITER_CHECKOUT_OUTPUT + ERROR_VARIABLE AITER_CHECKOUT_ERROR + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_STRIP_TRAILING_WHITESPACE + ) + if(NOT AITER_CHECKOUT_RESULT EQUAL 0) + message(FATAL_ERROR + "Failed to sync AITER source tree at ${__AITER_SOURCE_DIR} to " + "manifest-pinned commit ${AITER_SHA}.\n" + "${AITER_CHECKOUT_OUTPUT}\n${AITER_CHECKOUT_ERROR}") + endif() + message(STATUS "[AITER] Synced ${__AITER_SOURCE_DIR} to ${AITER_SHA}") + execute_process( - COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/check_aiter_mha_args.py --mode both --te-dir "${CMAKE_CURRENT_LIST_DIR}/../../.." + COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/check_aiter_mha_args.py + --mode both + --te-dir "${CMAKE_CURRENT_LIST_DIR}/../../.." + --aiter-root "${__AITER_SOURCE_DIR}" RESULT_VARIABLE AITER_ARG_CHECK_RESULT OUTPUT_VARIABLE AITER_ARG_CHECK_OUTPUT ERROR_VARIABLE AITER_ARG_CHECK_ERROR @@ -50,7 +78,7 @@ if(Python_EXECUTABLE) endif() message(STATUS "AITER API validation passed via check_aiter_mha_args.py") else() - message(WARNING "Python interpreter not found; skipping AITER API validation.") + message(WARNING "Python interpreter not found; skipping AITER source-tree sync and API validation.") endif() # so far, there are only gfx942 and gfx950 v3 kernels @@ -78,7 +106,6 @@ if(DEFINED AITER_MHA_PATH) set(__AITER_MHA_PATH ${AITER_MHA_PATH}) else() set(__AITER_MHA_PATH "") - include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake") get_prebuilt_aiter(__AITER_MHA_PATH) if(__AITER_MHA_PATH STREQUAL "") diff --git a/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake b/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake index ea0396116..65ee3ec81 100644 --- a/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake +++ b/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake @@ -18,8 +18,22 @@ string(STRIP "${ROCM_VER_CONTENT}" ROCM_VER_CONTENT) string(REGEX MATCH "^[0-9]+\\.[0-9]+" ROCM_VER "${ROCM_VER_CONTENT}") string(REGEX MATCH "^[0-9]+" ROCM_VER_MAJOR "${ROCM_VER}") -# AITER commit -get_git_commit("${__AITER_SOURCE_DIR}" AITER_SHA) +# AITER commit — read from the QoLA manifest so the cache key tracks the +# commit QoLA will actually check out and build, not whatever happens to be +# the submodule's current HEAD at configure time. +set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml") +set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS "${__QOLA_MANIFEST}") +file(STRINGS "${__QOLA_MANIFEST}" __AITER_COMMIT_LINES + REGEX "^[ \t]*aiter_commit[ \t]*=[ \t]*\"[^\"]+\"") +list(LENGTH __AITER_COMMIT_LINES __AITER_COMMIT_COUNT) +if(NOT __AITER_COMMIT_COUNT EQUAL 1) + message(FATAL_ERROR + "Expected exactly one 'aiter_commit = \"...\"' line in " + "${__QOLA_MANIFEST}, found ${__AITER_COMMIT_COUNT}.") +endif() +list(GET __AITER_COMMIT_LINES 0 __AITER_COMMIT_LINE) +string(REGEX MATCH "\"([^\"]+)\"" _UNUSED "${__AITER_COMMIT_LINE}") +set(AITER_SHA "${CMAKE_MATCH_1}") # Cache key & local paths set(AITER_CACHE_ROOT "${CMAKE_CURRENT_LIST_DIR}/../../../build/aiter-prebuilts") diff --git a/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py b/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py index 2e9831f1a..2cce9484f 100644 --- a/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py +++ b/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py @@ -64,11 +64,14 @@ def main() -> int: parser = argparse.ArgumentParser(description="Check aiter args usage vs header definition") parser.add_argument("--mode", choices=["fwd", "bwd", "both"], default="both", help="Mode: fwd, bwd, or both") parser.add_argument("--te-dir", type=Path, default=Path(__file__).parent.parent.parent.parent, help="Root directory of TransformerEngine") + parser.add_argument("--aiter-root", type=Path, default=None, + help="AITER source tree root. Defaults to /3rdparty/aiter.") args = parser.parse_args() + aiter_root = args.aiter_root if args.aiter_root else args.te_dir / "3rdparty/aiter" modes = ["fwd", "bwd"] if args.mode == "both" else [args.mode] mismatch = 0 for mode in modes: - header_path = args.te_dir / f"3rdparty/aiter/csrc/include/mha_{mode}.h" + header_path = aiter_root / f"csrc/include/mha_{mode}.h" source_path = args.te_dir / f"transformer_engine/common/ck_fused_attn/src/ck_fused_attn_{mode}.cpp" header_text = header_path.read_text(encoding="utf-8") source_text = source_path.read_text(encoding="utf-8") diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 5f6af0a41..c68a6a3c7 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -527,6 +527,8 @@ hipError_t _ck_attn_bwd_impl( } aiter::mha_bwd_args fmha_args{}; + fmha_args.sink_ptr=nullptr; + fmha_args.d_sink_ptr=nullptr; fmha_args.mask_type = static_cast(mask_type); fmha_args.use_asm_v3 = uses_bwd_v3; fmha_args.v3_atomic_fp32 = is_v3_atomic_fp32; From c417e4070b9ba22c82da2493ea81e31b6638489e Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 28 Apr 2026 20:53:27 +0000 Subject: [PATCH 04/20] Updated cmake w/ dubious ownership protection --- .../common/ck_fused_attn/CMakeLists.txt | 34 +++++++++++++------ .../common/ck_fused_attn/qola_manifest.toml | 2 ++ 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/CMakeLists.txt b/transformer_engine/common/ck_fused_attn/CMakeLists.txt index 7ae3e3c7a..5aea4db7d 100644 --- a/transformer_engine/common/ck_fused_attn/CMakeLists.txt +++ b/transformer_engine/common/ck_fused_attn/CMakeLists.txt @@ -41,10 +41,17 @@ include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake") if(Python_EXECUTABLE) set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") + # Redirect GIT_CONFIG_GLOBAL to a tempfile carrying `safe.directory = *` so + # git operations inside the QoLA-nested AITER tree (and its recursive + # submodules) work in containerized builds where the bind-mounted .git is + # owned by a different UID than the build process. Mirrors the pattern in + # transformer_engine/common/CMakeLists.txt:get_git_commit(). execute_process( - COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${__QOLA_DIR}:$ENV{PYTHONPATH}" - ${Python_EXECUTABLE} -c - "from qola.build_tools.submodule import ensure_aiter_commit; ensure_aiter_commit(r'${__AITER_SOURCE_DIR}', r'${AITER_SHA}')" + COMMAND sh -c + "tmp=$(mktemp /tmp/gitconfig.XXXXXX) || exit 1; \ +GIT_CONFIG_GLOBAL=$tmp git config --global --add safe.directory '*' >/dev/null 2>&1; \ +GIT_CONFIG_GLOBAL=$tmp PYTHONPATH=\"${__QOLA_DIR}:$PYTHONPATH\" '${Python_EXECUTABLE}' -c 'from qola.build_tools.submodule import ensure_aiter_commit; ensure_aiter_commit(r\"${__AITER_SOURCE_DIR}\", r\"${AITER_SHA}\")'; \ +rc=$?; rm -f \"$tmp\"; exit $rc" RESULT_VARIABLE AITER_CHECKOUT_RESULT OUTPUT_VARIABLE AITER_CHECKOUT_OUTPUT ERROR_VARIABLE AITER_CHECKOUT_ERROR @@ -114,13 +121,19 @@ else() set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") set(__QOLA_BUILD_DIR "${__QOLA_DIR}/build") set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml") + # Same GIT_CONFIG_GLOBAL trick as the early `ensure_aiter_commit` call: + # qola.cli build re-invokes ensure_aiter_commit internally and will hit + # the same dubious-ownership trap without it. execute_process( - COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${__QOLA_DIR}:$ENV{PYTHONPATH}" - ${Python_EXECUTABLE} -m qola.cli build - --manifest ${__QOLA_MANIFEST} - --aiter-root ${__AITER_SOURCE_DIR} - --output-dir ${__QOLA_BUILD_DIR} - --arch "${V3_ASM_ARCHS_STR}" + COMMAND sh -c + "tmp=$(mktemp /tmp/gitconfig.XXXXXX) || exit 1; \ +GIT_CONFIG_GLOBAL=$tmp git config --global --add safe.directory '*' >/dev/null 2>&1; \ +GIT_CONFIG_GLOBAL=$tmp PYTHONPATH=\"${__QOLA_DIR}:$PYTHONPATH\" '${Python_EXECUTABLE}' -m qola.cli build \ +--manifest '${__QOLA_MANIFEST}' \ +--aiter-root '${__AITER_SOURCE_DIR}' \ +--output-dir '${__QOLA_BUILD_DIR}' \ +--arch '${V3_ASM_ARCHS_STR}'; \ +rc=$?; rm -f \"$tmp\"; exit $rc" RESULT_VARIABLE QOLA_BUILD_RESULT ) if(NOT QOLA_BUILD_RESULT EQUAL 0) @@ -155,7 +168,8 @@ endforeach() add_library(ck_fused_attn SHARED ${ck_fused_attn_SOURCES}) set(CK_FUSED_ATTN_COMPILE_OPTIONS) list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS - -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT}) + -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT} + -DENABLE_CK=1) foreach(ARCH IN LISTS V3_ASM_ARCHS) list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS --offload-arch=${ARCH}) diff --git a/transformer_engine/common/ck_fused_attn/qola_manifest.toml b/transformer_engine/common/ck_fused_attn/qola_manifest.toml index 9bdce2e60..2b7e41537 100644 --- a/transformer_engine/common/ck_fused_attn/qola_manifest.toml +++ b/transformer_engine/common/ck_fused_attn/qola_manifest.toml @@ -9,9 +9,11 @@ architectures = ["gfx950", "gfx942"] [[modules]] name = "libmha_fwd" mode = "cpp_itfs" +receipt = 700 drop_srcs = ["mha_fwd_split.cu", "mha_fwd_batch_prefill.cu"] drop_directions = ["fwd_splitkv", "batch_prefill"] [[modules]] name = "libmha_bwd" mode = "cpp_itfs" +receipt = 700 From c7ecaf7a5209cb10cc5c423bc30529f222fca2dc Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 29 Apr 2026 15:41:15 +0000 Subject: [PATCH 05/20] Corrected logging --- .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 37 ++++++++----------- .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 17 +++------ 2 files changed, 22 insertions(+), 32 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 86ea82388..7e5a529b2 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -331,9 +331,8 @@ __global__ void dbias_reduce_b1ss( } // print the fmha_traits and args passed into ck apis -void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args){ +void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args, std::ostream* log_file){ - std::ostream* log_file = get_ck_log_stream(); (*log_file) << "\n" << func_name << "\n"; // fmha_traits debug @@ -447,14 +446,10 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ bool has_dbias = args.dbias_ptr != nullptr; bool is_mqa_gqa = (args.h > args.hg); - bool ck_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_log_config = true; - } + auto* log_file = get_ck_log_stream(); const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, log_file != nullptr}; bias_enum bias_type = bias_enum::no_bias; BiasShape bias_shape = BiasShape::k11SS; @@ -584,8 +579,8 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ // lse_workspace_ptr used as buffer if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) { if(args.is_group_mode() && std::string(env_p) == "1"){ - if(ck_log_config){ - std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + if(log_file){ + *log_file << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; } fmha_args.max_seqlen_q = get_runtime_max_seqlen(args.b, args.cu_seqlen_q_ptr, nullptr, args.lse_workspace_ptr, stream); fmha_args.max_seqlen_k = get_runtime_max_seqlen(args.b, args.cu_seqlen_kv_ptr, nullptr, args.lse_workspace_ptr, stream); @@ -593,8 +588,8 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ } // print ck traits and args when needed - if(ck_log_config){ - log_bwd_config(__FUNCTION__, fmha_args); + if(log_file){ + log_bwd_config(__FUNCTION__, fmha_args, log_file); } float average_runtime = QOLA_NS(mha_bwd)(fmha_args, stream_config); if(average_runtime < 0){ @@ -612,7 +607,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ dim3 grid(args.max_tokens_kv, args.hg); if(args.d_qk == args.d_v){ dim3 block(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_dv_reduce_thd: " << "\n"; *log_file << "cu_seqlen_kv_ptr: " << args.cu_seqlen_kv_ptr << "\n"; *log_file << "cu_seqlen_kv_padded_ptr: " << args.cu_seqlen_kv_padded_ptr << "\n"; @@ -639,7 +634,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ args.stride_h_dk, args.stride_s_dk);); } else { dim3 block_dk(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_or_dv_reduce_thd on dk: " << "\n"; *log_file << "cu_seqlen_kv_ptr: " << args.cu_seqlen_kv_ptr << "\n"; *log_file << "cu_seqlen_kv_padded_ptr: " << args.cu_seqlen_kv_padded_ptr << "\n"; @@ -662,7 +657,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ args.stride_h_dk, args.stride_s_dk);); dim3 block_dv(args.d_v); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_or_dv_reduce_thd on dv: " << "\n"; *log_file << "cu_seqlen_kv_ptr: " << args.cu_seqlen_kv_ptr << "\n"; *log_file << "cu_seqlen_kv_padded_ptr: " << args.cu_seqlen_kv_padded_ptr << "\n"; @@ -688,7 +683,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ dim3 grid(args.b, args.s_kv, args.hg); if(args.d_qk == args.d_v){ dim3 block(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_dv_reduce: " << "\n"; *log_file << "dk_expanded_ptr: " << args.dk_expanded_ptr << "\n"; *log_file << "dv_expanded_ptr: " << args.dv_expanded_ptr << "\n"; @@ -713,7 +708,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ args.stride_b_dk, args.stride_h_dk, args.stride_s_dk);); } else { dim3 block_dk(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_or_dv_reduce on dk: " << "\n"; *log_file << "dk_expanded_ptr: " << args.dk_expanded_ptr << "\n"; *log_file << "stride_b_dk_expanded: " << args.stride_b_dk_expanded << "\n"; @@ -734,7 +729,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ args.stride_b_dk, args.stride_h_dk, args.stride_s_dk);); dim3 block_dv(args.d_v); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_or_dv_reduce on dv: " << "\n"; *log_file << "dv_expanded_ptr: " << args.dv_expanded_ptr << "\n"; *log_file << "stride_b_dv_expanded: " << args.stride_b_dv_expanded << "\n"; @@ -764,7 +759,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ dim3 block(THREADS_PER_BLOCK); dim3 grid(ceil(1.0 * args.s_q * args.s_kv / THREADS_PER_BLOCK)); if(bias_shape==BiasShape::k11SS){ - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dbias_reduce_11SS: " << "\n"; *log_file << "dbias_ptr: " << args.dbias_ptr << "\n"; *log_file << "dbias_expanded_ptr: " << args.dbias_expanded_ptr << "\n"; @@ -776,7 +771,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ static_cast(args.dbias_expanded_ptr), static_cast(args.dbias_ptr));); }else if(bias_shape==BiasShape::k1HSS){ - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dbias_reduce_1HSS: " << "\n"; *log_file << "dbias_ptr: " << args.dbias_ptr << "\n"; *log_file << "dbias_expanded_ptr: " << args.dbias_expanded_ptr << "\n"; @@ -788,7 +783,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ static_cast(args.dbias_expanded_ptr), static_cast(args.dbias_ptr));); }else if(bias_shape==BiasShape::kB1SS){ - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dbias_reduce_B1SS: " << "\n"; *log_file << "dbias_ptr: " << args.dbias_ptr << "\n"; *log_file << "dbias_expanded_ptr: " << args.dbias_expanded_ptr << "\n"; diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index 0f407230c..0f4e9a424 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -15,9 +15,8 @@ namespace ck_fused_attn{ // print the fmha traits and fmha_args when calling ck apis -void log_fwd_config(const char* func_name, bool has_dropout, const aiter::mha_fwd_args& fmha_args){ +void log_fwd_config(const char* func_name, bool has_dropout, const aiter::mha_fwd_args& fmha_args, std::ostream* log_file){ - std::ostream* log_file = get_ck_log_stream(); (*log_file) << "\n" << func_name << "\n"; // debug fmha_traits @@ -103,11 +102,7 @@ hipError_t ck_attn_fwd(const CKAttnFwdArgs& args, hipStream_t stream){ bool has_dropout = (args.is_training && args.dropout_probability > 0.f); - bool ck_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_log_config = true; - } + auto* log_file = get_ck_log_stream(); const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; @@ -218,16 +213,16 @@ hipError_t ck_attn_fwd(const CKAttnFwdArgs& args, hipStream_t stream){ if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")){ if(args.is_group_mode() && std::string(env_p) == "1"){ - if(ck_log_config){ - std::cout << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + if(log_file){ + *log_file << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; } fmha_args.max_seqlen_q = get_runtime_max_seqlen(args.b, args.cu_seqlen_q_ptr, args.cu_seqlen_q_padded_ptr, args.lse_ptr, stream); } } // print ck traits and fmha_args when needed - if(ck_log_config){ - log_fwd_config(__FUNCTION__, has_dropout, fmha_args); + if(log_file){ + log_fwd_config(__FUNCTION__, has_dropout, fmha_args, log_file); } float average_runtime = QOLA_NS(mha_fwd)(fmha_args, stream_config); if(average_runtime < 0){ From da6e9a63bfa79b379fabe7b525a38e73661ce00c Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 29 Apr 2026 21:01:46 +0000 Subject: [PATCH 06/20] Updated qola to build aiter w/ new third_party spec --- 3rdparty/QoLA | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/QoLA b/3rdparty/QoLA index a597de03f..aac57fec6 160000 --- a/3rdparty/QoLA +++ b/3rdparty/QoLA @@ -1 +1 @@ -Subproject commit a597de03f36bf4ea59fe3681675c45e24c441669 +Subproject commit aac57fec69b37a8b51922246a4497275987f9a68 From e7ed124f2e829a0321c2316f590ae6112f90012c Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 29 Apr 2026 21:02:09 +0000 Subject: [PATCH 07/20] Added guards against AITER known buggy implementations --- .../common/ck_fused_attn/src/ck_fused_attn_bwd.cpp | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 7e5a529b2..638e4b877 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -461,7 +461,18 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.sink_ptr = nullptr; fmha_args.d_sink_ptr = nullptr; fmha_args.mask_type = static_cast(static_cast(args.attn_mask_type)); - fmha_args.use_asm_v3 = args.uses_bwd_v3; + // Mirrors AITER's small-seqlen guard at aiter/ops/mha.py:1689. + const bool buggy_small_sq = (args.s_q < 16); + // Predicate matches exactly bwd_hd128_bf16_causal_br_a32_psskddv_group.co + // (broken by AITER PR #2189). Other psskddv_group variants are unaffected. + const bool buggy_br_psskddv_group = + args.is_group_mode() && + args.attn_mask_type == ck_fused_attn::MaskType::mask_bottom_right && + args.dtype == ck_fused_attn::DType::kBFloat16 && + args.d_qk == 128 && args.d_v == 128 && + args.is_v3_atomic_fp32; + fmha_args.use_asm_v3 = + (buggy_small_sq || buggy_br_psskddv_group) ? false : args.uses_bwd_v3; fmha_args.v3_atomic_fp32 = args.is_v3_atomic_fp32; fmha_args.v3_bf16_cvt = args.how_v3_bf16_cvt; fmha_args.v3_api_check = false; From 6241f99d2a6c34c14cec4e4169925451aaf08b1e Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 12 May 2026 19:58:49 +0000 Subject: [PATCH 08/20] Updated build --- 3rdparty/QoLA | 2 +- .../common/ck_fused_attn/CMakeLists.txt | 56 ++++++++++--------- .../common/ck_fused_attn/qola_manifest.toml | 2 +- 3 files changed, 33 insertions(+), 27 deletions(-) diff --git a/3rdparty/QoLA b/3rdparty/QoLA index aac57fec6..9c13e77ef 160000 --- a/3rdparty/QoLA +++ b/3rdparty/QoLA @@ -1 +1 @@ -Subproject commit aac57fec69b37a8b51922246a4497275987f9a68 +Subproject commit 9c13e77ef3cf89053aad61ed3a0f27470f123ee5 diff --git a/transformer_engine/common/ck_fused_attn/CMakeLists.txt b/transformer_engine/common/ck_fused_attn/CMakeLists.txt index 5aea4db7d..0f78311f2 100644 --- a/transformer_engine/common/ck_fused_attn/CMakeLists.txt +++ b/transformer_engine/common/ck_fused_attn/CMakeLists.txt @@ -8,41 +8,30 @@ project(ck_fused_attn LANGUAGES HIP CXX) set(AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha shared lib install prefix in TE") -set(__AITER_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA/3rdparty/aiter") +# QoLA no longer vendors AITER as a submodule; it clones on demand into +# build/third_party/aiter (git-ignored) via `qola checkout`. Mirror that +# default here so the source-build path and the header-include paths +# resolve to the same tree. +set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") +set(__AITER_SOURCE_DIR "${__QOLA_DIR}/build/third_party/aiter") set(__CK_SOURCE_DIR "${__AITER_SOURCE_DIR}/3rdparty/composable_kernel") - set(CK_INCLUDE_DIR "${__CK_SOURCE_DIR}/include") -message(STATUS "ck_include_dir: ${CK_INCLUDE_DIR}") -if(NOT EXISTS "${CK_INCLUDE_DIR}") - message(FATAL_ERROR - "Could not find CK API. " - "Try running 'git submodule update --init --recursive' " - "within the Transformer Engine source.") -endif() - set(AITER_INCLUDE_DIR "${__AITER_SOURCE_DIR}/csrc/include") -message(STATUS "aiter_include_dir: ${AITER_INCLUDE_DIR}") -if(NOT EXISTS "${AITER_INCLUDE_DIR}") - message(FATAL_ERROR - "Could not find AITER API. " - "Try running 'git submodule update --init --recursive' " - "within the Transformer Engine source.") -endif() if(NOT Python_EXECUTABLE) find_package(Python COMPONENTS Interpreter QUIET) endif() # Resolve the manifest-pinned AITER commit (defines AITER_SHA) and bring the -# QoLA-nested AITER source tree to that commit before any consumer reads it +# QoLA-managed AITER source tree to that commit before any consumer reads it # (header validation below, header includes for the .cpp build later, and # QoLA's own kernel build if the prebuilt cache misses). include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake") if(Python_EXECUTABLE) - set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") + set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml") # Redirect GIT_CONFIG_GLOBAL to a tempfile carrying `safe.directory = *` so - # git operations inside the QoLA-nested AITER tree (and its recursive + # git operations inside the QoLA-managed AITER tree (and its recursive # submodules) work in containerized builds where the bind-mounted .git is # owned by a different UID than the build process. Mirrors the pattern in # transformer_engine/common/CMakeLists.txt:get_git_commit(). @@ -50,7 +39,9 @@ if(Python_EXECUTABLE) COMMAND sh -c "tmp=$(mktemp /tmp/gitconfig.XXXXXX) || exit 1; \ GIT_CONFIG_GLOBAL=$tmp git config --global --add safe.directory '*' >/dev/null 2>&1; \ -GIT_CONFIG_GLOBAL=$tmp PYTHONPATH=\"${__QOLA_DIR}:$PYTHONPATH\" '${Python_EXECUTABLE}' -c 'from qola.build_tools.submodule import ensure_aiter_commit; ensure_aiter_commit(r\"${__AITER_SOURCE_DIR}\", r\"${AITER_SHA}\")'; \ +GIT_CONFIG_GLOBAL=$tmp PYTHONPATH=\"${__QOLA_DIR}:$PYTHONPATH\" '${Python_EXECUTABLE}' -m qola.cli checkout \ +--manifest '${__QOLA_MANIFEST}' \ +--aiter-root '${__AITER_SOURCE_DIR}'; \ rc=$?; rm -f \"$tmp\"; exit $rc" RESULT_VARIABLE AITER_CHECKOUT_RESULT OUTPUT_VARIABLE AITER_CHECKOUT_OUTPUT @@ -88,6 +79,23 @@ else() message(WARNING "Python interpreter not found; skipping AITER source-tree sync and API validation.") endif() +# Sanity-check the resolved include directories now that `qola checkout` has +# materialized the AITER tree. +message(STATUS "ck_include_dir: ${CK_INCLUDE_DIR}") +if(NOT EXISTS "${CK_INCLUDE_DIR}") + message(FATAL_ERROR + "Could not find CK API at ${CK_INCLUDE_DIR}. " + "Re-run the build to let `qola checkout` clone AITER and its " + "composable_kernel submodule.") +endif() + +message(STATUS "aiter_include_dir: ${AITER_INCLUDE_DIR}") +if(NOT EXISTS "${AITER_INCLUDE_DIR}") + message(FATAL_ERROR + "Could not find AITER API at ${AITER_INCLUDE_DIR}. " + "Re-run the build to let `qola checkout` clone AITER.") +endif() + # so far, there are only gfx942 and gfx950 v3 kernels SET(V3_ASM_ARCHS_SUPPORTED "gfx942;gfx950") @@ -118,11 +126,9 @@ else() if(__AITER_MHA_PATH STREQUAL "") # If not available, fallback: Build from source via QoLA message(STATUS "[AITER-BUILD] Building AITER kernels via QoLA.") - set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") set(__QOLA_BUILD_DIR "${__QOLA_DIR}/build") - set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml") - # Same GIT_CONFIG_GLOBAL trick as the early `ensure_aiter_commit` call: - # qola.cli build re-invokes ensure_aiter_commit internally and will hit + # Same GIT_CONFIG_GLOBAL trick as the earlier `qola.cli checkout` call: + # `qola.cli build` re-invokes ensure_aiter_commit internally and will hit # the same dubious-ownership trap without it. execute_process( COMMAND sh -c diff --git a/transformer_engine/common/ck_fused_attn/qola_manifest.toml b/transformer_engine/common/ck_fused_attn/qola_manifest.toml index 2b7e41537..255d11c18 100644 --- a/transformer_engine/common/ck_fused_attn/qola_manifest.toml +++ b/transformer_engine/common/ck_fused_attn/qola_manifest.toml @@ -1,5 +1,5 @@ [qola] -aiter_commit = "8f816a049449f39609ee7daca8c21d63aa4274ed" # pinned AITER submodule commit +aiter_commit = "4b00d2ea91e88b5381ea7051521956a716485f30" # pinned AITER submodule commit namespace = "te" rocm_versions = ["7.2"] From 82ebdb6c97a6adf1b0d1721bc4167329b3e9e926 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 19 May 2026 14:45:31 +0000 Subject: [PATCH 09/20] Drop guard for corrected bug --- .../common/ck_fused_attn/src/ck_fused_attn_bwd.cpp | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 638e4b877..0ff659949 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -462,17 +462,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.d_sink_ptr = nullptr; fmha_args.mask_type = static_cast(static_cast(args.attn_mask_type)); // Mirrors AITER's small-seqlen guard at aiter/ops/mha.py:1689. - const bool buggy_small_sq = (args.s_q < 16); - // Predicate matches exactly bwd_hd128_bf16_causal_br_a32_psskddv_group.co - // (broken by AITER PR #2189). Other psskddv_group variants are unaffected. - const bool buggy_br_psskddv_group = - args.is_group_mode() && - args.attn_mask_type == ck_fused_attn::MaskType::mask_bottom_right && - args.dtype == ck_fused_attn::DType::kBFloat16 && - args.d_qk == 128 && args.d_v == 128 && - args.is_v3_atomic_fp32; - fmha_args.use_asm_v3 = - (buggy_small_sq || buggy_br_psskddv_group) ? false : args.uses_bwd_v3; + fmha_args.use_asm_v3 = (args.s_q < 16) ? false : args.uses_bwd_v3; fmha_args.v3_atomic_fp32 = args.is_v3_atomic_fp32; fmha_args.v3_bf16_cvt = args.how_v3_bf16_cvt; fmha_args.v3_api_check = false; From f9ab59c28fc8222859a6bb8d9ea9859cf49742ac Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 28 May 2026 18:29:26 +0000 Subject: [PATCH 10/20] Update AITER commit, adopt new API --- .../ck_fused_attn/check_aiter_mha_args.py | 2 +- .../include/ck_fused_attn/ck_fused_attn.hpp | 1 - .../common/ck_fused_attn/qola_manifest.toml | 2 +- .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 67 +++++++++++++++---- .../common/fused_attn_rocm/fused_attn_ck.cpp | 9 --- 5 files changed, 57 insertions(+), 24 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py b/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py index 2cce9484f..6bae3091d 100644 --- a/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py +++ b/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py @@ -31,7 +31,7 @@ def parse_with_skip_comments(buffer, line, regex, outputs): def extract_fields_from_header(text: str, struct_name: str) -> List[str]: - struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*)?;\s*$") + struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*|\{[^;]*\})?;\s*$") struct_end_re = re.compile(r"^\s*};\s*$") struct_start_re = re.compile(rf"\bstruct\s+{re.escape(struct_name)}\b") diff --git a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp index 127d75b4c..abd1ec371 100644 --- a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp +++ b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp @@ -113,7 +113,6 @@ struct CkAttnBwdArgs : CKAttnCommonArgs { // dQ void* dq_ptr = nullptr; uint64_t stride_b_dq = 0, stride_h_dq = 0, stride_s_dq = 0; - void* dq_acc_ptr = nullptr; // dK / dV expanded (MQA/GQA reduction inputs; null when h==hg) void* dk_expanded_ptr = nullptr; diff --git a/transformer_engine/common/ck_fused_attn/qola_manifest.toml b/transformer_engine/common/ck_fused_attn/qola_manifest.toml index 255d11c18..2b445ae08 100644 --- a/transformer_engine/common/ck_fused_attn/qola_manifest.toml +++ b/transformer_engine/common/ck_fused_attn/qola_manifest.toml @@ -1,5 +1,5 @@ [qola] -aiter_commit = "4b00d2ea91e88b5381ea7051521956a716485f30" # pinned AITER submodule commit +aiter_commit = "e3940660b40f4764cdf09147af96a2a764f264be" # pinned AITER submodule commit namespace = "te" rocm_versions = ["7.2"] diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 0ff659949..145d9b139 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -6,9 +6,13 @@ #include #include +#include #include #include +#include +#include #include "ck_fused_attn/ck_fused_attn.hpp" +#include "ck_tile/host/pinned_host_releaser.hpp" #include "qola_mha_bwd.h" #include "ck_fused_attn_utils.hpp" @@ -364,7 +368,6 @@ void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args, log_value(log_file, "dk_ptr", fmha_args.dk_ptr); log_value(log_file, "dv_ptr", fmha_args.dv_ptr); log_value(log_file, "dbias_ptr", fmha_args.dbias_ptr); - log_value(log_file, "dq_acc_ptr", fmha_args.dq_acc_ptr); log_value(log_file, "seqstart_q_ptr", fmha_args.seqstart_q_ptr); log_value(log_file, "seqstart_k_ptr", fmha_args.seqstart_k_ptr); @@ -389,7 +392,6 @@ void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args, log_value(log_file, "stride_o", fmha_args.stride_o); log_value(log_file, "stride_randval", fmha_args.stride_randval); log_value(log_file, "stride_do", fmha_args.stride_do); - log_value(log_file, "stride_dq_acc", fmha_args.stride_dq_acc); log_value(log_file, "stride_dq", fmha_args.stride_dq); log_value(log_file, "stride_dk", fmha_args.stride_dk); log_value(log_file, "stride_dv", fmha_args.stride_dv); @@ -402,7 +404,6 @@ void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args, log_value(log_file, "nhead_stride_randval", fmha_args.nhead_stride_randval); log_value(log_file, "nhead_stride_do", fmha_args.nhead_stride_do); log_value(log_file, "nhead_stride_lsed", fmha_args.nhead_stride_lsed); - log_value(log_file, "nhead_stride_dq_acc", fmha_args.nhead_stride_dq_acc); log_value(log_file, "nhead_stride_dq", fmha_args.nhead_stride_dq); log_value(log_file, "nhead_stride_dk", fmha_args.nhead_stride_dk); log_value(log_file, "nhead_stride_dv", fmha_args.nhead_stride_dv); @@ -415,7 +416,6 @@ void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args, log_value(log_file, "batch_stride_randval", fmha_args.batch_stride_randval); log_value(log_file, "batch_stride_do", fmha_args.batch_stride_do); log_value(log_file, "batch_stride_lsed", fmha_args.batch_stride_lsed); - log_value(log_file, "batch_stride_dq_acc", fmha_args.batch_stride_dq_acc); log_value(log_file, "batch_stride_dq", fmha_args.batch_stride_dq); log_value(log_file, "batch_stride_dk", fmha_args.batch_stride_dk); log_value(log_file, "batch_stride_dv", fmha_args.batch_stride_dv); @@ -493,7 +493,6 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.dbias_ptr = ((!args.is_group_mode()) && has_dbias) ? (bias_shape==BiasShape::kBHSS ? args.dbias_ptr : args.dbias_expanded_ptr) : nullptr; - fmha_args.dq_acc_ptr = args.dq_acc_ptr; if (args.is_group_mode()) { fmha_args.seqstart_q_ptr = args.cu_seqlen_q_padded_ptr==nullptr? args.cu_seqlen_q_ptr : args.cu_seqlen_q_padded_ptr; @@ -509,8 +508,13 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.seqlen_q_ptr = nullptr; fmha_args.seqlen_k_ptr = nullptr; - fmha_args.seqlen_q = args.s_q; - fmha_args.seqlen_k = args.s_kv; + // Group mode contract (matches aiter asm_mha_varlen_bwd.cu): seqlen_q/k + // carry the total token counts, max_seqlen_q/k the per-sequence maximum. + // aiter sizes dq_acc and related workspaces from seqlen_q; passing the + // per-sequence length in group mode under-sizes them and the kernel writes + // past the end. + fmha_args.seqlen_q = args.is_group_mode() ? args.max_tokens_q : args.s_q; + fmha_args.seqlen_k = args.is_group_mode() ? args.max_tokens_kv : args.s_kv; fmha_args.batch = args.b; fmha_args.max_seqlen_q = args.s_q; fmha_args.max_seqlen_k = args.s_kv; @@ -527,8 +531,6 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.stride_o = args.stride_s_o; fmha_args.stride_randval = args.s_kv; fmha_args.stride_do = args.stride_s_do; - //dq_acc of shape (nsplits, B, H, S, D) - fmha_args.stride_dq_acc = args.d_qk; fmha_args.stride_dq = args.stride_s_dq; fmha_args.stride_dk = is_mqa_gqa? args.stride_s_dk_expanded : args.stride_s_dk; fmha_args.stride_dv = is_mqa_gqa? args.stride_s_dv_expanded : args.stride_s_dv; @@ -546,7 +548,6 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.nhead_stride_randval = args.is_group_mode() ? 0 : args.s_q * args.s_kv; fmha_args.nhead_stride_do = args.stride_h_do; fmha_args.nhead_stride_lsed = args.is_group_mode() ? args.max_tokens_q : args.s_q; - fmha_args.nhead_stride_dq_acc = static_cast((args.is_group_mode() ? args.max_tokens_q : args.s_q) * args.d_qk); fmha_args.nhead_stride_dq = args.stride_h_dq; fmha_args.nhead_stride_dk = is_mqa_gqa? args.stride_h_dk_expanded : args.stride_h_dk; fmha_args.nhead_stride_dv = is_mqa_gqa? args.stride_h_dv_expanded : args.stride_h_dv; @@ -562,13 +563,11 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.batch_stride_randval = args.is_group_mode() ? 0 : args.h * args.s_q * args.s_kv; fmha_args.batch_stride_do = args.is_group_mode() ? 0 : args.stride_b_do; fmha_args.batch_stride_lsed = args.is_group_mode() ? 0 : args.h * args.s_q; - fmha_args.batch_stride_dq_acc = args.is_group_mode() ? 0 : static_cast(args.h * args.s_q * args.d_qk); fmha_args.batch_stride_dq = args.is_group_mode() ? 0 : args.stride_b_dq; fmha_args.batch_stride_dk = args.is_group_mode() ? 0 : (is_mqa_gqa? args.stride_b_dk_expanded : args.stride_b_dk); fmha_args.batch_stride_dv = args.is_group_mode() ? 0 : (is_mqa_gqa? args.stride_b_dv_expanded : args.stride_b_dv); // for dbias, use h since h can be different from bias_h fmha_args.batch_stride_dbias = args.is_group_mode() ? 0 : args.h * args.s_q * args.s_kv; - fmha_args.split_stride_dq_acc = static_cast(args.is_group_mode() ? (args.max_tokens_q * args.h * args.d_qk) : (args.b * args.h * args.s_q * args.d_qk)); fmha_args.window_size_left = args.window_size_left; fmha_args.window_size_right = args.window_size_right; @@ -588,11 +587,55 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ } } + // Device-side workspace allocations made inside mha_bwd (launcher metadata + // and the dq_acc accumulator). aiter only contracts that the pointer remain + // valid for the duration of the kernels it enqueues; hipFreeAsync on the + // same stream defers the free until that work completes. + std::vector mha_bwd_workspaces; + fmha_args.workspace_alloc = [&mha_bwd_workspaces, stream](size_t bytes, bool zero_init) -> void* { + if(bytes == 0){ + return nullptr; + } + void* ptr = nullptr; + if(hipMallocAsync(&ptr, bytes, stream) != hipSuccess){ + throw std::runtime_error("ck_fused_attn bwd: hipMallocAsync failed for AITER workspace."); + } + if(zero_init){ + if(hipMemsetAsync(ptr, 0, bytes, stream) != hipSuccess){ + hipFreeAsync(ptr, stream); + throw std::runtime_error("ck_fused_attn bwd: hipMemsetAsync failed for AITER workspace."); + } + } + mha_bwd_workspaces.push_back(ptr); + return ptr; + }; + // Group mode requires a pinned host buffer for the async D2H seqstart + // pipeline; aiter keeps the shared_ptr alive past kernel completion via a + // stream-tail hipLaunchHostFunc keepalive. The deleter fires from that HIP + // callback thread, which holds runtime locks — calling any HIP API from it + // (including hipHostFree) deadlocks against concurrent main-thread HIP + // calls. Defer the free to ck_tile::pinned_host_releaser's worker thread. + fmha_args.pinned_host_alloc = [](size_t bytes) -> std::shared_ptr { + if(bytes == 0){ + return {}; + } + void* ptr = nullptr; + if(hipHostMalloc(&ptr, bytes, hipHostMallocDefault) != hipSuccess){ + throw std::runtime_error("ck_fused_attn bwd: hipHostMalloc failed for AITER pinned host buffer."); + } + return std::shared_ptr(ptr, [](void* p){ + ck_tile::pinned_host_releaser::instance().enqueue(p); + }); + }; + // print ck traits and args when needed if(log_file){ log_bwd_config(__FUNCTION__, fmha_args, log_file); } float average_runtime = QOLA_NS(mha_bwd)(fmha_args, stream_config); + for(void* ws_ptr : mha_bwd_workspaces){ + hipFreeAsync(ws_ptr, stream); + } if(average_runtime < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass."); diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 744d0575a..7586a8388 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -743,9 +743,6 @@ void fused_attn_ck_bwd_impl( bool is_mqa_gqa = (h > hg); - size_t kN0 = (d_qk <= 128)? 128:64; - size_t nsplits = deterministic? ceil(1.0*s_kv/kN0):1; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(layout); bool is_ragged = qkv_format==NVTE_QKV_Format::NVTE_THD; bool is_SBHD = qkv_format==NVTE_QKV_Format::NVTE_SBHD || qkv_format==NVTE_QKV_Format::NVTE_SBHD_2BSHD; @@ -770,9 +767,6 @@ void fused_attn_ck_bwd_impl( // First h*max_tokens_q*sizeof(float) is the lse-d buffer (passed as softmax_lsed) void* lse_workspace = planner.allocate(h*max_tokens_q*sizeof(float)); - // CK requires dq_acc ptr; size depends on deterministic mode - void* dq_acc_ptr = planner.allocate(nsplits*h*max_tokens_q*d_qk*sizeof(float)); - void* dk_expanded_ptr = nullptr; void* dv_expanded_ptr = nullptr; std::array dk_expanded_stride; @@ -913,8 +907,6 @@ void fused_attn_ck_bwd_impl( } // Initialize workspace buffers. - // dq_acc is of shape (nsplits, B, S, H, D_qk); CK requires zeroing - NVTE_CHECK_CUDA(cudaMemsetAsync(dq_acc_ptr, 0, sizeof(float)*nsplits*h*max_tokens_q*d_qk, stream)); if(devPtrAlibiSlope){ dim3 block, grid; block.x = 1024; @@ -992,7 +984,6 @@ void fused_attn_ck_bwd_impl( ck_args.attn_mask_type = set_ck_mask(mask_type, window_size_left, window_size_right); ck_args.window_size_left = window_size_left; ck_args.window_size_right = window_size_right; - ck_args.dq_acc_ptr = dq_acc_ptr; ck_args.dk_expanded_ptr = dk_expanded_ptr; ck_args.dv_expanded_ptr = dv_expanded_ptr; ck_args.lse_workspace_ptr = lse_workspace; From e68b4354c74fff4c373736b466fcc9b7bed3fd3d Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 29 May 2026 19:08:10 +0000 Subject: [PATCH 11/20] Initial two-tier lib w/ runtime dispatch -- WIP --- .../common/ck_fused_attn/CMakeLists.txt | 166 ++++++++++++------ .../common/ck_fused_attn/qola_manifest.toml | 2 +- .../ck_fused_attn/qola_manifest_gfx1250.toml | 34 ++++ .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 41 ++++- .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 11 ++ .../common/fused_attn_rocm/fused_attn.cpp | 7 - .../fused_attn_rocm/fused_attn_aotriton.cpp | 5 + .../common/fused_attn_rocm/fused_attn_ck.cpp | 22 ++- .../common/fused_attn_rocm/fused_attn_ck.h | 6 +- 9 files changed, 230 insertions(+), 64 deletions(-) create mode 100644 transformer_engine/common/ck_fused_attn/qola_manifest_gfx1250.toml diff --git a/transformer_engine/common/ck_fused_attn/CMakeLists.txt b/transformer_engine/common/ck_fused_attn/CMakeLists.txt index 09329ddeb..3a1913a96 100644 --- a/transformer_engine/common/ck_fused_attn/CMakeLists.txt +++ b/transformer_engine/common/ck_fused_attn/CMakeLists.txt @@ -8,20 +8,8 @@ project(ck_fused_attn LANGUAGES HIP CXX) set(AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha shared lib install prefix in TE") -#Corresponding runtime check is in nvte_get_fused_attn_backend() -list(FIND CMAKE_HIP_ARCHITECTURES "gfx1250" _gfx1250_idx) -if(NOT _gfx1250_idx EQUAL -1) - message(WARNING - "Removing unsupported gfx1250 from CMAKE_HIP_ARCHITECTURES for ck_fused_attn build.") - list(REMOVE_ITEM CMAKE_HIP_ARCHITECTURES "gfx1250") - list(LENGTH CMAKE_HIP_ARCHITECTURES _hip_arch_count) - if(_hip_arch_count EQUAL 0) - message(FATAL_ERROR - "No supported architectures remain for the ck_fused_attn build. " - "Re-run the build with FUSED_ATTN_CK backend disabled.") - endif() - set(GPU_TARGETS ${CMAKE_HIP_ARCHITECTURES}) -endif() +# gfx1250 carries AITER V3 bwd kernels only (hd128, bf16, batch mode). The +# runtime envelope is enforced in nvte_get_fused_attn_backend(). set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") set(__AITER_SOURCE_DIR "${__QOLA_DIR}/build/third_party/aiter") set(__CK_SOURCE_DIR "${__AITER_SOURCE_DIR}/3rdparty/composable_kernel") @@ -106,47 +94,105 @@ if(NOT EXISTS "${AITER_INCLUDE_DIR}") "Re-run the build to let `qola checkout` clone AITER.") endif() -if(DEFINED AITER_MHA_PATH) - message(STATUS "[AITER-BUILD] Using AITER_MHA_PATH=${AITER_MHA_PATH}") - # use pre-built te_libmha_fwd.so te_libmha_bwd.so - set(__AITER_MHA_PATH ${AITER_MHA_PATH}) -else() - set(__AITER_MHA_PATH "") - get_prebuilt_aiter(__AITER_MHA_PATH) +# Partition the requested HIP architectures into the CK-full set (CDNA, where +# the AITER CK FMHA template headers compile) and the V3-asm-only set. gfx1250 +# (RDNA4) has AITER V3 *backward* asm kernels but no CK FMHA support and no +# forward kernels, so it is built as a separate CK-free library (namespace +# te_v3, manifest qola_manifest_gfx1250.toml) and dispatched at runtime in +# ck_attn_bwd. The two tiers coexist via distinct QoLA namespaces. +set(__CK_FULL_ARCHS ${CMAKE_HIP_ARCHITECTURES}) +set(__HAS_GFX1250 FALSE) +list(FIND __CK_FULL_ARCHS "gfx1250" __GFX1250_IDX) +if(NOT __GFX1250_IDX EQUAL -1) + set(__HAS_GFX1250 TRUE) + list(REMOVE_ITEM __CK_FULL_ARCHS "gfx1250") +endif() +list(LENGTH __CK_FULL_ARCHS __CK_FULL_ARCH_COUNT) +if(__CK_FULL_ARCH_COUNT EQUAL 0 AND NOT __HAS_GFX1250) + message(FATAL_ERROR "ck_fused_attn: no target architectures requested.") +endif() - if(__AITER_MHA_PATH STREQUAL "") - # If not available, fallback: Build from source via QoLA - list(JOIN CMAKE_HIP_ARCHITECTURES ";" GPU_ARCHS_STR) - message(STATUS "[AITER-BUILD] Building AITER kernels for ${GPU_ARCHS_STR} via QoLA.") - set(__QOLA_BUILD_DIR "${__QOLA_DIR}/build") - # Same GIT_CONFIG_GLOBAL trick as the earlier `qola.cli checkout` call: - # `qola.cli build` re-invokes ensure_aiter_commit internally and will hit - # the same dubious-ownership trap without it. - execute_process( - COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${__QOLA_DIR}:$ENV{PYTHONPATH}" - ${Python_EXECUTABLE} -m qola.cli build - --manifest ${__QOLA_MANIFEST} - --aiter-root ${__AITER_SOURCE_DIR} - --output-dir ${__QOLA_BUILD_DIR} - --arch "${GPU_ARCHS_STR}" - RESULT_VARIABLE QOLA_BUILD_RESULT - ) - if(NOT QOLA_BUILD_RESULT EQUAL 0) - message(FATAL_ERROR "[AITER-BUILD] QoLA build failed.") +set(__AITER_MHA_PATH "") +set(__HAVE_CK_FULL FALSE) + +# --- CK-full libraries (CDNA): te_libmha_fwd.so / te_libmha_bwd.so --- +if(__CK_FULL_ARCH_COUNT GREATER 0) + set(__HAVE_CK_FULL TRUE) + if(DEFINED AITER_MHA_PATH) + message(STATUS "[AITER-BUILD] Using AITER_MHA_PATH=${AITER_MHA_PATH}") + # use pre-built te_libmha_fwd.so te_libmha_bwd.so + set(__AITER_MHA_PATH ${AITER_MHA_PATH}) + else() + get_prebuilt_aiter(__AITER_MHA_PATH) + + if(__AITER_MHA_PATH STREQUAL "") + # If not available, fallback: Build from source via QoLA + list(JOIN __CK_FULL_ARCHS ";" GPU_ARCHS_STR) + message(STATUS "[AITER-BUILD] Building CK-full AITER kernels for ${GPU_ARCHS_STR} via QoLA.") + set(__QOLA_BUILD_DIR "${__QOLA_DIR}/build") + execute_process( + COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${__QOLA_DIR}:$ENV{PYTHONPATH}" + ${Python_EXECUTABLE} -m qola.cli build + --manifest ${__QOLA_MANIFEST} + --aiter-root ${__AITER_SOURCE_DIR} + --output-dir ${__QOLA_BUILD_DIR} + --arch "${GPU_ARCHS_STR}" + RESULT_VARIABLE QOLA_BUILD_RESULT + ) + if(NOT QOLA_BUILD_RESULT EQUAL 0) + message(FATAL_ERROR "[AITER-BUILD] QoLA build failed.") + endif() + + # Copy the final .so libs and exported public headers into the aiter + # prebuilt cache so downstream consumers see a self-contained tree. + get_default_aiter_cache_dir(__QOLA_CACHE_DIR) + set(__QOLA_CACHE_LIB "${__QOLA_CACHE_DIR}/lib") + file(MAKE_DIRECTORY ${__QOLA_CACHE_LIB}) + file(GLOB __QOLA_BUILT_LIBS "${__QOLA_BUILD_DIR}/lib/*.so") + file(COPY ${__QOLA_BUILT_LIBS} DESTINATION ${__QOLA_CACHE_LIB}) + file(COPY "${__QOLA_BUILD_DIR}/include" DESTINATION "${__QOLA_CACHE_DIR}") + set(__AITER_MHA_PATH "${__QOLA_CACHE_LIB}") + else() + message(STATUS "[AITER-BUILD] Using pre-built AITER from ${__AITER_MHA_PATH}") endif() + endif() +endif() + +# --- V3-asm-only backward library (gfx1250): te_v3_libmha_bwd.so --- +# There is no prebuilt cache path for gfx1250 (no public prebuilt, and a CK-free +# asm build is cheap), so always build it from source via QoLA. Both manifests +# pin the same AITER commit and share the already-checked-out source tree. +if(__HAS_GFX1250) + set(__QOLA_MANIFEST_V3 "${CMAKE_CURRENT_LIST_DIR}/qola_manifest_gfx1250.toml") + set(__QOLA_BUILD_DIR_V3 "${__QOLA_DIR}/build_gfx1250") + message(STATUS "[AITER-BUILD] Building CK-free V3 backward (gfx1250) via QoLA.") + # The asm-only / CK-free flags (ONLY_FAV3=1, ENABLE_CK=0) are carried by the + # gfx1250 manifest's libmha_bwd module, so no special env is needed here. + execute_process( + COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${__QOLA_DIR}:$ENV{PYTHONPATH}" + ${Python_EXECUTABLE} -m qola.cli build + --manifest ${__QOLA_MANIFEST_V3} + --aiter-root ${__AITER_SOURCE_DIR} + --output-dir ${__QOLA_BUILD_DIR_V3} + --arch "gfx1250" + RESULT_VARIABLE QOLA_V3_BUILD_RESULT + ) + if(NOT QOLA_V3_BUILD_RESULT EQUAL 0) + message(FATAL_ERROR "[AITER-BUILD] QoLA gfx1250 V3 build failed.") + endif() - # Copy the final .so libs and exported public headers into the aiter - # prebuilt cache so downstream consumers see a self-contained tree. + # Stage the v3 lib next to the CK-full libs so a single link/-L/install path + # covers both. For a gfx1250-only build there are no CK-full libs, so set up + # the cache lib dir here and stage the v3 public headers too. + if(__AITER_MHA_PATH STREQUAL "") get_default_aiter_cache_dir(__QOLA_CACHE_DIR) set(__QOLA_CACHE_LIB "${__QOLA_CACHE_DIR}/lib") file(MAKE_DIRECTORY ${__QOLA_CACHE_LIB}) - file(GLOB __QOLA_BUILT_LIBS "${__QOLA_BUILD_DIR}/lib/*.so") - file(COPY ${__QOLA_BUILT_LIBS} DESTINATION ${__QOLA_CACHE_LIB}) - file(COPY "${__QOLA_BUILD_DIR}/include" DESTINATION "${__QOLA_CACHE_DIR}") + file(COPY "${__QOLA_BUILD_DIR_V3}/include" DESTINATION "${__QOLA_CACHE_DIR}") set(__AITER_MHA_PATH "${__QOLA_CACHE_LIB}") - else() - message(STATUS "[AITER-BUILD] Using pre-built AITER from ${__AITER_MHA_PATH}") endif() + file(GLOB __QOLA_V3_LIBS "${__QOLA_BUILD_DIR_V3}/lib/te_v3_*.so") + file(COPY ${__QOLA_V3_LIBS} DESTINATION ${__AITER_MHA_PATH}) endif() set(ck_fused_attn_SOURCES) @@ -165,6 +211,16 @@ set(CK_FUSED_ATTN_COMPILE_OPTIONS) list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT} -DENABLE_CK=1) +# Tier guards consumed by src/ck_fused_attn_{fwd,bwd}.cpp: +# NVTE_AITER_CK_FULL -> qola::te::{mha_fwd,mha_bwd} (CDNA) are linked +# NVTE_AITER_V3_BWD_GFX1250 -> qola::te_v3::mha_bwd (gfx1250) is linked + +# runtime-dispatched in ck_attn_bwd +if(__HAVE_CK_FULL) + list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS -DNVTE_AITER_CK_FULL) +endif() +if(__HAS_GFX1250) + list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS -DNVTE_AITER_V3_BWD_GFX1250) +endif() # Public QoLA headers ship alongside the .so libs in ${__AITER_MHA_PATH}/../include # (emitted by qola.cli build, or copied from the QoLA build dir above for the @@ -181,10 +237,22 @@ target_include_directories(ck_fused_attn PRIVATE ${__QOLA_INCLUDE_DIR}) find_package(hip) target_link_directories(ck_fused_attn PUBLIC ${__AITER_MHA_PATH}) -list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64 -l:te_libmha_fwd.so -l:te_libmha_bwd.so) +list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64) +set(__INSTALL_AITER_LIBS) +if(__HAVE_CK_FULL) + list(APPEND ck_fused_attn_LINKER_LIBS -l:te_libmha_fwd.so -l:te_libmha_bwd.so) + list(APPEND __INSTALL_AITER_LIBS + ${__AITER_MHA_PATH}/te_libmha_fwd.so + ${__AITER_MHA_PATH}/te_libmha_bwd.so) +endif() +if(__HAS_GFX1250) + list(APPEND ck_fused_attn_LINKER_LIBS -l:te_v3_libmha_bwd.so) + list(APPEND __INSTALL_AITER_LIBS + ${__AITER_MHA_PATH}/te_v3_libmha_bwd.so) +endif() target_link_libraries(ck_fused_attn PUBLIC ${ck_fused_attn_LINKER_LIBS}) target_compile_options(ck_fused_attn PRIVATE ${CK_FUSED_ATTN_COMPILE_OPTIONS}) set_target_properties(ck_fused_attn PROPERTIES INSTALL_RPATH "$ORIGIN") -install(FILES ${__AITER_MHA_PATH}/te_libmha_fwd.so ${__AITER_MHA_PATH}/te_libmha_bwd.so DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib) +install(FILES ${__INSTALL_AITER_LIBS} DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib) install(TARGETS ck_fused_attn DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib) diff --git a/transformer_engine/common/ck_fused_attn/qola_manifest.toml b/transformer_engine/common/ck_fused_attn/qola_manifest.toml index 2b445ae08..b81e47ceb 100644 --- a/transformer_engine/common/ck_fused_attn/qola_manifest.toml +++ b/transformer_engine/common/ck_fused_attn/qola_manifest.toml @@ -1,5 +1,5 @@ [qola] -aiter_commit = "e3940660b40f4764cdf09147af96a2a764f264be" # pinned AITER submodule commit +aiter_commit = "f03a4ec572bb3d9e15da3b346763c8f126feec0d" # pinned AITER submodule commit namespace = "te" rocm_versions = ["7.2"] diff --git a/transformer_engine/common/ck_fused_attn/qola_manifest_gfx1250.toml b/transformer_engine/common/ck_fused_attn/qola_manifest_gfx1250.toml new file mode 100644 index 000000000..23b59027e --- /dev/null +++ b/transformer_engine/common/ck_fused_attn/qola_manifest_gfx1250.toml @@ -0,0 +1,34 @@ +# gfx1250 (RDNA4) carries AITER V3 *backward* asm kernels only — there are no +# forward kernels at the pinned commit, and the CK FMHA template headers do not +# compile for gfx1250. This manifest therefore builds a CK-free (ENABLE_CK=0), +# asm-v3-only backward library under a distinct namespace (te_v3) so it can +# coexist with the CK-full te_* libraries in a multi-arch build. TE selects +# between qola::te::mha_bwd and qola::te_v3::mha_bwd at runtime by device arch. +# +# Keep aiter_commit in lockstep with qola_manifest.toml — both consume the same +# checked-out AITER source tree. +[qola] +aiter_commit = "f03a4ec572bb3d9e15da3b346763c8f126feec0d" # pinned AITER submodule commit +namespace = "te_v3" +rocm_versions = ["7.2"] + +[build] +architectures = ["gfx1250"] + +# Reuse the torch-free libmha_bwd module (sources = mha_bwd.cu only; the same +# source the CK-full te_libmha_bwd.so builds from). Do NOT use +# module_fmha_v3_bwd here — it pulls in mha_common.cu, which includes +# and is therefore torch-dependent. To make this build +# CK-free and asm-only, two independent gates in mha_bwd.cu must both be set: +# - ONLY_FAV3=1 selects the asm-only dispatch (`#if ONLY_FAV3` returns the +# fmha_v3_bwd result; the `#else` branch instantiates CK fmha_bwd_traits), +# - ENABLE_CK=0 strips the CK fmha_bwd.hpp include and uses the ck_tile shim, +# mirroring AITER's own module_fmha_v3_bwd. drop_directions=["bwd"] removes the +# CK `generate.py -d bwd` codegen (the HSA codegen has no -d and is kept). +# flags_extra_cc values are eval'd as Python expressions (optCompilerConfig +# convention), hence the inner quotes; this replaces libmha_bwd's empty list. +[[modules]] +name = "libmha_bwd" +mode = "cpp_itfs" +drop_directions = ["bwd"] +flags_extra_cc = ["'-DONLY_FAV3=1'", "'-DENABLE_CK=0'"] diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 145d9b139..26a2def60 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -16,8 +16,33 @@ #include "qola_mha_bwd.h" #include "ck_fused_attn_utils.hpp" +// Staged gfx1250 backward dispatch. When this build includes the CK-free V3 +// backward library (te_v3_libmha_bwd.so, built for gfx1250), declare its +// namespaced entry point so ck_attn_bwd can route to it on gfx1250 devices at +// runtime. The CK-full path (QOLA_NS(mha_bwd) == qola::te::mha_bwd) is used on +// all other archs. +#if defined(NVTE_AITER_V3_BWD_GFX1250) +namespace qola { namespace te_v3 { +float mha_bwd(const aiter::mha_bwd_args& args, const ck_tile::stream_config& stream_config); +}} // namespace qola::te_v3 +#endif + namespace ck_fused_attn{ +#if defined(NVTE_AITER_V3_BWD_GFX1250) +namespace { +// True when the active device is gfx1250 (gcnArchName may carry feature +// suffixes, e.g. "gfx1250:sramecc+", so match on prefix). +bool is_gfx1250_device(){ + int dev = 0; + if(hipGetDevice(&dev) != hipSuccess){ return false; } + hipDeviceProp_t prop{}; + if(hipGetDeviceProperties(&prop, dev) != hipSuccess){ return false; } + return std::string(prop.gcnArchName).rfind("gfx1250", 0) == 0; +} +} // namespace +#endif + // TODO: unify with binary search in TE/common/fused_attn(rocm)/util // no device std::upper_bound // in an increasing array with given size len, search for the index that: @@ -632,7 +657,21 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ if(log_file){ log_bwd_config(__FUNCTION__, fmha_args, log_file); } - float average_runtime = QOLA_NS(mha_bwd)(fmha_args, stream_config); + float average_runtime; +#if defined(NVTE_AITER_V3_BWD_GFX1250) + if(is_gfx1250_device()){ + average_runtime = qola::te_v3::mha_bwd(fmha_args, stream_config); + } else +#endif + { +#if defined(NVTE_AITER_CK_FULL) + average_runtime = QOLA_NS(mha_bwd)(fmha_args, stream_config); +#else + throw std::runtime_error( + "ck_fused_attn bwd: this build has no CK-full AITER backward library " + "(no CDNA archs built); only the staged gfx1250 V3 path is present."); +#endif + } for(void* ws_ptr : mha_bwd_workspaces){ hipFreeAsync(ws_ptr, stream); } diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index 0f4e9a424..074ad0042 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -224,7 +224,18 @@ hipError_t ck_attn_fwd(const CKAttnFwdArgs& args, hipStream_t stream){ if(log_file){ log_fwd_config(__FUNCTION__, has_dropout, fmha_args, log_file); } +#if defined(NVTE_AITER_CK_FULL) float average_runtime = QOLA_NS(mha_fwd)(fmha_args, stream_config); +#else + // gfx1250-only build: no CK-full forward library exists (gfx1250 has no + // forward kernels). The unified backend selector never picks CK on gfx1250, + // so this path is unreachable at runtime; the guard only keeps the link + // closed when te_libmha_fwd.so is absent. + float average_runtime = -1.0f; + throw std::runtime_error( + "ck_fused_attn fwd: no CK-full AITER forward library in this build " + "(gfx1250 has no forward kernels)."); +#endif if(average_runtime < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn fwd pass."); diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index 1f837be41..93b8ff0c5 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -11,7 +11,6 @@ #include "fused_attn_aotriton.h" #include "fused_attn_ck.h" #include "../common.h" -#include "../util/cuda_runtime.h" //cuda::sm_arch #include "utils.h" // map NVTE_QKV_Layout to NVTE_QKV_Layout_Group @@ -283,12 +282,6 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( int64_t window_size_right, bool return_max_logit, bool cuda_graph) { using namespace transformer_engine; - //gfx1250 is disabled in ck_fused_attn/CMakeLists.txt and is not supported by curretnt aotriton - const int gpu_arch = cuda::sm_arch(cuda::current_device()); - if (gpu_arch == 125) { - return NVTE_Fused_Attn_Backend::NVTE_No_Backend; - } - // TODO: Add return_max_logit support if (return_max_logit) return NVTE_Fused_Attn_Backend::NVTE_No_Backend; diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 9a0161ca5..0c78f5a24 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -54,6 +54,11 @@ bool is_aotriton_backend_supported( int64_t window_size_right) { #ifdef USE_FUSED_ATTN_AOTRITON + // AOTriton has no gfx1250 support. + if(cuda::sm_arch(cuda::current_device()) == 125){ + return false; + } + //TODO: release after AOTriton support support Multi-latent attention if(head_dim_qk != head_dim_v){ return false; diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 7586a8388..d4b9b005e 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -30,9 +30,9 @@ bool is_ck_backend_supported( float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, - size_t head_dim_qk, - size_t head_dim_v, - int64_t window_size_left, + size_t head_dim_qk, + size_t head_dim_v, + int64_t window_size_left, int64_t window_size_right) { #ifdef USE_FUSED_ATTN_CK @@ -154,6 +154,22 @@ bool is_ck_backend_supported( } return false; } + + // gfx1250 (RDNA4) ships AITER V3 *backward* asm kernels only — there are no + // forward kernels at the pinned commit. TE selects one fused-attn backend per + // op and uses it for both directions (backward inherits the forward's choice), + // so selecting CK here would route the forward into a kernel-less path. + // Until the forward is handled (direction-aware backend selection, or gfx1250 + // forward kernels), do not select CK on gfx1250 through this unified path. + // The CK-free V3 backward library (te_v3_module_fmha_v3_bwd.so) and the + // runtime dispatch in ck_attn_bwd are built and staged for that activation. + if(cuda::sm_arch(cuda::current_device()) == 125){ + if(nvte_log_ck_config){ + std::cout<<"gfx1250 CK fused attn is staged (backward-only); not selected via the unified backend yet"< Date: Fri, 5 Jun 2026 19:21:55 +0000 Subject: [PATCH 12/20] Added guard against graph-unsafe CK V2 kernels --- .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 31 ++++++++++++ .../common/fused_attn_rocm/fused_attn.cpp | 3 +- .../common/fused_attn_rocm/fused_attn_ck.cpp | 50 +++++++++++++++++-- .../common/fused_attn_rocm/fused_attn_ck.h | 9 ++-- .../attention/dot_product_attention/utils.py | 18 +++++++ 5 files changed, 102 insertions(+), 9 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 145d9b139..0cc8c7689 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -632,6 +632,37 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ if(log_file){ log_bwd_config(__FUNCTION__, fmha_args, log_file); } + + // Graph-capture safety net. The CK v2 launcher (fmha_bwd / prepare_workspace_async) + // schedules self-deleting hipLaunchHostFunc nodes that re-run and double-free on + // every graph replay, so it must never be captured. Only the v3 asm path is + // graph-replay-safe. Backend selection already steers graph-captured training off + // these configs, but context-parallel and direct callers bypass that path, so we + // refuse a v2-bound dispatch under active capture rather than corrupt memory on + // replay. Conditions mirror AITER's fmha_v3_bwd gate (csrc/cpp_itfs/mha_bwd.cu). + hipStreamCaptureStatus capture_status = hipStreamCaptureStatusNone; + if(hipStreamIsCapturing(stream, &capture_status) == hipSuccess && + capture_status != hipStreamCaptureStatusNone){ + int dev = 0; + hipDeviceProp_t prop{}; + bool is_v3_arch = false; + if(hipGetDevice(&dev) == hipSuccess && hipGetDeviceProperties(&prop, dev) == hipSuccess){ + std::string arch_name(prop.gcnArchName); + is_v3_arch = arch_name.find("gfx942") != std::string::npos || + arch_name.find("gfx950") != std::string::npos; + } + bool resolves_to_v3 = fmha_args.use_asm_v3 && !fmha_args.is_deterministic && + !fmha_args.has_dbias && fmha_args.bias_type == 0 && + !fmha_args.has_dropout && is_v3_arch; + if(!resolves_to_v3){ + throw std::runtime_error( + "ck_fused_attn bwd: this configuration dispatches to the CK v2 launcher, which " + "is not HIP-graph-replay-safe (self-deleting host nodes in prepare_workspace_async). " + "Disable determinism/dropout/bias and run on gfx942/gfx950 with NVTE_CK_USES_BWD_V3=1 " + "to use the v3 asm path, or set NVTE_FUSED_ATTN_CK=0 under CUDA graphs."); + } + } + float average_runtime = QOLA_NS(mha_bwd)(fmha_args, stream_config); for(void* ws_ptr : mha_bwd_workspaces){ hipFreeAsync(ws_ptr, stream); diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index 1f837be41..fe189f521 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -329,7 +329,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( head_dim_qk, head_dim_v, window_size_left, - window_size_right)){ + window_size_right, + is_training, cuda_graph)){ return NVTE_Fused_Attn_Backend::NVTE_CK; }else if(nvte_fused_attn_aotriton && fused_attn_rocm::is_aotriton_backend_supported( q_dtype, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 7586a8388..460d89be4 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -18,6 +18,33 @@ namespace transformer_engine { namespace fused_attn_rocm { +#ifdef USE_FUSED_ATTN_CK +// Returns false when a CK backward for this config would dispatch to the CK v2 +// launcher (fmha_bwd / prepare_workspace_async), which schedules self-deleting +// hipLaunchHostFunc nodes that double-free on graph replay. Only the v3 asm bwd +// path is HIP-graph-replay-safe. Mirrors AITER's fmha_v3_bwd gate (mha_bwd.cu) +// for the conditions visible at backend-selection time; determinism is applied +// separately on the framework side. +static bool is_ck_bwd_graph_capture_safe( + NVTE_Bias_Type bias_type, + float dropout, + size_t max_seqlen_q) { + // The CK v2 launcher is reached whenever the v3 asm bwd path is not taken. + // v3 requires gfx942/gfx950, no dropout, no bias, and (per TE's use_asm_v3 rule + // in ck_fused_attn_bwd.cpp) max_seqlen_q >= 16. NVTE_CK_USES_BWD_V3 can force the + // v2 path off entirely. + bool uses_bwd_v3 = getenv("NVTE_CK_USES_BWD_V3", 1); + const std::string& arch = cuda::sm_arch_name(); + bool is_v3_arch = arch.find("gfx942") != std::string::npos || + arch.find("gfx950") != std::string::npos; + return uses_bwd_v3 && + is_v3_arch && + dropout == 0.f && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + max_seqlen_q >= 16; +} +#endif // USE_FUSED_ATTN_CK + // check the fused attn config to see whether it's ck backend supported // single filtering followed by joint filtering bool is_ck_backend_supported( @@ -30,10 +57,11 @@ bool is_ck_backend_supported( float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, - size_t head_dim_qk, - size_t head_dim_v, - int64_t window_size_left, - int64_t window_size_right) { + size_t head_dim_qk, + size_t head_dim_v, + int64_t window_size_left, + int64_t window_size_right, + bool is_training, bool cuda_graph) { #ifdef USE_FUSED_ATTN_CK @@ -154,6 +182,20 @@ bool is_ck_backend_supported( } return false; } + + // Under HIP-graph capture, CK backward must take the graph-replay-safe v3 asm + // path; a config that would fall back to the CK v2 launcher is not graph-safe. + // Reject such graph-captured training configs so selection falls through to a + // graph-safe backend (the v2 host-pack hazard is backward-only, so inference is + // unaffected). Determinism also forces v2 but is invisible here, so it is handled + // on the framework side. + if(is_training && cuda_graph && + !is_ck_bwd_graph_capture_safe(bias_type, dropout, max_seqlen_q)){ + if(nvte_log_ck_config){ + std::cout<<"CK backward would use the v2 launcher, which is not HIP-graph-replay-safe"< Date: Fri, 5 Jun 2026 20:43:17 +0000 Subject: [PATCH 13/20] Update AITER commit --- transformer_engine/common/ck_fused_attn/qola_manifest.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/ck_fused_attn/qola_manifest.toml b/transformer_engine/common/ck_fused_attn/qola_manifest.toml index b81e47ceb..e3a5bb2b7 100644 --- a/transformer_engine/common/ck_fused_attn/qola_manifest.toml +++ b/transformer_engine/common/ck_fused_attn/qola_manifest.toml @@ -1,5 +1,5 @@ [qola] -aiter_commit = "f03a4ec572bb3d9e15da3b346763c8f126feec0d" # pinned AITER submodule commit +aiter_commit = "6aeba412fa057a3d1bf9e1811ddecc9e9cb2af7a" # pinned AITER submodule commit namespace = "te" rocm_versions = ["7.2"] From 77164823778b5b6c51d665d2c5b6ec72276a2cbb Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 5 Jun 2026 20:52:56 +0000 Subject: [PATCH 14/20] Updated QoLA commit --- 3rdparty/QoLA | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/QoLA b/3rdparty/QoLA index 9c13e77ef..331971c17 160000 --- a/3rdparty/QoLA +++ b/3rdparty/QoLA @@ -1 +1 @@ -Subproject commit 9c13e77ef3cf89053aad61ed3a0f27470f123ee5 +Subproject commit 331971c17de638c21184fcaae239cdbad6f8e26e From eaf6b94eb33d1fb7e5d58ad3bb3ab43de86b1426 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 11 Jun 2026 19:18:10 +0000 Subject: [PATCH 15/20] Bumped aiter commit --- transformer_engine/common/ck_fused_attn/qola_manifest.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/ck_fused_attn/qola_manifest.toml b/transformer_engine/common/ck_fused_attn/qola_manifest.toml index e3a5bb2b7..eefbe2d6b 100644 --- a/transformer_engine/common/ck_fused_attn/qola_manifest.toml +++ b/transformer_engine/common/ck_fused_attn/qola_manifest.toml @@ -1,5 +1,5 @@ [qola] -aiter_commit = "6aeba412fa057a3d1bf9e1811ddecc9e9cb2af7a" # pinned AITER submodule commit +aiter_commit = "bb1010b249377e53c6bad264b2be28525f0fc06e" # pinned AITER submodule commit namespace = "te" rocm_versions = ["7.2"] From bbe3e3b504e04c4f4b48f01711c2f42ce812905e Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 11 Jun 2026 20:41:16 +0000 Subject: [PATCH 16/20] CK-free build for gfx1250 --- .../common/ck_fused_attn/CMakeLists.txt | 11 ++++++++++- .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 6 ++++++ .../ck_fused_attn/src/ck_fused_attn_utils.cpp | 6 ++++++ .../ck_fused_attn/src/ck_fused_attn_utils.hpp | 18 ++++++++++++++++++ 4 files changed, 40 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/ck_fused_attn/CMakeLists.txt b/transformer_engine/common/ck_fused_attn/CMakeLists.txt index 3a1913a96..b7fdaf74c 100644 --- a/transformer_engine/common/ck_fused_attn/CMakeLists.txt +++ b/transformer_engine/common/ck_fused_attn/CMakeLists.txt @@ -207,10 +207,19 @@ foreach(file ${ck_fused_attn_SOURCES}) endforeach() add_library(ck_fused_attn SHARED ${ck_fused_attn_SOURCES}) +# ENABLE_CK gates the wrapper's CK-tile dependency: with a CK-full arch present +# (CDNA) the headers pull real ck_tile; for a gfx1250-only build the real CK +# template headers do not compile, so build CK-free (ENABLE_CK=0) and let the +# wrapper fall back to the ck_tile shim + HIP-native 16-bit numerics. +if(__HAVE_CK_FULL) + set(__ENABLE_CK 1) +else() + set(__ENABLE_CK 0) +endif() set(CK_FUSED_ATTN_COMPILE_OPTIONS) list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT} - -DENABLE_CK=1) + -DENABLE_CK=${__ENABLE_CK}) # Tier guards consumed by src/ck_fused_attn_{fwd,bwd}.cpp: # NVTE_AITER_CK_FULL -> qola::te::{mha_fwd,mha_bwd} (CDNA) are linked # NVTE_AITER_V3_BWD_GFX1250 -> qola::te_v3::mha_bwd (gfx1250) is linked + diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index 074ad0042..c1576999f 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -201,7 +201,13 @@ hipError_t ck_attn_fwd(const CKAttnFwdArgs& args, hipStream_t stream){ fmha_args.is_group_mode = args.is_group_mode(); fmha_args.bias_type = static_cast(bias_type); fmha_args.has_lse = args.lse_ptr!=nullptr; +#if ENABLE_CK fmha_args.qscale_type = static_cast(quant_scale_enum::no_scale); +#else + // quant_scale_enum lives in the CK example headers (quant.hpp), absent in the + // CK-free build. no_scale == 0; this fwd path is unused on gfx1250 anyway. + fmha_args.qscale_type = 0; +#endif fmha_args.has_sink = false; fmha_args.q_descale_ptr = nullptr; fmha_args.k_descale_ptr = nullptr; diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp index 936639f27..d99f4f44b 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp @@ -15,8 +15,14 @@ #include #include "ck_fused_attn_utils.hpp" #include "ck_fused_attn/ck_fused_attn.hpp" +#if ENABLE_CK #include "mask.hpp" #include "bias.hpp" +#else +// CK-free (gfx1250) build: mask.hpp/bias.hpp are CK example headers that do not +// compile for this arch. mask_enum / bias_enum come from the ck_tile shim +// (pulled via ck_fused_attn_utils.hpp -> aiter_hip_common.h) instead. +#endif namespace ck_fused_attn{ diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp index a926d230d..95010e7cb 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp @@ -10,7 +10,25 @@ #include #include #include +#if ENABLE_CK #include "ck_tile/host.hpp" +#else +// CK-free (gfx1250) build: the real ck_tile template headers do not compile for +// this arch. Pull the lightweight shim (ck_tile::index_t / stream_config, plus +// mask_enum / bias_enum) via aiter_hip_common.h, and provide the few ck_tile +// numeric symbols the reduction kernels use. half_t is the _Float16 scalar +// (unambiguous float arithmetic, unlike __half which is ambiguous in the +// kernels' implicit-conversion else-branch); bf16_t is __hip_bfloat16. Both are +// bit-compatible with fp16/bf16 so the kernel sources stay unchanged. +#include "aiter_hip_common.h" +#include +namespace ck_tile { +using half_t = _Float16; +using bf16_t = __hip_bfloat16; +__host__ __device__ inline float bf16_to_float(bf16_t x) { return __bfloat162float(x); } +__host__ __device__ inline bf16_t float_to_bf16(float x) { return __float2bfloat16(x); } +} // namespace ck_tile +#endif #include "ck_fused_attn/ck_fused_attn.hpp" //forward declaration for ck_tile enum From 794bf7de308c62295514e8503d64cec9a9c751fe Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 11 Jun 2026 21:01:42 +0000 Subject: [PATCH 17/20] Add exclusions for CK-free build --- transformer_engine/common/CMakeLists.txt | 30 ++++++++++++++----- .../common/gemm/cublaslt_gemm.cu | 8 ++++- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 02eaaea93..39201397d 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -291,14 +291,27 @@ if(USE_ROCM) fused_attn_rocm/fused_attn_aotriton.cpp fused_attn_rocm/fused_attn_ck.cpp fused_attn_rocm/utils.cpp - gemm/ck_grouped_gemm/ck_grouped_gemm.cpp - gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp - gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp - gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nn.cpp - gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nt.cpp - gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tn.cpp - gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tt.cpp gemm/rocm_gemm.cu) + + # CK grouped GEMM instantiates CK-tile GEMM kernels, which do not compile for + # gfx1250 (RDNA4 - CK-tile has no arch definition for it). Build it only when + # no gfx1250 target is requested; otherwise the runtime falls back to the + # hipBLASLt grouped-GEMM path. NVTE_CK_GROUPED_GEMM gates the matching call + # site in gemm/cublaslt_gemm.cu. + if(NOT CMAKE_HIP_ARCHITECTURES MATCHES "gfx1250") + set(NVTE_CK_GROUPED_GEMM ON) + list(INSERT transformer_engine_SOURCES 0 + gemm/ck_grouped_gemm/ck_grouped_gemm.cpp + gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp + gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp + gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nn.cpp + gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nt.cpp + gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tn.cpp + gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tt.cpp) + else() + message(STATUS "ck_grouped_gemm disabled (gfx1250 target; CK-tile unsupported) - " + "using hipBLASLt grouped GEMM fallback") + endif() endif() if(USE_CUDA) @@ -385,6 +398,9 @@ set_property( else() set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) target_include_directories(transformer_engine PRIVATE ${CK_ROOT}/include) + if(NVTE_CK_GROUPED_GEMM) + target_compile_definitions(transformer_engine PRIVATE NVTE_CK_GROUPED_GEMM) + endif() endif() #USE_CUDA # Configure dependencies diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 35cad5092..c05032f0e 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -31,7 +31,7 @@ #include "./config.h" #ifndef __HIP_PLATFORM_AMD__ #include "./cutlass_grouped_gemm.cuh" -#else +#elif defined(NVTE_CK_GROUPED_GEMM) #include "ck_grouped_gemm/ck_grouped_gemm.h" #endif @@ -1194,12 +1194,18 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() && #ifdef __HIP_PLATFORM_AMD__ true) { +#ifdef NVTE_CK_GROUPED_GEMM if (!ck_tile_grouped_gemm(A, B, D, num_gemms, transa, transb, workspace, accumulate, stream)) { if (warn_fallback) { NVTE_WARN("Fallback to cuBLAS grouped GEMM."); } cublas_path(); } +#else + // CK grouped GEMM is not built for this arch (e.g. gfx1250, where CK-tile + // does not compile); use the hipBLASLt grouped-GEMM path directly. + cublas_path(); +#endif #else all_groups_uniform_k128(B, transb)) { cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate, From 86c08f9b4df42c5bcba06470b4c9b1886ab10a7c Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Thu, 11 Jun 2026 18:49:50 -0400 Subject: [PATCH 18/20] FWD V3 support --- 3rdparty/QoLA | 2 +- .../common/ck_fused_attn/CMakeLists.txt | 12 +++-- .../ck_fused_attn/qola_manifest_gfx1250.toml | 27 +++++++++- .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 51 +++++++++++++++---- .../common/fused_attn_rocm/fused_attn_ck.cpp | 15 ------ 5 files changed, 77 insertions(+), 30 deletions(-) diff --git a/3rdparty/QoLA b/3rdparty/QoLA index 331971c17..5349b3d27 160000 --- a/3rdparty/QoLA +++ b/3rdparty/QoLA @@ -1 +1 @@ -Subproject commit 331971c17de638c21184fcaae239cdbad6f8e26e +Subproject commit 5349b3d27e33c5beb7f2304ed00cc44a03990ec7 diff --git a/transformer_engine/common/ck_fused_attn/CMakeLists.txt b/transformer_engine/common/ck_fused_attn/CMakeLists.txt index b7fdaf74c..4b4a46718 100644 --- a/transformer_engine/common/ck_fused_attn/CMakeLists.txt +++ b/transformer_engine/common/ck_fused_attn/CMakeLists.txt @@ -137,6 +137,7 @@ if(__CK_FULL_ARCH_COUNT GREATER 0) --aiter-root ${__AITER_SOURCE_DIR} --output-dir ${__QOLA_BUILD_DIR} --arch "${GPU_ARCHS_STR}" + --skip-checkout RESULT_VARIABLE QOLA_BUILD_RESULT ) if(NOT QOLA_BUILD_RESULT EQUAL 0) @@ -158,14 +159,14 @@ if(__CK_FULL_ARCH_COUNT GREATER 0) endif() endif() -# --- V3-asm-only backward library (gfx1250): te_v3_libmha_bwd.so --- +# --- V3-asm-only libraries for gfx1250: te_v3_libmha_fwd.so and te_v3_libmha_bwd.so --- # There is no prebuilt cache path for gfx1250 (no public prebuilt, and a CK-free # asm build is cheap), so always build it from source via QoLA. Both manifests # pin the same AITER commit and share the already-checked-out source tree. if(__HAS_GFX1250) set(__QOLA_MANIFEST_V3 "${CMAKE_CURRENT_LIST_DIR}/qola_manifest_gfx1250.toml") set(__QOLA_BUILD_DIR_V3 "${__QOLA_DIR}/build_gfx1250") - message(STATUS "[AITER-BUILD] Building CK-free V3 backward (gfx1250) via QoLA.") + message(STATUS "[AITER-BUILD] Building CK-free V3 (gfx1250) via QoLA.") # The asm-only / CK-free flags (ONLY_FAV3=1, ENABLE_CK=0) are carried by the # gfx1250 manifest's libmha_bwd module, so no special env is needed here. execute_process( @@ -224,11 +225,13 @@ list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS # NVTE_AITER_CK_FULL -> qola::te::{mha_fwd,mha_bwd} (CDNA) are linked # NVTE_AITER_V3_BWD_GFX1250 -> qola::te_v3::mha_bwd (gfx1250) is linked + # runtime-dispatched in ck_attn_bwd +# NVTE_AITER_V3_FWD_GFX1250 -> qola::te_v3::mha_fwd (gfx1250) is linked + +# runtime-dispatched in ck_attn_fwd if(__HAVE_CK_FULL) list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS -DNVTE_AITER_CK_FULL) endif() if(__HAS_GFX1250) - list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS -DNVTE_AITER_V3_BWD_GFX1250) + list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS -DNVTE_AITER_V3_BWD_GFX1250 -DNVTE_AITER_V3_FWD_GFX1250) endif() # Public QoLA headers ship alongside the .so libs in ${__AITER_MHA_PATH}/../include @@ -255,8 +258,9 @@ if(__HAVE_CK_FULL) ${__AITER_MHA_PATH}/te_libmha_bwd.so) endif() if(__HAS_GFX1250) - list(APPEND ck_fused_attn_LINKER_LIBS -l:te_v3_libmha_bwd.so) + list(APPEND ck_fused_attn_LINKER_LIBS -l:te_v3_libmha_fwd.so -l:te_v3_libmha_bwd.so) list(APPEND __INSTALL_AITER_LIBS + ${__AITER_MHA_PATH}/te_v3_libmha_fwd.so ${__AITER_MHA_PATH}/te_v3_libmha_bwd.so) endif() target_link_libraries(ck_fused_attn PUBLIC ${ck_fused_attn_LINKER_LIBS}) diff --git a/transformer_engine/common/ck_fused_attn/qola_manifest_gfx1250.toml b/transformer_engine/common/ck_fused_attn/qola_manifest_gfx1250.toml index 23b59027e..a1d7a6706 100644 --- a/transformer_engine/common/ck_fused_attn/qola_manifest_gfx1250.toml +++ b/transformer_engine/common/ck_fused_attn/qola_manifest_gfx1250.toml @@ -8,13 +8,38 @@ # Keep aiter_commit in lockstep with qola_manifest.toml — both consume the same # checked-out AITER source tree. [qola] -aiter_commit = "f03a4ec572bb3d9e15da3b346763c8f126feec0d" # pinned AITER submodule commit +aiter_commit = "c4c4faa4789cb3bf6d972192d96c5461b71728be" # pinned AITER submodule commit namespace = "te_v3" rocm_versions = ["7.2"] [build] architectures = ["gfx1250"] +[[modules]] +name = "libmha_fwd" +mode = "cpp_itfs" +drop_directions = ["fwd", "fwd_splitkv", "batch_prefill"] +flags_extra_cc = ["'-DFAV3_ON=1'", "'-DENABLE_CK=0'", "'-DFAV1250_ON=1'"] +hsa_subdirs = ["fmha_fwd_bf16"] +# Override the source list from optCompilerConfig.json to add the two +# gfx1250 ASM sink TUs. mha_fwd.cu (FAV1250_ON=1) declares +# fmha_fwd_with_sink_asm() via extern "C" and calls it; +srcs = [ + "f'{AITER_CSRC_DIR}/cpp_itfs/mha_fwd.cu'", + "f'{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_split.cu'", + "f'{AITER_CSRC_DIR}/cpp_itfs/mha_fwd_batch_prefill.cu'", + "f'{AITER_CSRC_DIR}/py_itfs_cu/asm_fmha_fwd_with_sink.cu'", +] +# Override blob_gen_cmd to drop the CK generate.py -d {fwd,fwd_splitkv,batch_prefill} +# commands (not needed / would fail for gfx1250) and add the codegen run that +# produce the config headers consumed by the sink TUs: +# asm_fmha_fwd_bf16_configs.hpp <- -m fmha_fwd_bf16 +# The registry's add_blob_gen_cmd (-m fmha_v3_fwd) is still appended by +# _apply_cpp_itfs on top of this list. +blob_gen_cmd = [ + "f'{AITER_META_DIR}/hsa/codegen.py -m fmha_fwd_bf16 --output_dir {{}}'", +] + # Reuse the torch-free libmha_bwd module (sources = mha_bwd.cu only; the same # source the CK-full te_libmha_bwd.so builds from). Do NOT use # module_fmha_v3_bwd here — it pulls in mha_common.cu, which includes diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index c1576999f..8509d0737 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -12,8 +12,33 @@ #include "qola_mha_fwd.h" #include "ck_fused_attn_utils.hpp" +// Staged gfx1250 forward dispatch. When this build includes the CK-free V3 +// forward library (te_v3_libmha_fwd.so, built for gfx1250), declare its +// namespaced entry point so ck_attn_fwd can route to it on gfx1250 devices at +// runtime. The CK-full path (QOLA_NS(mha_fwd) == qola::te::mha_fwd) is used on +// all other archs. +#if defined(NVTE_AITER_V3_FWD_GFX1250) +namespace qola { namespace te_v3 { +float mha_fwd(const aiter::mha_fwd_args& args, const ck_tile::stream_config& stream_config); +}} // namespace qola::te_v3 +#endif + namespace ck_fused_attn{ +#if defined(NVTE_AITER_V3_FWD_GFX1250) +namespace { +// True when the active device is gfx1250 (gcnArchName may carry feature +// suffixes, e.g. "gfx1250:sramecc+", so match on prefix). +bool is_gfx1250_device(){ + int dev = 0; + if(hipGetDevice(&dev) != hipSuccess){ return false; } + hipDeviceProp_t prop{}; + if(hipGetDeviceProperties(&prop, dev) != hipSuccess){ return false; } + return prop.major == 12 && prop.minor == 5; +} +} // namespace +#endif + // print the fmha traits and fmha_args when calling ck apis void log_fwd_config(const char* func_name, bool has_dropout, const aiter::mha_fwd_args& fmha_args, std::ostream* log_file){ @@ -230,18 +255,26 @@ hipError_t ck_attn_fwd(const CKAttnFwdArgs& args, hipStream_t stream){ if(log_file){ log_fwd_config(__FUNCTION__, has_dropout, fmha_args, log_file); } + + float average_runtime; +#if defined(NVTE_AITER_V3_FWD_GFX1250) + if(is_gfx1250_device()){ + average_runtime = qola::te_v3::mha_fwd(fmha_args, stream_config); + } else +#endif + { #if defined(NVTE_AITER_CK_FULL) - float average_runtime = QOLA_NS(mha_fwd)(fmha_args, stream_config); + average_runtime = QOLA_NS(mha_fwd)(fmha_args, stream_config); #else - // gfx1250-only build: no CK-full forward library exists (gfx1250 has no - // forward kernels). The unified backend selector never picks CK on gfx1250, - // so this path is unreachable at runtime; the guard only keeps the link - // closed when te_libmha_fwd.so is absent. - float average_runtime = -1.0f; - throw std::runtime_error( - "ck_fused_attn fwd: no CK-full AITER forward library in this build " - "(gfx1250 has no forward kernels)."); + // gfx1250-only build: no CK-full forward library exists (gfx1250 has no + // forward kernels). The unified backend selector never picks CK on gfx1250, + // so this path is unreachable at runtime; the guard only keeps the link + // closed when te_libmha_fwd.so is absent. + throw std::runtime_error( + "ck_fused_attn fwd: no CK-full AITER forward library in this build " + "(gfx1250 has no forward kernels)."); #endif + } if(average_runtime < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn fwd pass."); diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 50dd6a406..5b9f34790 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -183,21 +183,6 @@ bool is_ck_backend_supported( return false; } - // gfx1250 (RDNA4) ships AITER V3 *backward* asm kernels only — there are no - // forward kernels at the pinned commit. TE selects one fused-attn backend per - // op and uses it for both directions (backward inherits the forward's choice), - // so selecting CK here would route the forward into a kernel-less path. - // Until the forward is handled (direction-aware backend selection, or gfx1250 - // forward kernels), do not select CK on gfx1250 through this unified path. - // The CK-free V3 backward library (te_v3_module_fmha_v3_bwd.so) and the - // runtime dispatch in ck_attn_bwd are built and staged for that activation. - if(cuda::sm_arch(cuda::current_device()) == 125){ - if(nvte_log_ck_config){ - std::cout<<"gfx1250 CK fused attn is staged (backward-only); not selected via the unified backend yet"< Date: Sun, 14 Jun 2026 18:03:18 -0400 Subject: [PATCH 19/20] Fix FWD V3 run and test --- tests/pytorch/attention/test_attention.py | 8 +- .../attention/test_attention_gfx1250.py | 382 ++++++++++++++++++ .../common/ck_fused_attn/CMakeLists.txt | 59 ++- .../ck_fused_attn/qola_manifest_gfx1250.toml | 15 +- .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 48 +++ 5 files changed, 481 insertions(+), 31 deletions(-) create mode 100644 tests/pytorch/attention/test_attention_gfx1250.py diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index d1bf0d6c4..619802711 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -297,7 +297,9 @@ def test_dot_product_attention( # Skip if only unfused backend is supported # Double-count the CK backend since we want to compare V2/V3 kernels - has_ck_backend = IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends + # Issue #16948 CK V2 is disabled for gfx1250 + has_ck_backend = ( IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends and + get_device_compute_capability() != (12, 5) ) if not has_ck_backend and ( len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported ) < 2: @@ -1494,7 +1496,9 @@ def test_transformer_layer( # Skip if only unfused backend is supported # Double-count the CK backend since we want to compare V2/V3 kernels - has_ck_backend = IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends + # Issue #16948 CK V2 is disabled for gfx1250 + has_ck_backend = ( IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends and + get_device_compute_capability() != (12, 5) ) if not has_ck_backend and ( len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported ) < 2: diff --git a/tests/pytorch/attention/test_attention_gfx1250.py b/tests/pytorch/attention/test_attention_gfx1250.py new file mode 100644 index 000000000..17787b7c6 --- /dev/null +++ b/tests/pytorch/attention/test_attention_gfx1250.py @@ -0,0 +1,382 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +""" +Gfx1250-targeted attention tests. + +Configs mirror the smoke tests in: + 3rdparty/aiter/op_tests/cpp/mha/smoke_test_fwd_v3_gfx1250.sh (fwd V3) + +FWD V3 notes (fmha_fwd_gfx1250_batched / fmha_fwd_with_sink_asm): + - Both D64 and D128 require a non-null sink_addr (fixed kernarg layout). + TE supplies a static [256] fp32 buffer initialized to -1e30f so that + exp(-1e30f) ≈ 0.0f adds no effective weight for non-fully-masked rows. + - D64 (ENABLE_SINK=1): kernel reads and uses the sink values as a logit + floor. Top-left causal; sq ≤ sk (rectangular) safe because even with + sink≈0 the real attention weights dominate. + - D128 (ENABLE_SINK=0): kernel ignores the sink values entirely. + Causal (top-left or bottom-right); sq == sk only — rectangular shapes + risk NaN on fully-masked KV tiles because the sink floor is disabled. + - No SWA (window_size_left must be -1). + +Each test forces CK V3 and compares against a pure-PyTorch scaled dot product +attention reference (torch.nn.functional.scaled_dot_product_attention) computed +in float32 for numerical stability. +""" +import os +import sys +import pathlib + +import pytest +import torch +import torch.nn.functional as F +from torch.utils.cpp_extension import IS_HIP_EXTENSION + +from transformer_engine.pytorch import DotProductAttention, get_device_compute_capability +from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend +from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends +from transformer_engine.pytorch.distributed import CudaRNGStatesTracker + +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import ( + reset_rng_states, + ModelConfig, + get_available_attention_backends, +) + +# Whole file is ROCm + gfx1250 only. +pytestmark = [ + pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific test."), + pytest.mark.skipif( + get_device_compute_capability() != (12, 5), + reason="gfx1250 (compute capability 12.5) required.", + ), +] + +_SEED = 1234 + + +def _make_rng_tracker(): + tracker = CudaRNGStatesTracker() + tracker.add("model-parallel-rng", _SEED) + return tracker + + +def _run_dpa( + config: ModelConfig, + backend_env: dict, + dtype: torch.dtype, + is_training: bool, + q: torch.Tensor = None, + k: torch.Tensor = None, + v: torch.Tensor = None, +): + """Run one forward (and optional backward) pass, return (out, grads). + + If q/k/v are provided they are used directly (requires_grad set appropriately); + otherwise fresh random tensors are created. + """ + reset_rng_states() + + for key in [ + "NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN", + "NVTE_FUSED_ATTN_CK", "NVTE_FUSED_ATTN_AOTRITON", + "NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3", + ]: + os.environ.pop(key, None) + for key, val in backend_env.items(): + os.environ[key] = str(val) + _attention_backends["backend_selection_requires_update"] = True + + device = "cuda" + b = config.batch_size + sq = config.max_seqlen_q + sk = config.max_seqlen_kv + hq = config.num_heads + hk = config.num_gqa_groups + dqk = config.head_dim_qk + + if q is None: + q = torch.randn(b, sq, hq, dqk, dtype=dtype, device=device) + k = torch.randn(b, sk, hk, dqk, dtype=dtype, device=device) + v = torch.randn(b, sk, hk, dqk, dtype=dtype, device=device) + + q = q.detach().requires_grad_(is_training) + k = k.detach().requires_grad_(is_training) + v = v.detach().requires_grad_(is_training) + + block = DotProductAttention( + num_attention_heads=hq, + kv_channels=dqk, + num_gqa_groups=hk, + attention_dropout=0.0, + qkv_format="bshd", + attn_mask_type=config.attn_mask_type, + tp_size=1, + tp_group=None, + get_rng_state_tracker=_make_rng_tracker, + ).to(dtype=dtype, device=device) + if not is_training: + block.eval() + + out = block( + q, k, v, + qkv_format="bshd", + attn_mask_type=config.attn_mask_type, + window_size=config.window_size, + max_seqlen_q=sq, + max_seqlen_kv=sk, + ) + + grads = None + if is_training: + out.sum().backward() + grads = (q.grad.clone(), k.grad.clone(), v.grad.clone()) + + return out.detach(), grads + + +def _pytorch_ref( + config: ModelConfig, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + is_training: bool, +) -> tuple: + """Pure-PyTorch SDPA reference computed in float32. + + Inputs are BSHD; GQA heads are expanded before calling SDPA. + Returns (out_bf16, (dq, dk, dv)) where dq/dk/dv are None when not training. + The reference runs in float32 for numerical stability so that the comparison + is against a correct high-precision result rather than a TE-specific backend + that may itself have issues on gfx1250. + """ + b, sq, hq, d = q.shape + hk = k.shape[2] + + # Upcast to float32 for reference stability. + qf = q.float() + kf = k.float() + vf = v.float() + + if is_training: + qf = qf.detach().requires_grad_(True) + kf = kf.detach().requires_grad_(True) + vf = vf.detach().requires_grad_(True) + + # SDPA expects [B, H, S, D]. + qf_t = qf.permute(0, 2, 1, 3) # [b, hq, sq, d] + kf_t = kf.permute(0, 2, 1, 3) # [b, hk, sk, d] + vf_t = vf.permute(0, 2, 1, 3) + + # Expand KV heads for GQA so SDPA sees [b, hq, sk, d]. + # Use a separate leaf for the expanded tensor so grad flows back to hk heads. + gqa = hk != hq + if gqa: + kf_t_exp = kf_t.repeat_interleave(hq // hk, dim=1).detach().requires_grad_(is_training) + vf_t_exp = vf_t.repeat_interleave(hq // hk, dim=1).detach().requires_grad_(is_training) + else: + kf_t_exp = kf_t + vf_t_exp = vf_t + + attn_mask_type = config.attn_mask_type + window_size = config.window_size + sk_len = kf_t.shape[2] + + # Use explicit mask for bottom-right causal or SWA; otherwise let SDPA handle it. + has_swa = window_size is not None and window_size not in ((-1, -1), (-1, 0)) + needs_explicit_mask = attn_mask_type == "causal_bottom_right" or has_swa + + if needs_explicit_mask: + # Start with all-attend, then apply causal + SWA constraints. + rows = torch.arange(sq, device=q.device).unsqueeze(1) # [sq, 1] + cols = torch.arange(sk_len, device=q.device).unsqueeze(0) # [1, sk] + mask = torch.ones(sq, sk_len, dtype=torch.bool, device=q.device) + if attn_mask_type in ("causal", "causal_bottom_right"): + offset = sk_len - sq if attn_mask_type == "causal_bottom_right" else 0 + mask = mask & (cols <= rows + offset) + if has_swa: + left, right = window_size + lo = (rows - left).clamp(min=0) if left >= 0 else torch.zeros_like(rows) + hi = (rows + right) if right >= 0 else torch.full_like(rows, sk_len - 1) + mask = mask & (cols >= lo) & (cols <= hi) + float_mask = torch.zeros(sq, sk_len, dtype=torch.float32, device=q.device) + float_mask[~mask] = float("-inf") + out_t = F.scaled_dot_product_attention(qf_t, kf_t_exp, vf_t_exp, attn_mask=float_mask) + elif attn_mask_type == "no_mask": + out_t = F.scaled_dot_product_attention(qf_t, kf_t_exp, vf_t_exp, is_causal=False) + else: + # causal top-left: PyTorch is_causal=True is exactly this + out_t = F.scaled_dot_product_attention(qf_t, kf_t_exp, vf_t_exp, is_causal=True) + + out_bf16 = out_t.permute(0, 2, 1, 3).to(dtype=q.dtype).detach() # [b, sq, hq, d] + + dq = dk = dv = None + if is_training: + out_t.sum().backward() + dq = qf.grad.to(dtype=q.dtype).detach() # [b, sq, hq, d] + if gqa: + # kf_t_exp.grad: [b, hq, sk, d] — reduce over the hq/hk groups back to hk heads + group = hq // hk + dk_exp = kf_t_exp.grad.view(b, hk, group, sk_len, d).sum(dim=2) # [b, hk, sk, d] + dv_exp = vf_t_exp.grad.view(b, hk, group, sk_len, d).sum(dim=2) + dk = dk_exp.permute(0, 2, 1, 3).to(dtype=k.dtype).detach() # [b, sk, hk, d] + dv = dv_exp.permute(0, 2, 1, 3).to(dtype=v.dtype).detach() + else: + dk = kf.grad.to(dtype=k.dtype).detach() # [b, sk, hk, d] + dv = vf.grad.to(dtype=v.dtype).detach() + + return out_bf16, (dq, dk, dv) + + +def _compare(config: ModelConfig, dtype: torch.dtype = torch.bfloat16, is_training: bool = True): + """Run CK V3 and compare against a float32 PyTorch SDPA reference.""" + tols = dict(atol=2e-2, rtol=2e-2) + + _, _, fused_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout="bshd_bshd_bshd", + is_training=is_training, + ) + if FusedAttnBackend["CK"] not in fused_backends: + pytest.skip("CK backend not available for this config") + + reset_rng_states() + for _evar in [ + "NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN", + "NVTE_FUSED_ATTN_CK", "NVTE_FUSED_ATTN_AOTRITON", + "NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3", + ]: + os.environ.pop(_evar, None) + + device = "cuda" + b = config.batch_size + sq = config.max_seqlen_q + sk = config.max_seqlen_kv + hq = config.num_heads + hk = config.num_gqa_groups + dqk = config.head_dim_qk + + q = torch.randn(b, sq, hq, dqk, dtype=dtype, device=device) + k = torch.randn(b, sk, hk, dqk, dtype=dtype, device=device) + v = torch.randn(b, sk, hk, dqk, dtype=dtype, device=device) + + ref_out, ref_grads = _pytorch_ref(config, q, k, v, is_training) + + ck_v3_env = { + "NVTE_FUSED_ATTN": "1", + "NVTE_FLASH_ATTN": "0", + "NVTE_UNFUSED_ATTN": "0", + "NVTE_FUSED_ATTN_CK": "1", + "NVTE_FUSED_ATTN_AOTRITON": "0", + "NVTE_CK_USES_FWD_V3": "1", + "NVTE_CK_USES_BWD_V3": "1", + } + ck_out, ck_grads = _run_dpa(config, ck_v3_env, dtype, is_training, q=q, k=k, v=v) + + # TE DotProductAttention returns [b, sq, hq*d] for bshd format; reshape to [b, sq, hq, d]. + ck_out = ck_out.view(b, sq, hq, dqk) + + torch.testing.assert_close(ck_out, ref_out, **tols) + if is_training and ref_grads is not None: + dq_ref, dk_ref, dv_ref = ref_grads + dq_ck, dk_ck, dv_ck = ck_grads + # grads from _run_dpa are already [b, s, h, d] since q/k/v were passed as BSHD leaves + torch.testing.assert_close(dq_ck, dq_ref, **tols) + torch.testing.assert_close(dk_ck, dk_ref, **tols) + torch.testing.assert_close(dv_ck, dv_ref, **tols) + + +# --------------------------------------------------------------------------- +# FWD V3 D64 configs (smoke_test_fwd_sink.sh — run_d64) +# +# Kernel: fmha_fwd_with_sink_asm, ENABLE_SINK=1, bottom-right causal. +# Sink floor prevents div-by-zero, so sq < sk (rectangular) is safe. +# h=8, h_k∈{1,2,4}, b∈{1,2}. FWD-only (no backward for this path). +# +# NOTE: the smoke test labels some cases "mha" but actually uses h_k=1 +# (GQA-1 / MQA). The TE test mirrors that: "mha" suffix = num_gqa_groups=1. +# The square cases (sq==sk) additionally exercise num_gqa_groups=8 (true MHA). +# --------------------------------------------------------------------------- + +_fwd_v3_d64: dict = {} +# The gfx1250 ASM kernel implements bottom-right causal (mask_type 1 and 2 both +# map to is_causal=1 inside fmha_fwd_with_sink_asm, which follows bottom-right +# causal semantics). Square shapes (sq==sk) are equivalent for both variants; +# rectangular shapes (sq < sk) must use "causal_bottom_right" to match. +for _s in (512, 1024, 2048): + # h_k=1 matches smoke test "mha" label; also add true MHA (h_k=h=8) for square. + _fwd_v3_d64[f"d64_sq{_s}_sk{_s}_b1_mqa"] = ModelConfig( + 1, _s, 8, 64, max_seqlen_kv=_s, num_gqa_groups=1, attn_mask_type="causal_bottom_right") + _fwd_v3_d64[f"d64_sq{_s}_sk{_s}_b1_mha"] = ModelConfig( + 1, _s, 8, 64, max_seqlen_kv=_s, num_gqa_groups=8, attn_mask_type="causal_bottom_right") + _fwd_v3_d64[f"d64_sq{_s}_sk{_s}_b2_gqa2"] = ModelConfig( + 2, _s, 8, 64, max_seqlen_kv=_s, num_gqa_groups=2, attn_mask_type="causal_bottom_right") +for _sq in (128, 256, 512): + for _sk in (512, 2048): + for _b in (1, 2): + for _hq, _hk in ((8, 1), (8, 2), (4, 4)): + _key = f"d64_sq{_sq}_sk{_sk}_b{_b}_h{_hq}k{_hk}" + _fwd_v3_d64[_key] = ModelConfig( + _b, _sq, _hq, 64, max_seqlen_kv=_sk, num_gqa_groups=_hk, + attn_mask_type="causal_bottom_right") +# Rectangular tail configs: smoke test uses h_k=1 for these non-standard shapes. +for _sq in (130, 300): + _fwd_v3_d64[f"d64_sq{_sq}_sk2048_b1_mha"] = ModelConfig( + 1, _sq, 8, 64, max_seqlen_kv=2048, num_gqa_groups=1, attn_mask_type="causal_bottom_right") +for _sk in (768, 2300): + _fwd_v3_d64[f"d64_sq128_sk{_sk}_b1_mha"] = ModelConfig( + 1, 128, 8, 64, max_seqlen_kv=_sk, num_gqa_groups=1, attn_mask_type="causal_bottom_right") + + +@pytest.mark.parametrize("model", sorted(_fwd_v3_d64.keys())) +def test_gfx1250_fwd_v3_d64(model): + """FWD V3 D64 correctness — run_d64 configs from smoke_test_fwd_v3_gfx1250.sh.""" + _compare(_fwd_v3_d64[model], dtype=torch.bfloat16, is_training=False) + + +# --------------------------------------------------------------------------- +# FWD V3 D128 configs (smoke_test_fwd_sink.sh — run_d128) +# +# Kernel: fmha_fwd_with_sink_asm, ENABLE_SINK=0 (sink_ptr ignored/nullptr). +# D128 uses bottom-right causal (kernel is_causal=1, same path as D64). +# h=8, h_k∈{1,2,4}, b∈{1,2}. FWD-only. +# +# NOTE: smoke test "mha" cases use h_k=1 (MQA); true MHA (h_k=8) is added +# for square shapes only. +# --------------------------------------------------------------------------- + +_fwd_v3_d128: dict = {} +for _s in (512, 1024, 2048): + # h_k=1 matches smoke test; also add true MHA for square shapes. + _fwd_v3_d128[f"d128_sq{_s}_sk{_s}_b1_mqa"] = ModelConfig( + 1, _s, 8, 128, max_seqlen_kv=_s, num_gqa_groups=1, + attn_mask_type="causal_bottom_right") + _fwd_v3_d128[f"d128_sq{_s}_sk{_s}_b1_mha"] = ModelConfig( + 1, _s, 8, 128, max_seqlen_kv=_s, num_gqa_groups=8, + attn_mask_type="causal_bottom_right") + _fwd_v3_d128[f"d128_sq{_s}_sk{_s}_b2_gqa2"] = ModelConfig( + 2, _s, 8, 128, max_seqlen_kv=_s, num_gqa_groups=2, + attn_mask_type="causal_bottom_right") +# Rectangular configs from smoke test run_d128 (bottom-right causal, sq < sk safe). +for _sq in (128, 256): + for _hq, _hk in ((8, 1), (8, 2), (4, 4)): + _fwd_v3_d128[f"d128_sq{_sq}_sk2048_h{_hq}k{_hk}"] = ModelConfig( + 1, _sq, _hq, 128, max_seqlen_kv=2048, num_gqa_groups=_hk, + attn_mask_type="causal_bottom_right") +# Unaligned tail configs: h_k=1 matching smoke test. +_fwd_v3_d128["d128_sq130_sk2048_b1_mha"] = ModelConfig( + 1, 130, 8, 128, max_seqlen_kv=2048, num_gqa_groups=1, + attn_mask_type="causal_bottom_right") +_fwd_v3_d128["d128_sq128_sk2300_b1_mha"] = ModelConfig( + 1, 128, 8, 128, max_seqlen_kv=2300, num_gqa_groups=1, + attn_mask_type="causal_bottom_right") + + +@pytest.mark.parametrize("model", sorted(_fwd_v3_d128.keys())) +def test_gfx1250_fwd_v3_d128(model): + """FWD V3 D128 correctness — run_d128 configs from smoke_test_fwd_v3_gfx1250.sh.""" + _compare(_fwd_v3_d128[model], dtype=torch.bfloat16, is_training=False) diff --git a/transformer_engine/common/ck_fused_attn/CMakeLists.txt b/transformer_engine/common/ck_fused_attn/CMakeLists.txt index 4b4a46718..5e77eb093 100644 --- a/transformer_engine/common/ck_fused_attn/CMakeLists.txt +++ b/transformer_engine/common/ck_fused_attn/CMakeLists.txt @@ -8,16 +8,28 @@ project(ck_fused_attn LANGUAGES HIP CXX) set(AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha shared lib install prefix in TE") +# Issue #16948 # gfx1250 carries AITER V3 bwd kernels only (hd128, bf16, batch mode). The # runtime envelope is enforced in nvte_get_fused_attn_backend(). set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") set(__AITER_SOURCE_DIR "${__QOLA_DIR}/build/third_party/aiter") +if (DEFINED ENV{NVTE_AITER_SOURCE_DIR} AND NOT $ENV{NVTE_AITER_SOURCE_DIR} STREQUAL "") + set(__AITER_SOURCE_DIR $ENV{NVTE_AITER_SOURCE_DIR}) + message(STATUS "Using AITER source from NVTE_AITER_SOURCE_DIR=${__AITER_SOURCE_DIR}. Disable AITER checkout.") + set(__SKIP_AITER_CHECKOUT TRUE) +else() + set(__SKIP_AITER_CHECKOUT FALSE) +endif() set(__CK_SOURCE_DIR "${__AITER_SOURCE_DIR}/3rdparty/composable_kernel") set(CK_INCLUDE_DIR "${__CK_SOURCE_DIR}/include") set(AITER_INCLUDE_DIR "${__AITER_SOURCE_DIR}/csrc/include") if(NOT Python_EXECUTABLE) find_package(Python COMPONENTS Interpreter QUIET) + if (NOT Python_EXECUTABLE) + #Don't fail the entire build if Python isn't found since it may pass if pre-built AITER is used + message(ERROR "Python interpreter not found. Some build steps will be skipped or fail. Please install Python and ensure it is on the PATH.") + endif() endif() # Resolve the manifest-pinned AITER commit (defines AITER_SHA) and bring the @@ -33,27 +45,26 @@ if(Python_EXECUTABLE) # submodules) work in containerized builds where the bind-mounted .git is # owned by a different UID than the build process. Mirrors the pattern in # transformer_engine/common/CMakeLists.txt:get_git_commit(). - execute_process( - COMMAND sh -c - "tmp=$(mktemp /tmp/gitconfig.XXXXXX) || exit 1; \ -GIT_CONFIG_GLOBAL=$tmp git config --global --add safe.directory '*' >/dev/null 2>&1; \ -GIT_CONFIG_GLOBAL=$tmp PYTHONPATH=\"${__QOLA_DIR}:$PYTHONPATH\" '${Python_EXECUTABLE}' -m qola.cli checkout \ ---manifest '${__QOLA_MANIFEST}' \ ---aiter-root '${__AITER_SOURCE_DIR}'; \ -rc=$?; rm -f \"$tmp\"; exit $rc" - RESULT_VARIABLE AITER_CHECKOUT_RESULT - OUTPUT_VARIABLE AITER_CHECKOUT_OUTPUT - ERROR_VARIABLE AITER_CHECKOUT_ERROR - OUTPUT_STRIP_TRAILING_WHITESPACE - ERROR_STRIP_TRAILING_WHITESPACE - ) - if(NOT AITER_CHECKOUT_RESULT EQUAL 0) - message(FATAL_ERROR - "Failed to sync AITER source tree at ${__AITER_SOURCE_DIR} to " - "manifest-pinned commit ${AITER_SHA}.\n" - "${AITER_CHECKOUT_OUTPUT}\n${AITER_CHECKOUT_ERROR}") + if (NOT __SKIP_AITER_CHECKOUT) + execute_process( + COMMAND sh -c + "PYTHONPATH=\"${__QOLA_DIR}:$PYTHONPATH\" '${Python_EXECUTABLE}' -m qola.cli checkout \ + --manifest '${__QOLA_MANIFEST}' \ + --aiter-root '${__AITER_SOURCE_DIR}'" + RESULT_VARIABLE AITER_CHECKOUT_RESULT + OUTPUT_VARIABLE AITER_CHECKOUT_OUTPUT + ERROR_VARIABLE AITER_CHECKOUT_ERROR + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_STRIP_TRAILING_WHITESPACE + ) + if(NOT AITER_CHECKOUT_RESULT EQUAL 0) + message(FATAL_ERROR + "Failed to sync AITER source tree at ${__AITER_SOURCE_DIR} to " + "manifest-pinned commit ${AITER_SHA}.\n" + "${AITER_CHECKOUT_OUTPUT}\n${AITER_CHECKOUT_ERROR}") + endif() + message(STATUS "[AITER] Synced ${__AITER_SOURCE_DIR} to ${AITER_SHA}") endif() - message(STATUS "[AITER] Synced ${__AITER_SOURCE_DIR} to ${AITER_SHA}") execute_process( COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/check_aiter_mha_args.py @@ -73,8 +84,6 @@ rc=$?; rm -f \"$tmp\"; exit $rc" "${AITER_ARG_CHECK_OUTPUT}\n${AITER_ARG_CHECK_ERROR}") endif() message(STATUS "AITER API validation passed via check_aiter_mha_args.py") -else() - message(WARNING "Python interpreter not found; skipping AITER source-tree sync and API validation.") endif() # Sanity-check the resolved include directories now that `qola checkout` has @@ -166,6 +175,11 @@ endif() if(__HAS_GFX1250) set(__QOLA_MANIFEST_V3 "${CMAKE_CURRENT_LIST_DIR}/qola_manifest_gfx1250.toml") set(__QOLA_BUILD_DIR_V3 "${__QOLA_DIR}/build_gfx1250") + if (__SKIP_AITER_CHECKOUT) + set(__SKIP_CHECKOUT_ARGS "--skip-checkout") + else() + set(__SKIP_CHECKOUT_ARGS "") + endif() message(STATUS "[AITER-BUILD] Building CK-free V3 (gfx1250) via QoLA.") # The asm-only / CK-free flags (ONLY_FAV3=1, ENABLE_CK=0) are carried by the # gfx1250 manifest's libmha_bwd module, so no special env is needed here. @@ -176,6 +190,7 @@ if(__HAS_GFX1250) --aiter-root ${__AITER_SOURCE_DIR} --output-dir ${__QOLA_BUILD_DIR_V3} --arch "gfx1250" + ${__SKIP_CHECKOUT_ARGS} RESULT_VARIABLE QOLA_V3_BUILD_RESULT ) if(NOT QOLA_V3_BUILD_RESULT EQUAL 0) diff --git a/transformer_engine/common/ck_fused_attn/qola_manifest_gfx1250.toml b/transformer_engine/common/ck_fused_attn/qola_manifest_gfx1250.toml index a1d7a6706..99d2bae36 100644 --- a/transformer_engine/common/ck_fused_attn/qola_manifest_gfx1250.toml +++ b/transformer_engine/common/ck_fused_attn/qola_manifest_gfx1250.toml @@ -1,14 +1,15 @@ -# gfx1250 (RDNA4) carries AITER V3 *backward* asm kernels only — there are no -# forward kernels at the pinned commit, and the CK FMHA template headers do not -# compile for gfx1250. This manifest therefore builds a CK-free (ENABLE_CK=0), -# asm-v3-only backward library under a distinct namespace (te_v3) so it can -# coexist with the CK-full te_* libraries in a multi-arch build. TE selects -# between qola::te::mha_bwd and qola::te_v3::mha_bwd at runtime by device arch. +# gfx1250 (RDNA4) AITER V3 ASM kernels: forward (fmha_fwd_with_sink_asm, +# D64/D128) and backward. CK FMHA template headers do not compile for gfx1250, +# so this manifest builds CK-free (ENABLE_CK=0) asm-v3-only libraries under a +# distinct namespace (te_v3) so they can coexist with the CK-full te_* libraries +# in a multi-arch build. TE selects between qola::te::mha_{fwd,bwd} and +# qola::te_v3::mha_{fwd,bwd} at runtime by device arch. +# Issue #16948 # # Keep aiter_commit in lockstep with qola_manifest.toml — both consume the same # checked-out AITER source tree. [qola] -aiter_commit = "c4c4faa4789cb3bf6d972192d96c5461b71728be" # pinned AITER submodule commit +aiter_commit = "0403bcdfd0b507e0459012b82e00785a818f774a" # pinned AITER submodule commit namespace = "te_v3" rocm_versions = ["7.2"] diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index 8509d0737..c2555711d 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -6,8 +6,10 @@ #include #include +#include #include #include +#include #include "ck_fused_attn/ck_fused_attn.hpp" #include "qola_mha_fwd.h" #include "ck_fused_attn_utils.hpp" @@ -36,6 +38,31 @@ bool is_gfx1250_device(){ if(hipGetDeviceProperties(&prop, dev) != hipSuccess){ return false; } return prop.major == 12 && prop.minor == 5; } + +// D64 gfx1250 fmha_fwd_with_sink_asm (ENABLE_SINK=1): requires non-null sink_ptr +// of shape [nhead] fp32 in "AITER post-scale domain". The kernel adds +// exp(sink_val[h]) to every row's softmax denominator. We initialize to +// -1e30f so expf(-1e30f)=0.0f in fp32 — zero contribution, matching the +// UnfusedDotProductAttention reference which has no sink term. +// D128 (ENABLE_SINK=0): dispatch guard rejects sink_ptr!=nullptr; leave null. +// +// Single static buffer, allocated once, kept for the process lifetime. +constexpr int kSinkBufMaxHeads = 256; +static float* s_sink_buf = nullptr; +static std::once_flag s_sink_once; + +const void* get_gfx1250_sink_buf(){ + std::call_once(s_sink_once, [](){ + if(hipMalloc(&s_sink_buf, kSinkBufMaxHeads * sizeof(float)) != hipSuccess){ + s_sink_buf = nullptr; + return; + } + std::vector fill(kSinkBufMaxHeads, -1e30f); + hipMemcpy(s_sink_buf, fill.data(), + kSinkBufMaxHeads * sizeof(float), hipMemcpyHostToDevice); + }); + return s_sink_buf; +} } // namespace #endif @@ -181,6 +208,16 @@ hipError_t ck_attn_fwd(const CKAttnFwdArgs& args, hipStream_t stream){ fmha_args.block_scale_seqstart_q_ptr = nullptr; fmha_args.block_scale_seqstart_k_ptr = nullptr; fmha_args.sink_ptr = nullptr; +#if defined(NVTE_AITER_V3_FWD_GFX1250) + // D64 (ENABLE_SINK=1): fmha_fwd_gfx1250_batched requires non-null sink_ptr and + // reads each element as a per-head logit added to the softmax denominator. + // We pass -1e30f so exp(-1e30f)=0.0f in fp32 — zero contribution, matching + // the UnfusedDotProductAttention reference. + // D128 (ENABLE_SINK=0): dispatch guard rejects sink_ptr!=nullptr, so leave null. + if(is_gfx1250_device() && args.d_qk == 64 && args.h <= kSinkBufMaxHeads) { + fmha_args.sink_ptr = get_gfx1250_sink_buf(); + } +#endif fmha_args.seqlen_k = args.s_kv; // unused in group mode (or kvcache enabled) fmha_args.max_seqlen_q = args.s_q; @@ -258,7 +295,18 @@ hipError_t ck_attn_fwd(const CKAttnFwdArgs& args, hipStream_t stream){ float average_runtime; #if defined(NVTE_AITER_V3_FWD_GFX1250) + // Pre-fill O and LSE before calling the gfx1250 ASM forward kernel. + // The kernel ABI requires a valid (allocated) LSE buffer regardless of + // return_lse; the kernel may touch lse_ptr even when return_lse=0. + // O/LSE pre-initialization is handled inside fmha_fwd_gfx1250_batched + // (in aiter/csrc/cpp_itfs/mha_fwd.cu) as part of the kernel calling convention. if(is_gfx1250_device()){ + if(fmha_args.lse_ptr == nullptr) + throw std::runtime_error( + "ck_fused_attn fwd: lse_ptr is null on gfx1250 — caller must allocate softmax LSE."); + if(fmha_args.o_ptr == nullptr) + throw std::runtime_error( + "ck_fused_attn fwd: o_ptr is null on gfx1250 — caller must allocate output."); average_runtime = qola::te_v3::mha_fwd(fmha_args, stream_config); } else #endif From 8fd79f9275ad82f818ac98fd42944d7e941344ac Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 16 Jun 2026 16:20:51 +0000 Subject: [PATCH 20/20] Updated with AOT memory handling (dynamic on host) --- .../include/ck_fused_attn/ck_fused_attn.hpp | 12 ++ .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 108 ++++++++++++++---- .../common/fused_attn_rocm/fused_attn_ck.cpp | 30 +++++ 3 files changed, 129 insertions(+), 21 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp index abd1ec371..d0472899d 100644 --- a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp +++ b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp @@ -133,6 +133,13 @@ struct CkAttnBwdArgs : CKAttnCommonArgs { // Workspace shared with forward LSE void* lse_workspace_ptr = nullptr; + // AOT scratch for AITER's internal bwd allocations (launcher metadata + dq_acc + // accumulator). Carved from the caller's workspace and handed to aiter through + // the workspace_alloc callback; ck_attn_bwd_workspace_size() reports the bytes + // to reserve. aiter_workspace_bytes bounds the bump allocator. + void* aiter_workspace_ptr = nullptr; + size_t aiter_workspace_bytes = 0; + // V3 ASM kernel selection bool deterministic = false; bool uses_bwd_v3 = false; @@ -142,6 +149,11 @@ struct CkAttnBwdArgs : CKAttnCommonArgs { hipError_t ck_attn_fwd(const CKAttnFwdArgs& args, hipStream_t stream); hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream); +// Bytes of AOT device scratch ck_attn_bwd needs for AITER's internal bwd +// workspace (launcher metadata + dq_acc), covering both the v2 (CK launcher) and +// v3 (asm) dispatch paths. Pure host-side computation; no kernel launch. +size_t ck_attn_bwd_workspace_size(const CkAttnBwdArgs& args); + }//namespace ck_fused_attn #endif // CK_FUSED_ATTN_H diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 0cc8c7689..a35cc9c4e 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -5,11 +5,11 @@ ************************************************************************/ #include +#include #include #include #include #include -#include #include #include "ck_fused_attn/ck_fused_attn.hpp" #include "ck_tile/host/pinned_host_releaser.hpp" @@ -440,6 +440,69 @@ void dump_bwd_timings(const char* dump_path, float average_runtime){ file << average_runtime << "\n"; } +namespace { + +// Trait subset that determines AITER's internal bwd workspace footprint. Mirrors +// the fields ck_attn_bwd sets on mha_bwd_args so the size query and the dispatch +// stay in lockstep. +::fmha_bwd_traits make_bwd_traits(const CkAttnBwdArgs& args){ + bool has_dropout = (args.dropout_probability > 0.f); + bool has_dbias = args.dbias_ptr != nullptr; + bias_enum bias_type = bias_enum::no_bias; + if(!args.is_group_mode()){ + bias_type = get_ck_bias_type_shape(&args).first; + } + return ::fmha_bwd_traits{ + /* seqlen_q */ static_cast(args.is_group_mode() ? args.max_tokens_q : args.s_q), + /* seqlen_k */ static_cast(args.is_group_mode() ? args.max_tokens_kv : args.s_kv), + /* batch */ static_cast(args.b), + /* max_seqlen_q */ static_cast(args.s_q), + /* max_seqlen_k */ static_cast(args.s_kv), + /* hdim_q */ static_cast(args.d_qk), + /* hdim_v */ static_cast(args.d_v), + /* nhead_q */ static_cast(args.h), + /* nhead_k */ static_cast(args.hg), + /* data_type */ get_data_type_str(args.dtype), + /* is_group_mode */ args.is_group_mode(), + /* mask_type */ static_cast(args.attn_mask_type), + /* bias_type */ bias_type, + /* has_dbias */ (!args.is_group_mode()) && has_dbias, + /* has_dropout */ has_dropout, + /* is_store_randval */ false, + /* is_deterministic */ args.deterministic, + }; +} + +// dq_acc bytes the v3 asm path allocates via workspace_alloc. Mirrors aiter's +// fmha_v3_bwd sizing (csrc/cpp_itfs/mha_bwd.cu); returns 0 when v3 can't run so +// the CK launcher size dominates. Gating mirrors ck_attn_bwd's use_asm_v3. +size_t v3_dq_acc_bytes(const CkAttnBwdArgs& args){ + const bool use_asm_v3 = (args.s_q < 16) ? false : args.uses_bwd_v3; + if(!use_asm_v3){ + return 0; + } + const size_t seqlen_q = args.is_group_mode() ? args.max_tokens_q : args.s_q; + const size_t elem = args.is_v3_atomic_fp32 ? 4 : 2; + const size_t a16_seq = (args.s_q + 15) / 16 * 16; + const size_t a16_hdim = (args.d_qk == 192) ? 192 : 128; + const size_t dq_acc_seq = args.is_v3_atomic_fp32 ? seqlen_q : a16_seq; + const size_t dq_acc_hdim = args.is_v3_atomic_fp32 ? args.d_qk : a16_hdim; + const size_t eff_batch = (args.is_group_mode() && args.is_v3_atomic_fp32) ? 1 : args.b; + return eff_batch * args.h * dq_acc_seq * dq_acc_hdim * elem; +} + +} // namespace + +size_t ck_attn_bwd_workspace_size(const CkAttnBwdArgs& args){ + // v2 (CK launcher) reports its full device workspace (host metadata + dq_acc) + // host-side; v3 (asm) allocates only dq_acc. v3 is tried first but may fall + // back to v2, so reserve the larger of the two. The launcher symbol is forced + // local by QoLA's export script, so the v2 size is queried through QoLA. + const size_t v2_bytes = QOLA_NS(mha_bwd_workspace_size)(make_bwd_traits(args)); + const size_t v3_bytes = v3_dq_acc_bytes(args); + return v2_bytes > v3_bytes ? v2_bytes : v3_bytes; +} + hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ bool has_dropout = (args.dropout_probability > 0.f); @@ -587,34 +650,40 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ } } - // Device-side workspace allocations made inside mha_bwd (launcher metadata - // and the dq_acc accumulator). aiter only contracts that the pointer remain - // valid for the duration of the kernels it enqueues; hipFreeAsync on the - // same stream defers the free until that work completes. - std::vector mha_bwd_workspaces; - fmha_args.workspace_alloc = [&mha_bwd_workspaces, stream](size_t bytes, bool zero_init) -> void* { + // Device-side workspace for mha_bwd's internal allocations (launcher metadata + // and the dq_acc accumulator) is reserved ahead of time by the caller (see + // ck_attn_bwd_workspace_size) and carved here, matching the AOTriton bwd path. + // workspace_alloc bump-allocates from that buffer instead of allocating per + // call; only one allocation happens per dispatch, but the bump allocator stays + // correct if aiter splits the request. + void* ws_base = args.aiter_workspace_ptr; + const size_t ws_capacity = args.aiter_workspace_bytes; + size_t ws_offset = 0; + fmha_args.workspace_alloc = [ws_base, ws_capacity, &ws_offset, stream](size_t bytes, bool zero_init) -> void* { if(bytes == 0){ return nullptr; } - void* ptr = nullptr; - if(hipMallocAsync(&ptr, bytes, stream) != hipSuccess){ - throw std::runtime_error("ck_fused_attn bwd: hipMallocAsync failed for AITER workspace."); + constexpr size_t kAlign = 256; + const size_t base = (ws_offset + kAlign - 1) & ~(kAlign - 1); + if(ws_base == nullptr || base + bytes > ws_capacity){ + throw std::runtime_error("ck_fused_attn bwd: AITER workspace request exceeds reserved AOT buffer."); } + void* ptr = static_cast(ws_base) + base; + ws_offset = base + bytes; if(zero_init){ if(hipMemsetAsync(ptr, 0, bytes, stream) != hipSuccess){ - hipFreeAsync(ptr, stream); throw std::runtime_error("ck_fused_attn bwd: hipMemsetAsync failed for AITER workspace."); } } - mha_bwd_workspaces.push_back(ptr); return ptr; }; - // Group mode requires a pinned host buffer for the async D2H seqstart - // pipeline; aiter keeps the shared_ptr alive past kernel completion via a - // stream-tail hipLaunchHostFunc keepalive. The deleter fires from that HIP - // callback thread, which holds runtime locks — calling any HIP API from it - // (including hipHostFree) deadlocks against concurrent main-thread HIP - // calls. Defer the free to ck_tile::pinned_host_releaser's worker thread. + // Group mode needs a pinned host buffer for the async D2H seqstart pipeline. + // aiter keeps the shared_ptr alive past kernel completion via a stream-tail + // release; that release (and thus the deleter) fires from a HIP callback thread + // holding runtime locks, so calling any HIP API from it (including hipHostFree) + // would deadlock against concurrent main-thread HIP calls. Defer the free to + // ck_tile::pinned_host_releaser's worker thread, which frees each buffer once + // it is no longer in flight — small and group-mode-v2 only, but never leaked. fmha_args.pinned_host_alloc = [](size_t bytes) -> std::shared_ptr { if(bytes == 0){ return {}; @@ -664,9 +733,6 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ } float average_runtime = QOLA_NS(mha_bwd)(fmha_args, stream_config); - for(void* ws_ptr : mha_bwd_workspaces){ - hipFreeAsync(ws_ptr, stream); - } if(average_runtime < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass."); diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 5b9f34790..3ea1913d7 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -809,6 +809,34 @@ void fused_attn_ck_bwd_impl( // First h*max_tokens_q*sizeof(float) is the lse-d buffer (passed as softmax_lsed) void* lse_workspace = planner.allocate(h*max_tokens_q*sizeof(float)); + // Reserve AOT scratch for AITER's internal bwd workspace (launcher metadata + + // dq_acc accumulator), carved by ck_attn_bwd's workspace_alloc callback instead + // of allocated per call. The size is a pure host-side query; the full ck_args + // built below (execution pass only) carries the pointer into the dispatch, so + // assemble just the trait-bearing fields here. + ck_fused_attn::CkAttnBwdArgs ws_size_args; + ws_size_args.dtype = nvte_to_ck_dtype(dtype); + ws_size_args.b = b; ws_size_args.h = h; ws_size_args.hg = hg; + ws_size_args.s_q = s_q; ws_size_args.s_kv = s_kv; ws_size_args.d_qk = d_qk; ws_size_args.d_v = d_v; + ws_size_args.max_tokens_q = max_tokens_q; ws_size_args.max_tokens_kv = max_tokens_kv; + ws_size_args.attn_mask_type = set_ck_mask(mask_type, window_size_left, window_size_right); + ws_size_args.dropout_probability = dropout_probability; + ws_size_args.deterministic = deterministic; + ws_size_args.uses_bwd_v3 = nvte_ck_uses_bwd_v3; + ws_size_args.is_v3_atomic_fp32 = nvte_ck_is_v3_atomic_fp32; + ws_size_args.how_v3_bf16_cvt = nvte_ck_how_v3_bf16_cvt; + ws_size_args.dbias_ptr = devPtrdBias; + if((is_SBHD && is_padding) || bshd_to_thd || is_ragged){ + // group mode: a non-null cu_seqlen flips is_group_mode(); bias is forced off + ws_size_args.cu_seqlen_q_ptr = devPtrCuSeqlensQ; + }else{ + // batch mode: bias shape feeds the trait/workspace sizing + ws_size_args.attn_bias_type = nvte_to_ck_bias_type(bias_type); + ws_size_args.bias_b = bias_b; ws_size_args.bias_h = bias_h; + } + const size_t aiter_workspace_bytes = ck_fused_attn::ck_attn_bwd_workspace_size(ws_size_args); + void* aiter_workspace = planner.allocate(aiter_workspace_bytes); + void* dk_expanded_ptr = nullptr; void* dv_expanded_ptr = nullptr; std::array dk_expanded_stride; @@ -1029,6 +1057,8 @@ void fused_attn_ck_bwd_impl( ck_args.dk_expanded_ptr = dk_expanded_ptr; ck_args.dv_expanded_ptr = dv_expanded_ptr; ck_args.lse_workspace_ptr = lse_workspace; + ck_args.aiter_workspace_ptr = aiter_workspace; + ck_args.aiter_workspace_bytes = aiter_workspace_bytes; ck_args.deterministic = deterministic; ck_args.uses_bwd_v3 = nvte_ck_uses_bwd_v3; ck_args.is_v3_atomic_fp32 = nvte_ck_is_v3_atomic_fp32;