diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 7b0fdf80fd1..f045908b422 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -206,6 +206,7 @@ message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}") add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL example_fmha_fwd.cpp) target_link_libraries(${EXAMPLE_FMHA_FWD} ${FMHA_FWD_INSTANCES}) target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set_property(TARGET ${EXAMPLE_FMHA_FWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}") # not using add_example_executable() to add this target, since we don't want this to be included in @@ -213,6 +214,7 @@ message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}") add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp) target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES}) target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set_property(TARGET ${EXAMPLE_FMHA_BWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated # by cmake will print too many files, execvp: /bin/sh: Argument list too long diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index 65bb2077642..532285ce5a6 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -1,8 +1,8 @@ -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT -# CMakeLists.txt for sparse attention (Jenga and VSA) +#Copyright(c) Advanced Micro Devices, Inc., or its affiliates. +#SPDX - License - Identifier : MIT +#CMakeLists.txt for sparse attention(Jenga and VSA) -# Use SUPPORTED_GPU_TARGETS directly +#Use SUPPORTED_GPU_TARGETS directly set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) set(GPU_TARGETS ${SUPPORTED_GPU_TARGETS}) @@ -16,7 +16,7 @@ endif() message(STATUS "Building Sparse Attention (Jenga & VSA) for targets: ${INST_TARGETS}") -# Code generation scripts +#Code generation scripts file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS ${CMAKE_CURRENT_LIST_DIR}/generate.py ${CMAKE_CURRENT_LIST_DIR}/codegen/*.py @@ -83,6 +83,7 @@ message(DEBUG "adding example ${EXAMPLE_JENGA_SPARSE_ATTN}") add_executable(${EXAMPLE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_jenga_sparse_attn.cpp) target_link_libraries(${EXAMPLE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES}) target_include_directories(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set_property(TARGET ${EXAMPLE_JENGA_SPARSE_ATTN} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE -Wno-undefined-func-template -Wno-float-equal @@ -148,9 +149,107 @@ message(DEBUG "adding example ${EXAMPLE_VSA_SPARSE_ATTN}") add_executable(${EXAMPLE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_vsa_sparse_attn.cpp) target_link_libraries(${EXAMPLE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES}) target_include_directories(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set_property(TARGET ${EXAMPLE_VSA_SPARSE_ATTN} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE -Wno-undefined-func-template -Wno-float-equal ) +# ============================================================================ +# Sparge Sparse Attention (PV-skip enabled, derived from VSA) +# ============================================================================ +set(SPARSE_ATTN_SPARGE_CODE_GEN_ARGS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api fwd_sparge + --receipt 600 +) + +# Generate list of Sparge kernels (at configure time, only list) +execute_process( + COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_SPARGE_CODE_GEN_ARGS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_blob_list.txt + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to generate Sparge kernel list") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_blob_list.txt SPARSE_ATTN_SPARGE_GEN_BLOBS) + +# Generate Sparge kernel source files at build time +add_custom_command( + OUTPUT ${SPARSE_ATTN_SPARGE_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_SPARGE_CODE_GEN_ARGS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CODE_GEN_SCRIPTS} + COMMENT "Generate CK Tile Sparge Sparse Attention kernels" +) + +message(STATUS "Sparge kernel files to be generated: ${SPARSE_ATTN_SPARGE_GEN_BLOBS}") + +# Sparge Instances +set(SPARSE_ATTN_SPARGE_INSTANCES "tile_sparse_attn_sparge_instances") + +add_library(${SPARSE_ATTN_SPARGE_INSTANCES} OBJECT EXCLUDE_FROM_ALL + ${SPARSE_ATTN_SPARGE_GEN_BLOBS} +) +target_include_directories(${SPARSE_ATTN_SPARGE_INSTANCES} PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn +) +set_source_files_properties(${SPARSE_ATTN_SPARGE_GEN_BLOBS} PROPERTIES LANGUAGE HIP) +set_property(TARGET ${SPARSE_ATTN_SPARGE_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) + +target_compile_options(${SPARSE_ATTN_SPARGE_INSTANCES} PRIVATE + -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN + -DCK_TILE_FMHA_FWD_FAST_EXP2 + -Wno-undefined-func-template + -Wno-float-equal +) + +# ============================================================================ +# Sparge BlockMap GPU Kernel (hand-written instantiation, no codegen) +# ============================================================================ +set(SPARGE_BLOCKMAP_INSTANCES "tile_sparge_blockmap_instances") + +add_library(${SPARGE_BLOCKMAP_INSTANCES} OBJECT EXCLUDE_FROM_ALL + ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp +) +target_include_directories(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn +) +set_source_files_properties( + ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp + PROPERTIES LANGUAGE HIP +) +set_property(TARGET ${SPARGE_BLOCKMAP_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) + +target_compile_options(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE + -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN + -DCK_TILE_FMHA_FWD_FAST_EXP2 + -Wno-undefined-func-template + -Wno-float-equal +) + +# ---------------------------------------------------------------------------- +# Build unified Sparge test: combines blockmap, Jenga, and VSA attention +# for end-to-end evaluation and timing in a single executable. +# ---------------------------------------------------------------------------- +set(EXAMPLE_SPARGE "tile_example_sparge") +message(DEBUG "adding example ${EXAMPLE_SPARGE}") +add_executable(${EXAMPLE_SPARGE} EXCLUDE_FROM_ALL test_sparge.cpp) +target_link_libraries(${EXAMPLE_SPARGE} + ${SPARSE_ATTN_JENGA_INSTANCES} + ${SPARSE_ATTN_VSA_INSTANCES} + ${SPARSE_ATTN_SPARGE_INSTANCES} + ${SPARGE_BLOCKMAP_INSTANCES} +) +target_include_directories(${EXAMPLE_SPARGE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set_property(TARGET ${EXAMPLE_SPARGE} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) +target_compile_options(${EXAMPLE_SPARGE} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal +) + set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/50_sparse_attn/README.md b/example/ck_tile/50_sparse_attn/README.md new file mode 100644 index 00000000000..593f4a85ef5 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/README.md @@ -0,0 +1,52 @@ +# Sparge Attention (Composable Kernel) + +A Composable Kernel port of [SpargeAttn](https://github.com/thu-ml/SpargeAttn) for AMD GPU. Both the block-map pipeline (mean-pool → cosine sim → pooled QK → top-k LUT) and the sparse FMHA stage run on-GPU. Two attention backends are exposed via `-pipeline=vsa` (default, faster) and `-pipeline=jenga` (async K/V load variant). + +## Status vs Upstream + +Implemented: +- per-block mean-pool, cosine similarity, pooled QK +- top-k / `cdfthreshd` block selection, BlockMap LUT +- sparse FMHA (both `vsa` and `jenga` backends) +- per-head `topk` / `simthreshd1` / `cdfthreshd` +- **is_causal mask in pooled score** (top-left only at block-map grain) ([spas_sage_attn/utils.py:L338](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L338)) +- **attention_sink** — block-map column 0 force-on ([spas_sage_attn/autotune.py:L355](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/autotune.py#L355)) + +Not yet ported (upstream pinned to commit [`ae5b629`](https://github.com/thu-ml/SpargeAttn/tree/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a)): +- **K smoothing** — pre-pool `k -= km`; required for diffusion / video checkpoints (CogVideoX, Mochi-1, Flux, OpenSora, SD 3.5) ([spas_sage_attn/core.py:L53](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/core.py#L53)) +- **Sort-based top-k selection** — replaces our O(N_k^2) iterative argmax; matters at long seqlen (s ≥ 16k) ([spas_sage_attn/utils.py:L345](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L345)) +- **Q/K int8 quant fusion in pool kernel** — enables a downstream int8 GEMM0 in the attn kernel ([spas_sage_attn/utils.py:L371](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L371)) + +## PV-skip modes + +`pv_threshold` per-Q-tile skip in the attention kernel is implemented in three variants, selectable at runtime via `-pv_mode={none|warp|block}`: + +- **`none`** — skip disabled; baseline matching the no-PV-skip codegen instance. +- **`warp`** (per-wavefront) — each wavefront votes locally via `__shfl_xor` butterfly AND; SGPR-resident flag. CK-tile-specific variant, not in upstream. +- **`block`** (per-block) — block-wide consensus vote via LDS broadcast; aligned with upstream sm80 ([`qk_int_sv_f16_cuda_sm80.cuh:L334`](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/csrc/qattn/qk_int_sv_f16_cuda_sm80.cuh#L334)). V loads stay unconditional in all modes — the guard wraps the PV MMA only, matching upstream and paper Algorithm 1. + +![PV-skip mode comparison](docs/pv_skip_mode_comparison.png) + +*MI300X, b=2 h=16 s=8192 d=128 fp16, 5 seeds × 9 sparsity points. All three modes dispatch to the `kM0=64 padK=0` tile bucket at this shape.* + +On the canonical recipe shape, `none > warp > block` at every measured sparsity, with no crossover. The per-block guard adds +33..+35 VGPR (6..9 spills) on this tile configuration, depressing occupancy. `warp` is +0..+4 VGPR. The default is `-pv_mode=warp`; switch to `none` for the no-skip baseline or `block` to exercise the upstream-aligned variant. A shape sweep is needed before recommending `block` as default — the `kM0=128` path has Δ ≈ 0 VGPR for per-block and is a candidate. + +## Usage + +```bash +ninja tile_example_sparge +./bin/tile_example_sparge -pipeline=vsa -b=2 -h=32 -s=16384 -d=128 -topk=0.4 -simthreshd1=0.001 +``` + +Select a PV-skip variant with `-pv_mode={none|warp|block}` (default `warp`); finite `-pv_threshold=20` lets the per-Q-tile skip predicate fire. + +Mask + attention sink: +- `-mask` accepts the `01_fmha` grammar (`0` / `t` / `b` / `t:l,r` / `xt:N` / `g:y,x`, default `0`). The block-map selection prunes past-diagonal blocks only under `mask_top_left` (`t`); `b` / SWA / generic are forwarded to the attention kernel and emit a stderr WARN that the block-map selection is unchanged. +- `-attention_sink {0,1}` forces block-map column `kb=0` ON for every Q-block (default `0`). Under `-mask t` this is degenerate since `kb=0` is always causal-valid. + +Add `-v=1` for CPU validation; use a small shape (`-b=1 -h=2 -s=512`), since full-shape CPU reference scales O(s²) and runs 30+ minutes at s=8k, hours at s=16k. When `-mask != 0` or `-attention_sink == 1`, the `[block_map cross-check]` and `[VSA LUT self-consistency]` cells are SKIPPED (the CPU reference does not model causal mask or sink); the `[attention output]` cell still runs but the dense reference applies no mask, so it will report FAIL on the kernel-correct output. Treat `-v=1` correctness as **block-map level only** in those configurations. + +## References + +- [SpargeAttn upstream](https://github.com/thu-ml/SpargeAttn) (pinned to [`ae5b629`](https://github.com/thu-ml/SpargeAttn/tree/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a)) +- [Paper — Zhang et al., arXiv:2502.18137](https://arxiv.org/abs/2502.18137) diff --git a/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py b/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py index 8614a1ff3ba..0f2866cbf42 100644 --- a/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py +++ b/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py @@ -58,11 +58,13 @@ def get_mask_check_map(mask: str): PIPELINE_MAP = { "qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsyncJenga", "qr_async_vsa": "ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA", + "qr_async_sparge": "ck_tile::BlockFmhaPipelineQRKSVSAsyncSparge", } PIPELINE_ENUM_MAP = { "qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", "qr_async_vsa": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", + "qr_async_sparge": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", } BOOL_MAP = { diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py index a3d32652a98..fc4b8642ddd 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py @@ -141,6 +141,17 @@ def update_file(file_path, content): constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} + +template<> +void fmha_jenga_fwd_oneshot_(const ck_tile::stream_config& s, fmha_jenga_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} """ FMHA_FWD_API_FILENAME = "fmha_jenga_fwd_api.cpp" @@ -219,6 +230,45 @@ def update_file(file_path, content): }} """ +FMHA_FWD_ONESHOT_API_FILENAME = "fmha_jenga_fwd_oneshot_api.cpp" +FMHA_FWD_ONESHOT_API = """ +#include "fmha_fwd_trek.hpp" +#include + +void fmha_jenga_fwd_oneshot(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{ + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + std::cerr << "fmha_jenga_fwd_oneshot: no matching dispatch (dtype=" << t.data_type + << " hdim_q=" << t.hdim_q << " hdim_v=" << t.hdim_v + << " seqlen_q=" << a.seqlen_q << " seqlen_k=" << a.seqlen_k + << " mask=" << static_cast(t.mask_type) << ")" << std::endl; +}} +""" + +FMHA_FWD_ONESHOT_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_ONESHOT_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_ONESHOT_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_ONESHOT_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + fmha_jenga_fwd_oneshot_(s, a); + return; + }} +""" + @dataclass class CppConstraint: @@ -274,10 +324,7 @@ def scheck(self) -> str: @property def seqtune(self) -> str: - if self.bm0 == 128: - return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true - else: - return f"a.seqlen_q <= {self.bm0}" + return "true" @property def skcheck(self) -> str: @@ -447,6 +494,67 @@ def api(self) -> str: per_tr_load += " (void)t ; (void)s ; (void)a;" return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + @property + def oneshot_api(self) -> str: + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() + for tr_load in ["t", "f"]: + per_dtypes = str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case = str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() + for k, trait in enumerate(traits): + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_ONESHOT_API_INNER_DISPATCH.format( + F_if=if_k, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_trload=BOOL_MAP[trait.tr_load], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_ONESHOT_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_ONESHOT_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_ONESHOT_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) + if not per_tr_load: + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_ONESHOT_API.format(F_dispatch=per_tr_load) + @dataclass class FmhaFwdTileSize: @@ -582,38 +690,39 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128): [ - FmhaFwdTileSize( # fmt: skip - 16, + FmhaFwdTileSize( # fmt: skip -- 128x128 tile (original, for old sparse attn test) + 128, + 128, 32, - 64, 128, 32, 128, + 4, 1, 1, + 4, 1, 1, - 1, - 1, - 16, - 16, 32, - 16, + 32, 16, 32, + 32, + 16, -1, + CppConstraint("t.bm0 == 0 || t.bm0 == 128"), ), - FmhaFwdTileSize( # fmt: skip - 32, - 32, + FmhaFwdTileSize( # fmt: skip -- 64x128 tile (for sparge blockmap kM0=64) + 64, 128, + 32, 128, 32, 128, + 2, 1, 1, - 1, - 1, + 2, 1, 1, 32, @@ -623,18 +732,40 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: 32, 16, -1, + CppConstraint("t.bm0 == 64"), ), FmhaFwdTileSize( # fmt: skip - 128, + 16, + 32, 64, + 128, 32, 128, + 1, + 1, + 1, + 1, + 1, + 1, + 16, + 16, + 32, 16, + 16, + 32, + -1, + ), + FmhaFwdTileSize( # fmt: skip + 32, + 32, + 128, + 128, + 32, 128, - 4, 1, 1, - 4, + 1, + 1, 1, 1, 32, @@ -647,10 +778,10 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: ), FmhaFwdTileSize( # fmt: skip 128, - 128, + 64, 32, 128, - 32, + 16, 128, 4, 1, @@ -780,7 +911,7 @@ def get_fwd_blobs( for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): - if tile.F_bm0 != 128 or tile.F_bn0 != 128: + if tile.F_bm0 not in (64, 128) or tile.F_bn0 != 128: continue if pipeline.tag != "qr_async": continue @@ -846,6 +977,7 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + update_file(autogen_dir / FMHA_FWD_ONESHOT_API_FILENAME, api_pool.oneshot_api) def write_blobs( @@ -865,3 +997,4 @@ def list_blobs( for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") + f.write((file_path.parent / GEN_DIR / FMHA_FWD_ONESHOT_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_sparge.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_sparge.py new file mode 100644 index 00000000000..e5182c3dc89 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_sparge.py @@ -0,0 +1,1088 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass, field +import fnmatch +import itertools +import os +import os.path as path +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cpp_symbol_map import ( + BOOL_MAP, + FWD_DTYPE_MAP, + LAYOUT_MAP, + MODE_MAP, + PIPELINE_ENUM_MAP, + PIPELINE_MAP, + get_mask_check_map, + get_mask_map, +) + +GEN_DIR = "" + + +def update_file(file_path, content): + """Update the file at file_path with the given content if it differs from the existing content. + + It avoids unnecessary touching of the file which triggers rebuilds + """ + + existing_content = "" + if path.exists(file_path): + with open(file_path, "r") as file: + existing_content = file.read() + if existing_content == content: + return + with open(file_path, "w") as file: + file.write(content) + + +DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16} + +K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd_trek.hpp" +#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp" +#include "kernel/fmha_fwd_sparge_kernel.hpp" + +""" + +# NOTE: Sparge sparse attention kernel has the following restrictions enforced by static_assert: +# - Group mode: NOT supported (batch mode only) +# - Bias: NOT supported (NO_BIAS only) +# - LSE output: NOT supported (false only) +# - Dropout: NOT supported (false only) +# - Logits soft-cap: NOT supported (false only) +# - FP8 static quantization: NOT supported (NO_SCALE only) +# The template below hardcodes these unsupported features accordingly. + +FMHA_FWD_KERNEL_BODY = """ +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum, +// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + false, // has_logits_soft_cap - NOT supported + ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported + false, // store_lse - NOT supported + false, // has_dropout - NOT supported + false, // has_randval - NOT supported + ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported + {F_occupancy}, + false>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported) + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaSparseFwdTypeConfig::QDataType, + typename FmhaSparseFwdTypeConfig::KDataType, + typename FmhaSparseFwdTypeConfig::VDataType, + typename FmhaSparseFwdTypeConfig::SaccDataType, + typename FmhaSparseFwdTypeConfig::SMPLComputeDataType, + typename FmhaSparseFwdTypeConfig::BiasDataType, + typename FmhaSparseFwdTypeConfig::RandValOutputDataType, + typename FmhaSparseFwdTypeConfig::LSEDataType, + typename FmhaSparseFwdTypeConfig::PDataType, + typename FmhaSparseFwdTypeConfig::OaccDataType, + typename FmhaSparseFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + {F_trload}, + fmha_trait_{F_idx}>; + +// R30: emit 3 pipeline / kernel instances per traits combo — kNone (PV-skip +// AST removed; source-equivalent to VSA), kPerWave (R25 A1 shipped path), +// kPerBlock (R30 added: block-wide AND vote gates gemm_1). The host dispatch +// in fmha_sparge_fwd_api.cpp picks one based on +// fmha_sparge_fwd_args::pv_mode_compile (0/1/2). +// R26 split-launch: fmha_fwd_create_kargs_and_grids(a) forwards the new +// fmha_sparge_fwd_args fields (pv_threshold_per_head_ptr, head_remap_ptr, +// nhead_in_launch) to MakeKargs. When head_remap_ptr is non-null the wrapper +// also shrinks grids.y to nhead_in_launch so each bucket fires its own kernel. +// Suffixes: +// _pvsf = PV-Skip OFF (kNone) +// _pvst = PV-Skip per-WAVE (kPerWave; preserved R25 A1 binary name) +// _pvsb = PV-Skip per-BLOCK (kPerBlock; R30 new) +using fmha_pipeline_{F_idx}_pvsf = ck_tile::BlockFmhaPipelineQRKSVSAsyncSparge< + fmha_pipeline_problem_{F_idx}, + ck_tile::BlockFmhaPipelineQRKSVSAsyncDefaultPolicy, + ck_tile::PVSkipMode::kNone>; +using fmha_pipeline_{F_idx}_pvst = ck_tile::BlockFmhaPipelineQRKSVSAsyncSparge< + fmha_pipeline_problem_{F_idx}, + ck_tile::BlockFmhaPipelineQRKSVSAsyncDefaultPolicy, + ck_tile::PVSkipMode::kPerWave>; +using fmha_pipeline_{F_idx}_pvsb = ck_tile::BlockFmhaPipelineQRKSVSAsyncSparge< + fmha_pipeline_problem_{F_idx}, + ck_tile::BlockFmhaPipelineQRKSVSAsyncDefaultPolicy, + ck_tile::PVSkipMode::kPerBlock>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx}_pvsf = + ck_tile::FmhaFwdSpargeKernel; +using fmha_kernel_{F_idx}_pvst = + ck_tile::FmhaFwdSpargeKernel; +using fmha_kernel_{F_idx}_pvsb = + ck_tile::FmhaFwdSpargeKernel; + +using trait_{F_idx} = fmha_sparge_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + +#include + +// R30: 3 specializations per traits combo — int kPVMode values: +// 0 = kNone (pvsf binary) +// 1 = kPerWave (pvst binary; R25 A1 path) +// 2 = kPerBlock (pvsb binary; R30 new) +template<> +float fmha_sparge_fwd_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}_pvsf; + if(s.log_level_ > 0) + std::cout << ", " << "{F_kernel_name}_pvsf" << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +float fmha_sparge_fwd_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}_pvst; + if(s.log_level_ > 0) + std::cout << ", " << "{F_kernel_name}_pvst" << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +float fmha_sparge_fwd_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}_pvsb; + if(s.log_level_ > 0) + std::cout << ", " << "{F_kernel_name}_pvsb" << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} + +template<> +void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}_pvsf; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} + +template<> +void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}_pvst; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} + +template<> +void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config& s, fmha_sparge_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}_pvsb; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} +""" + +FMHA_FWD_API_FILENAME = "fmha_sparge_fwd_api.cpp" +FMHA_FWD_API = """ +#include + +#include + +namespace {{ +bool get_num_cus(unsigned& num_cus) {{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device"); + return false; + }} + + hipDeviceProp_t props{{}}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device properties"); + return false; + }} + + num_cus = props.multiProcessorCount; + return true; +}} + +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ + const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; + const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 + + return batch * nheads * num_m_blocks * num_n_blocks; +}} +}} // namespace + +float fmha_sparge_fwd(fmha_sparge_fwd_traits t, fmha_sparge_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; + + [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate + + unsigned num_cus; + if (!get_num_cus(num_cus)) {{ + return r; + }} + + [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ + return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); + }}; + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_sparge_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + // R30: pv_mode_compile selects 0=kNone / 1=kPerWave / 2=kPerBlock. + switch(a.pv_mode_compile) {{ + case 0: return fmha_sparge_fwd_(s, a); + case 1: return fmha_sparge_fwd_(s, a); + case 2: return fmha_sparge_fwd_(s, a); + default: return fmha_sparge_fwd_(s, a); // legacy default = per-wave + }} + }} +""" + +FMHA_FWD_ONESHOT_API_FILENAME = "fmha_sparge_fwd_oneshot_api.cpp" +FMHA_FWD_ONESHOT_API = """ +#include "fmha_fwd_trek.hpp" +#include + +void fmha_sparge_fwd_oneshot(fmha_sparge_fwd_traits t, fmha_sparge_fwd_args a, const ck_tile::stream_config& s){{ + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + std::cerr << "fmha_sparge_fwd_oneshot: no matching dispatch (dtype=" << t.data_type + << " hdim_q=" << t.hdim_q << " hdim_v=" << t.hdim_v + << " seqlen_q=" << a.seqlen_q << " seqlen_k=" << a.seqlen_k + << " mask=" << static_cast(t.mask_type) << ")" << std::endl; +}} +""" + +FMHA_FWD_ONESHOT_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_ONESHOT_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_ONESHOT_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_ONESHOT_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_sparge_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + // R30: pv_mode_compile selects 0=kNone / 1=kPerWave / 2=kPerBlock. + switch(a.pv_mode_compile) {{ + case 0: fmha_sparge_fwd_oneshot_(s, a); return; + case 1: fmha_sparge_fwd_oneshot_(s, a); return; + case 2: fmha_sparge_fwd_oneshot_(s, a); return; + default: fmha_sparge_fwd_oneshot_(s, a); return; + }} + }} +""" + + +@dataclass +class CppConstraint: + bool_expr: str = None + + def __str__(self): + if self.bool_expr is None: + return "true" + else: + return f"{self.bool_expr}" + + def __and__(self, other): + return CppConstraint(f"({str(self)}) && ({str(other)})") + + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag: str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim: str + dtype: str # data type + mode: str # value from MODE_MAP + bm0: int # tile size along q seqlen (block size) + bn0: int # tile size along qk seqlen + bk0: int # tile size along qk gemm unroll + bn1: int # tile size along v head_dim + bk1: int # tile size along kv gemm unroll + bk0max: int + vlayout: str + logits: str + mask: str + spad: str + skpad: str + dpad: str + dvpad: str + tr_load: str + constraint: CppConstraint + + @property + def name(self) -> str: + return ( + f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" + ) + + @property + def scheck(self) -> str: + if self.mode == "group": + return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true + if self.spad == "t": + return "true" # always support + return "true" + + @property + def seqtune(self) -> str: + return "true" + + @property + def skcheck(self) -> str: + if self.mode == "group": + return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true + if self.skpad == "t": + return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" + return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" + + @property + def dcheck(self) -> str: + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == "t": + return f"a.hdim_q % {vec} == 0" + assert False + + @property + def dvcheck(self) -> str: + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == "t": + return f"a.hdim_v % {vec} == 0" + assert False + + +@dataclass +class FmhaFwdPipeline: + tag: str + + F_vlayout: str # row/col + F_spad: str # true/false + F_skpad: str # + F_dpad: str # + F_dvpad: str # + F_logits: str # t/f + F_mask: str # value from MASK_MAP + F_trload: str # true/false + F_constraint: CppConstraint = field(default_factory=CppConstraint) + + @property + def name(self) -> str: + def pad_name() -> str: + n = "" + if self.F_spad == "t": + n += "s" + if self.F_skpad == "t": + n += "sk" + if self.F_dpad == "t": + n += "d" + if self.F_dvpad == "t": + n += "dv" + if n != "": + n = "p" + n + return n + + pn = pad_name() + n = f"{self.tag}_v{self.F_vlayout[0]}" + if pn != "": + n += f"_{pn}" + else: + n += "_npad" + + if self.F_logits == "t": + n += "_logits" + else: + n += "_nlogits" + + n += "_nbias" + + if self.F_mask[0:2] == "s_": + if self.F_mask == "s_mask": + n += "_mask" + else: + n += "_nmask" + else: + if self.F_mask != "no": + n += f"_m{self.F_mask[0]}" + else: + n += "_nmask" + + n += "_nskip" + + n += "_nsquant" + + if self.F_trload == "t": + n += "_trload" + else: + n += "_ntrload" + + return n + + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait: FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + hdim = trait.hdim, trait.bn1 + if hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][hdim] = list() + + self.pool[trait.dtype][hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() + for tr_load in ["t", "f"]: + per_dtypes = str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case = str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() + for k, trait in enumerate(traits): + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( + F_if=if_k, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + # F_logits removed - hardcoded to false (NOT supported) + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_trload=BOOL_MAP[trait.tr_load], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) + if not per_tr_load: + # empty string we add some ignore to suppress warning in api + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + + @property + def oneshot_api(self) -> str: + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() + for tr_load in ["t", "f"]: + per_dtypes = str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case = str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() + for k, trait in enumerate(traits): + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_ONESHOT_API_INNER_DISPATCH.format( + F_if=if_k, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_trload=BOOL_MAP[trait.tr_load], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_ONESHOT_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_ONESHOT_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_ONESHOT_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) + if not per_tr_load: + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_ONESHOT_API.format(F_dispatch=per_tr_load) + + +@dataclass +class FmhaFwdTileSize: + F_bm0: int # tile size along q seqlen (block size) + F_bn0: int # tile size along k seqlen + F_bk0: int # tile size along qk gemm unroll + F_bn1: int # tile size along v head_dim + F_bk1: int # tile size along kv gemm unroll + F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0: int # number of warps for gemm0 along q seqlen + F_rn0: int # number of warps for gemm0 along k seqlen + F_rk0: int # number of warps for gemm0 along head dim q (not used) + F_rm1: int # number of warps for gemm1 along q seqlen + F_rn1: int # number of warps for gemm1 along head dim v + F_rk1: int # number of warps for gemm1 along k seqlen (not used) + F_wm0: int # gemm0 warp size along m + F_wn0: int # gemm0 warp size along n + F_wk0: int # gemm0 warp size along k + F_wm1: int # gemm1 warp size along m + F_wn1: int # gemm1 warp size along n + F_wk1: int # gemm1 warp size along k + F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint: CppConstraint = field(default_factory=CppConstraint) + + @property + def name(self) -> str: + return ( + f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" + + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" + + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" + + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + ) + + +@dataclass +class FmhaFwdKernel: + F_idx: int # this is not a tunable, but a counter to differentiate symbol + F_hdim: int # hdim + F_dtype: str # data type + F_mode: str # value from MODE_MAP + F_tile: FmhaFwdTileSize + F_pipeline: FmhaFwdPipeline + mask_impl: str + + @property + def template(self) -> str: + # kernel_body removed - unused + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( + F_idx=self.F_idx, + F_hdim=self.F_hdim, + F_dtype=FWD_DTYPE_MAP[self.F_dtype], + F_bm0=self.F_tile.F_bm0, + F_bn0=self.F_tile.F_bn0, + F_bk0=self.F_tile.F_bk0, + F_bn1=self.F_tile.F_bn1, + F_bk1=self.F_tile.F_bk1, + F_bk0max=self.F_tile.F_bk0max, + F_rm0=self.F_tile.F_rm0, + F_rn0=self.F_tile.F_rn0, + F_rk0=self.F_tile.F_rk0, + F_rm1=self.F_tile.F_rm1, + F_rn1=self.F_tile.F_rn1, + F_rk1=self.F_tile.F_rk1, + F_wm0=self.F_tile.F_wm0, + F_wn0=self.F_tile.F_wn0, + F_wk0=self.F_tile.F_wk0, + F_wm1=self.F_tile.F_wm1, + F_wn1=self.F_tile.F_wn1, + F_wk1=self.F_tile.F_wk1, + F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad=BOOL_MAP[self.F_pipeline.F_spad], + F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], + # F_logits removed - hardcoded to false in template (NOT supported) + F_occupancy=self.F_tile.F_occupancy, + F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode=MODE_MAP[self.F_mode], + F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], + F_trload=BOOL_MAP[self.F_pipeline.F_trload], + F_kernel_name=self.name, + ) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return ( + f"fmha_sparge_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + + self.F_tile.name + + "_" + + self.F_pipeline.name + ) + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + tr_load=self.F_pipeline.F_trload, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, + ) + + +class KernelComponentFactory: + # TODO: design a more practical way to do it + # this is current supported tile size per hdim + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype == "fp16" or dtype == "bf16": + return { + # (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + # FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128, 128): [ + FmhaFwdTileSize( # fmt: skip -- 128x128 tile (original, for old sparse attn test) + 128, + 128, + 32, + 128, + 32, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 32, + 32, + 16, + 32, + 32, + 16, + -1, + CppConstraint("t.bm0 == 0 || t.bm0 == 128"), + ), + FmhaFwdTileSize( # fmt: skip -- 64x128 tile (for sparge blockmap kM0=64) + 64, + 128, + 32, + 128, + 32, + 128, + 2, + 1, + 1, + 2, + 1, + 1, + 32, + 32, + 16, + 32, + 32, + 16, + -1, + CppConstraint("t.bm0 == 64"), + ), + FmhaFwdTileSize( # fmt: skip + 16, + 32, + 64, + 128, + 32, + 128, + 1, + 1, + 1, + 1, + 1, + 1, + 16, + 16, + 32, + 16, + 16, + 32, + -1, + ), + FmhaFwdTileSize( # fmt: skip + 32, + 32, + 128, + 128, + 32, + 128, + 1, + 1, + 1, + 1, + 1, + 1, + 32, + 32, + 16, + 32, + 32, + 16, + -1, + ), + FmhaFwdTileSize( # fmt: skip + 128, + 64, + 32, + 128, + 16, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 32, + 32, + 16, + 32, + 32, + 16, + -1, + ), + ], + # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } + else: + return None + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + @staticmethod + def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # NOTE: logits soft-cap is NOT supported by Sparge sparse attention (enforced by static_assert) + pipelines = [] + if dtype in ["fp16", "bf16"]: + for logits, mask in itertools.product( + ["f"], # logits soft-cap NOT supported, always false + get_mask_map(mask_impl).keys(), + ): + if hdim == 256 and hdim_v == 256: + # sparge fmha only supports dim <= 192 for now. + continue + pipelines.append( + FmhaFwdPipeline( + "qr_async_sparge", + "row", + "t", + "f", + "t", + "t", + logits, + mask, + "f", + ) + ) + pipelines.append( + FmhaFwdPipeline( + "qr_async_sparge", + "row", + "t", + "t", + "t", + "t", + logits, + mask, + "f", + ) + ) + else: + assert False + return pipelines + + +class CustomFactory(KernelComponentFactory): + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) + if dtype == "fp16" or dtype == "bf16": + if (128, 128) in result.keys(): + result[(128, 128)].insert( + 0, + FmhaFwdTileSize( + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + CppConstraint( + "get_num_blocks(128) < num_cus * min_cu_util_rate" + ), + ), + ) + return result + + +def get_fwd_blobs( + kernel_filter: Optional[str], receipt, optdim_list, mask_impl +) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + factory = ( + CustomFactory + if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1" + else KernelComponentFactory + ) + + # Only generate fp16/bf16 kernels for now. + # NOTE: Sparge sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert) + for dtype in ["fp16", "bf16"]: + d = factory.get_hdim_tile_size_dict(dtype) + if d is None: + continue + for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]): + for tile, pipeline in itertools.product( + tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) + ): + if tile.F_bm0 not in (64, 128) or tile.F_bn0 != 128: + continue + if pipeline.tag != "qr_async_sparge": + continue + k = FmhaFwdKernel( + F_idx=1, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + ) + if kernel_filter != "": + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= mode == "batch" + cond &= pipeline.F_logits == "f" + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # aiter::mha_fwd C++ api integration + elif receipt == 600: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + update_file(autogen_dir / kernel.filename, kernel.template) + + +def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: + update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + update_file(autogen_dir / FMHA_FWD_ONESHOT_API_FILENAME, api_pool.oneshot_api) + + +def write_blobs( + output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + + +def list_blobs( + file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + with file_path.open("a") as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") + f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") + f.write((file_path.parent / GEN_DIR / FMHA_FWD_ONESHOT_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py index 038738de246..208877037f1 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py @@ -141,6 +141,17 @@ def update_file(file_path, content): constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} + +template<> +void fmha_vsa_fwd_oneshot_(const ck_tile::stream_config& s, fmha_vsa_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} """ FMHA_FWD_API_FILENAME = "fmha_vsa_fwd_api.cpp" @@ -219,6 +230,45 @@ def update_file(file_path, content): }} """ +FMHA_FWD_ONESHOT_API_FILENAME = "fmha_vsa_fwd_oneshot_api.cpp" +FMHA_FWD_ONESHOT_API = """ +#include "fmha_fwd_trek.hpp" +#include + +void fmha_vsa_fwd_oneshot(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{ + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + std::cerr << "fmha_vsa_fwd_oneshot: no matching dispatch (dtype=" << t.data_type + << " hdim_q=" << t.hdim_q << " hdim_v=" << t.hdim_v + << " seqlen_q=" << a.seqlen_q << " seqlen_k=" << a.seqlen_k + << " mask=" << static_cast(t.mask_type) << ")" << std::endl; +}} +""" + +FMHA_FWD_ONESHOT_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_ONESHOT_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_ONESHOT_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_ONESHOT_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + fmha_vsa_fwd_oneshot_(s, a); + return; + }} +""" + @dataclass class CppConstraint: @@ -274,10 +324,7 @@ def scheck(self) -> str: @property def seqtune(self) -> str: - if self.bm0 == 128: - return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true - else: - return f"a.seqlen_q <= {self.bm0}" + return "true" @property def skcheck(self) -> str: @@ -447,6 +494,67 @@ def api(self) -> str: per_tr_load += " (void)t ; (void)s ; (void)a;" return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + @property + def oneshot_api(self) -> str: + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() + for tr_load in ["t", "f"]: + per_dtypes = str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case = str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() + for k, trait in enumerate(traits): + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_ONESHOT_API_INNER_DISPATCH.format( + F_if=if_k, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_trload=BOOL_MAP[trait.tr_load], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_ONESHOT_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_ONESHOT_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_ONESHOT_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) + if not per_tr_load: + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_ONESHOT_API.format(F_dispatch=per_tr_load) + @dataclass class FmhaFwdTileSize: @@ -582,38 +690,39 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128): [ - FmhaFwdTileSize( # fmt: skip - 16, + FmhaFwdTileSize( # fmt: skip -- 128x128 tile (original, for old sparse attn test) + 128, + 128, 32, - 64, 128, 32, 128, + 4, 1, 1, + 4, 1, 1, - 1, - 1, - 16, - 16, 32, - 16, + 32, 16, 32, + 32, + 16, -1, + CppConstraint("t.bm0 == 0 || t.bm0 == 128"), ), - FmhaFwdTileSize( # fmt: skip - 32, - 32, + FmhaFwdTileSize( # fmt: skip -- 64x128 tile (for sparge blockmap kM0=64) + 64, 128, + 32, 128, 32, 128, + 2, 1, 1, - 1, - 1, + 2, 1, 1, 32, @@ -623,18 +732,40 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: 32, 16, -1, + CppConstraint("t.bm0 == 64"), ), FmhaFwdTileSize( # fmt: skip - 128, + 16, + 32, 64, + 128, 32, 128, + 1, + 1, + 1, + 1, + 1, + 1, + 16, + 16, + 32, 16, + 16, + 32, + -1, + ), + FmhaFwdTileSize( # fmt: skip + 32, + 32, + 128, + 128, + 32, 128, - 4, 1, 1, - 4, + 1, + 1, 1, 1, 32, @@ -647,10 +778,10 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: ), FmhaFwdTileSize( # fmt: skip 128, - 128, + 64, 32, 128, - 32, + 16, 128, 4, 1, @@ -780,7 +911,7 @@ def get_fwd_blobs( for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): - if tile.F_bm0 != 128 or tile.F_bn0 != 128: + if tile.F_bm0 not in (64, 128) or tile.F_bn0 != 128: continue if pipeline.tag != "qr_async_vsa": continue @@ -846,6 +977,7 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + update_file(autogen_dir / FMHA_FWD_ONESHOT_API_FILENAME, api_pool.oneshot_api) def write_blobs( @@ -865,3 +997,4 @@ def list_blobs( for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") + f.write((file_path.parent / GEN_DIR / FMHA_FWD_ONESHOT_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/docs/plot_sparge_perf.py b/example/ck_tile/50_sparse_attn/docs/plot_sparge_perf.py new file mode 100644 index 00000000000..95a13d5f65c --- /dev/null +++ b/example/ck_tile/50_sparse_attn/docs/plot_sparge_perf.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python3 +"""Plot sparge perf charts from full_grid.csv. + +Re-run with different fixed (b, h, s, dtype, topk) by editing the constants below. +No GPU / no srun / no rebuild — pure matplotlib from CSV. +""" +import os +import sys +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np + +# ---------------------------------------------------------------------- +# Tunable constants — edit these to regenerate for a different point. +# ---------------------------------------------------------------------- +CSV_PATH = "/home/AMD/ginolu12/gino_tmp/full_grid.csv" +OUT_DIR = os.path.dirname(os.path.abspath(__file__)) + +# Chart 1 — speedup vs topk for one fixed (b, h, s, dtype) +CHART1_B = 2 +CHART1_H = 32 +CHART1_S = 16384 +CHART1_DTYPE = "fp16" +CHART1_HEAD_DIM = 128 # for title only + +# Chart 2 — kernel breakdown across s for fixed (b, h, dtype, topk) +CHART2_B = 2 +CHART2_H = 32 +CHART2_DTYPE = "fp16" +CHART2_TOPK = 0.4 +CHART2_S_LIST = [2048, 4096, 8192, 16384] +CHART2_HEAD_DIM = 128 # for title only + +DPI = 140 + +# ---------------------------------------------------------------------- +# Helpers +# ---------------------------------------------------------------------- +def is_fail(note): + if not isinstance(note, str): + return False + return "FAIL" in note + +def is_high_spread(note): + if not isinstance(note, str): + return False + return "HIGH_SPREAD" in note + +def load_data(): + df = pd.read_csv(CSV_PATH) + return df + +# ---------------------------------------------------------------------- +# Chart 1 +# ---------------------------------------------------------------------- +def plot_chart1(df, out_path): + sel = df[ + (df["b"] == CHART1_B) + & (df["h"] == CHART1_H) + & (df["s"] == CHART1_S) + & (df["dtype"] == CHART1_DTYPE) + ].copy() + sel = sel.sort_values("topk").reset_index(drop=True) + + if sel.empty: + print(f"[chart1] WARNING: no rows for b={CHART1_B} h={CHART1_H} s={CHART1_S} dtype={CHART1_DTYPE}") + return [], 0 + + # Drop fully failed rows but keep partial-fail rows; we'll mask per-series. + # Convert numeric columns + for col in ["sparge_jenga", "sparge_vsa", "sparse_jenga", "sparse_vsa", "fmha_us"]: + sel[col] = pd.to_numeric(sel[col], errors="coerce") + + fmha = sel["fmha_us"] + + # Compute speedups; rows with FAIL on a given column will have NaN already. + series = { + "sparge_vsa": fmha / sel["sparge_vsa"], + "sparge_jenga": fmha / sel["sparge_jenga"], + "sparse_vsa": fmha / sel["sparse_vsa"], + "sparse_jenga": fmha / sel["sparse_jenga"], + } + + style = { + "sparge_vsa": {"color": "#1f77b4", "marker": "o", "lw": 2.0}, + "sparge_jenga": {"color": "#ff7f0e", "marker": "s", "lw": 2.0}, + "sparse_vsa": {"color": "#2ca02c", "marker": "^", "lw": 1.5, "ls": "--"}, + "sparse_jenga": {"color": "#d62728", "marker": "v", "lw": 1.5, "ls": "--"}, + } + + fig, ax = plt.subplots(figsize=(8.5, 5.5), dpi=DPI) + + x = sel["topk"].to_numpy() + + # HIGH_SPREAD overlay first (under main markers) + hs_mask = sel["note"].apply(is_high_spread) + high_spread_cells = [] + if hs_mask.any(): + for _, row in sel[hs_mask].iterrows(): + high_spread_cells.append((row["topk"], row["max_spread_pct"])) + # gray ring underneath every series's data point at that x + for label, sp in series.items(): + xs_hs = x[hs_mask.to_numpy()] + ys_hs = sp[hs_mask.to_numpy()].to_numpy() + ax.scatter(xs_hs, ys_hs, s=180, facecolors="none", + edgecolors="gray", linewidths=1.5, zorder=2) + + for label, sp in series.items(): + st = style[label] + ax.plot(x, sp.to_numpy(), label=label, + color=st["color"], marker=st["marker"], + linewidth=st["lw"], linestyle=st.get("ls", "-"), + markersize=7, zorder=3) + + ax.axhline(1.0, color="black", linestyle=":", linewidth=1.2, label="fmha (baseline)", zorder=1) + + ax.set_xlabel("topk (kept fraction)") + ax.set_ylabel("speedup vs FMHA dense (×)") + ax.set_title( + f"Speedup vs FMHA " + f"(b={CHART1_B} h={CHART1_H} s={CHART1_S} d={CHART1_HEAD_DIM} {CHART1_DTYPE})" + ) + ax.grid(True, which="both", linestyle=":", alpha=0.6) + ax.set_xticks(np.arange(0.1, 0.71, 0.1)) + ax.legend(loc="best", framealpha=0.9) + + # Footnote about HIGH_SPREAD overlay + if high_spread_cells: + ax.text(0.01, -0.16, + "Gray rings: HIGH_SPREAD cells (high run-to-run variance)", + transform=ax.transAxes, fontsize=8, color="gray") + + fig.tight_layout() + fig.savefig(out_path, dpi=DPI, bbox_inches="tight") + plt.close(fig) + return high_spread_cells, os.path.getsize(out_path) + + +# ---------------------------------------------------------------------- +# Chart 2 +# ---------------------------------------------------------------------- +def plot_chart2(df, out_path): + sel = df[ + (df["b"] == CHART2_B) + & (df["h"] == CHART2_H) + & (df["dtype"] == CHART2_DTYPE) + & (np.isclose(df["topk"], CHART2_TOPK)) + & (df["s"].isin(CHART2_S_LIST)) + ].copy() + sel = sel.sort_values("s").reset_index(drop=True) + + if sel.empty: + print(f"[chart2] WARNING: no rows for b={CHART2_B} h={CHART2_H} dtype={CHART2_DTYPE} topk={CHART2_TOPK}") + return 0 + + for col in ["sparge_jenga_pre", "sparge_jenga_attn", + "sparge_vsa_pre", "sparge_vsa_attn", "fmha_us"]: + sel[col] = pd.to_numeric(sel[col], errors="coerce") + + s_vals = sel["s"].to_numpy() + n = len(s_vals) + idx = np.arange(n, dtype=float) + + width = 0.35 + offset = width / 2 + 0.02 + + fig, ax = plt.subplots(figsize=(9.0, 5.8), dpi=DPI) + + # Jenga bars (left of group) + jenga_pre = sel["sparge_jenga_pre"].to_numpy() + jenga_attn = sel["sparge_jenga_attn"].to_numpy() + vsa_pre = sel["sparge_vsa_pre"].to_numpy() + vsa_attn = sel["sparge_vsa_attn"].to_numpy() + fmha_vals = sel["fmha_us"].to_numpy() + + color_jenga_pre = "#fdbf6f" # light orange + color_jenga_attn = "#ff7f0e" # orange + color_vsa_pre = "#a6cee3" # light blue + color_vsa_attn = "#1f77b4" # blue + + bj_pre = ax.bar(idx - offset, jenga_pre, width, + color=color_jenga_pre, edgecolor="black", linewidth=0.6, + label="sparge_jenga _pre (BlockMap)") + bj_at = ax.bar(idx - offset, jenga_attn, width, bottom=jenga_pre, + color=color_jenga_attn, edgecolor="black", linewidth=0.6, + label="sparge_jenga _attn") + bv_pre = ax.bar(idx + offset, vsa_pre, width, + color=color_vsa_pre, edgecolor="black", linewidth=0.6, + label="sparge_vsa _pre (BlockMap)") + bv_at = ax.bar(idx + offset, vsa_attn, width, bottom=vsa_pre, + color=color_vsa_attn, edgecolor="black", linewidth=0.6, + label="sparge_vsa _attn") + + # Add total labels on top of each stack + totals_jenga = jenga_pre + jenga_attn + totals_vsa = vsa_pre + vsa_attn + for i in range(n): + ax.text(idx[i] - offset, totals_jenga[i], f"{totals_jenga[i]:.0f}", + ha="center", va="bottom", fontsize=8) + ax.text(idx[i] + offset, totals_vsa[i], f"{totals_vsa[i]:.0f}", + ha="center", va="bottom", fontsize=8) + + # FMHA reference: short horizontal dashed segment per group + seg_half = 0.40 + fmha_label_done = False + for i in range(n): + ax.hlines(fmha_vals[i], idx[i] - seg_half, idx[i] + seg_half, + colors="black", linestyles="dashed", linewidth=1.2, + label="fmha dense (reference)" if not fmha_label_done else None, + zorder=5) + ax.text(idx[i] + seg_half + 0.02, fmha_vals[i], + f"fmha {fmha_vals[i]:.0f}", fontsize=7, va="center", color="black") + fmha_label_done = True + + ax.set_xticks(idx) + ax.set_xticklabels([f"s={s}" for s in s_vals.astype(int)]) + ax.set_xlabel("sequence length (s)") + ax.set_ylabel("kernel time (µs)") + ax.set_title( + f"Sparge kernel time breakdown " + f"(b={CHART2_B} h={CHART2_H} d={CHART2_HEAD_DIM} {CHART2_DTYPE}, topk={CHART2_TOPK})" + ) + ax.grid(True, axis="y", linestyle=":", alpha=0.6) + ax.legend(loc="upper left", framealpha=0.9, fontsize=9) + + # log-y is too aggressive — leave linear; bars will just be tall. + fig.tight_layout() + fig.savefig(out_path, dpi=DPI, bbox_inches="tight") + plt.close(fig) + return os.path.getsize(out_path) + + +# ---------------------------------------------------------------------- +# Main +# ---------------------------------------------------------------------- +def main(): + os.makedirs(OUT_DIR, exist_ok=True) + df = load_data() + + chart1_path = os.path.join(OUT_DIR, "speedup_vs_sparsity.png") + chart2_path = os.path.join(OUT_DIR, "kernel_breakdown.png") + + hs_cells, size1 = plot_chart1(df, chart1_path) + size2 = plot_chart2(df, chart2_path) + + print(f"Wrote {chart1_path} ({size1} bytes)") + print(f"Wrote {chart2_path} ({size2} bytes)") + + if hs_cells: + print("HIGH_SPREAD cells in chart-1 selection:") + for topk, pct in hs_cells: + print(f" topk={topk} max_spread_pct={pct}") + else: + print("No HIGH_SPREAD cells in chart-1 selection.") + + +if __name__ == "__main__": + main() diff --git a/example/ck_tile/50_sparse_attn/docs/pv_skip_mode_comparison.png b/example/ck_tile/50_sparse_attn/docs/pv_skip_mode_comparison.png new file mode 100644 index 00000000000..28eec99e79a Binary files /dev/null and b/example/ck_tile/50_sparse_attn/docs/pv_skip_mode_comparison.png differ diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index 7349c3576e8..071d0409b09 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -272,7 +272,7 @@ struct fmha_jenga_fwd_traits std::string data_type; bool is_v_rowmajor; mask_enum mask_type; - // TODO: padding check is inside this api + int bm0 = 0; // preferred Q-tile size; 0 = don't care (dispatch picks largest) }; float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); @@ -280,7 +280,12 @@ float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile:: template float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args); -float fmha_jenga_fwd(fmha_jenga_fwd_args, const ck_tile::stream_config&); +template +void fmha_jenga_fwd_oneshot_(const ck_tile::stream_config&, fmha_jenga_fwd_args); + +void fmha_jenga_fwd_oneshot(fmha_jenga_fwd_traits, + fmha_jenga_fwd_args, + const ck_tile::stream_config&); // VSA uses the same traits structure as Jenga; aliases for clarity template float fmha_vsa_fwd_(const ck_tile::stream_config&, fmha_vsa_fwd_args); -float fmha_vsa_fwd(fmha_vsa_fwd_args, const ck_tile::stream_config&); +template +void fmha_vsa_fwd_oneshot_(const ck_tile::stream_config&, fmha_vsa_fwd_args); + +void fmha_vsa_fwd_oneshot(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&); + +// sparge: same args as vsa plus a scalar PV-skip threshold (Step 1). +struct fmha_sparge_fwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* lut_ptr; // delta-encoded K-block indices per Q-block, int32 [B,H,Q_blk,K_blk] + const void* valid_block_num_ptr; // valid K-block count per Q-block, int32 [B,H,Q_blk] + void* o_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float scale_s; + float pv_threshold; // SpargeAttn §4.4 PV-skip per-Q-tile threshold (scalar mode) + + // R26 split-launch: when non-null, per-head pv_threshold buffer (length nhead_q) + // is read on device instead of the scalar. Combined with head_remap_ptr the + // host can issue two launches (finite-threshold bucket + sentinel bucket) at + // different binaries. + const float* pv_threshold_per_head_ptr = nullptr; + const int* head_remap_ptr = nullptr; + int nhead_in_launch = 0; // 0 = identity (full nhead_q grid) + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + + // R25 V0: select between kEnablePVSkip=true / =false template instantiations + // at host dispatch time. Default true preserves existing behaviour (binary + // shipped pre-R25-V0 only had the true instance). Profiler can flip this to + // false to measure the source-equivalent-to-VSA baseline (`if constexpr` + // removes the entire PV-skip AST). + // + // R30: superseded by pv_mode_compile (int 0/1/2). Kept for source compat — + // when callers only set pv_skip_compile, the split-launch wrapper derives + // pv_mode_compile = (pv_skip_compile ? 1 : 0). + bool pv_skip_compile = true; + + // R30: 3-mode PV-skip select. + // 0 = kNone (no PV-skip; AST removed; equivalent to VSA baseline) + // 1 = kPerWave (R25 A1 shipped path; per-wavefront butterfly vote) + // 2 = kPerBlock (R30 added; block-wide AND vote through 1 LDS slot) + // Default 1 preserves R25 A1 behaviour for any caller that doesn't set it. + int pv_mode_compile = 1; +}; + +template +auto fmha_fwd_create_kargs_and_grids(fmha_sparge_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.lut_ptr, + args.valid_block_num_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.pv_threshold, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + // R26 split-launch extras + args.pv_threshold_per_head_ptr, + args.head_remap_ptr, + args.nhead_in_launch); + + // R26 split-launch: when head_remap is active, gridDim.y shrinks to bucket size. + const ck_tile::index_t grid_nhead = (args.head_remap_ptr != nullptr && args.nhead_in_launch > 0) + ? args.nhead_in_launch + : args.nhead_q; + dim3 grids = FmhaKernel::GridSize(args.batch, grid_nhead, args.max_seqlen_q, args.hdim_v); + return ck_tile::make_tuple(kargs, grids); +} + +template +using fmha_sparge_fwd_traits_ = fmha_jenga_fwd_traits_; + +using fmha_sparge_fwd_traits = fmha_jenga_fwd_traits; + +float fmha_sparge_fwd(fmha_sparge_fwd_traits, fmha_sparge_fwd_args, const ck_tile::stream_config&); + +// R25 V0 / R30: PV-skip mode is a template non-type param so codegen can emit +// all 3 instantiations from the same source tree. The host dispatch +// (fmha_sparge_fwd_api.cpp) selects the right specialization based on +// fmha_sparge_fwd_args::pv_mode_compile at runtime. +// 0 = kNone, 1 = kPerWave, 2 = kPerBlock (matches ck_tile::PVSkipMode). +template +float fmha_sparge_fwd_(const ck_tile::stream_config&, fmha_sparge_fwd_args); + +template +void fmha_sparge_fwd_oneshot_(const ck_tile::stream_config&, fmha_sparge_fwd_args); + +void fmha_sparge_fwd_oneshot(fmha_sparge_fwd_traits, + fmha_sparge_fwd_args, + const ck_tile::stream_config&); diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp new file mode 100644 index 00000000000..10a58ae05f6 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp @@ -0,0 +1,409 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// Hand-written template instantiation for SpargeBlockMapKernel (fp16, D=128). + +#include "sparge_blockmap_trek.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "ck_tile/host/device_memory.hpp" + +#include +#include +#include +#include +#include +#include + +// ============================================================================ +// Type configuration for block map kernel (reuses FmhaSparseFwdTypeConfig) +// ============================================================================ + +// fp16: D=128, kM0=64, kN0=128 +using bmap_fp16_block_tile = ck_tile::sequence<64, 128, 128, 128, 128, 128>; +// kM0 kN0 kK0 kN1 kK1 kQKHeaddim(D) + +using bmap_fp16_shape = + ck_tile::TileFmhaShape, // Gemm0BlockWarps + ck_tile::sequence<16, 16, 16>, // Gemm0WarpTile (unused by blockmap, but + // needed by shape) + ck_tile::sequence<4, 1, 1>, // Gemm1BlockWarps + ck_tile::sequence<16, 16, 16>, // Gemm1WarpTile + true>; // VLayout row-major + +using bmap_fp16_trait = ck_tile::TileFmhaTraits; // kIsVRowMajorSkip + +using bmap_fp16_variant = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; +using bmap_fp16_mask = ck_tile::GenericAttentionMask; + +using bmap_fp16_problem = ck_tile::BlockFmhaPipelineProblem; + +using bmap_fp16_pipeline = ck_tile::SpargeBlockMapPipeline; +using bmap_fp16_kernel = ck_tile::SpargeBlockMapKernel; + +using kstats_fp16_pipeline = ck_tile::SpargeKStatsPipeline; +using kstats_fp16_kernel = ck_tile::SpargeKStatsKernel; + +// bf16: dtype-independent aliases share fp16 chain; only problem differs. +using bmap_bf16_block_tile = bmap_fp16_block_tile; +using bmap_bf16_shape = bmap_fp16_shape; +using bmap_bf16_trait = bmap_fp16_trait; +using bmap_bf16_variant = bmap_fp16_variant; +using bmap_bf16_mask = bmap_fp16_mask; + +using bmap_bf16_problem = ck_tile::BlockFmhaPipelineProblem; + +using bmap_bf16_pipeline = ck_tile::SpargeBlockMapPipeline; +using bmap_bf16_kernel = ck_tile::SpargeBlockMapKernel; + +using kstats_bf16_pipeline = ck_tile::SpargeKStatsPipeline; +using kstats_bf16_kernel = ck_tile::SpargeKStatsKernel; + +// ============================================================================ +// Workspace layout: caller owns the buffer; we just compute size + offsets. +// Layout = [pooled_k (KDataType) | sim_k (uint8)]. sim_k follows pooled_k with +// no padding (uint8 has alignment 1). +// ============================================================================ + +namespace { + +constexpr int sparge_kN0_for(int hdim_q) +{ + // d=128 instances use kN0=128 (see bmap_fp16_block_tile). + return (hdim_q == 128) ? 128 : 0; +} + +size_t dtype_bytes(const std::string& dt) +{ + if(dt == "fp16" || dt == "bf16") + return 2; + return 0; +} + +} // namespace + +sparge_blockmap_workspace_layout +sparge_blockmap_compute_workspace_layout(sparge_blockmap_traits traits, sparge_blockmap_args args) +{ + const int kN0 = sparge_kN0_for(traits.hdim_q); + const int N_k = (kN0 > 0) ? ck_tile::integer_divide_ceil(args.seqlen_k, kN0) : 0; + const int D = traits.hdim_q; + const size_t element_bytes = dtype_bytes(traits.data_type); + + sparge_blockmap_workspace_layout layout{}; + layout.pooled_k_offset = 0; + layout.pooled_k_bytes = + static_cast(args.batch) * args.nhead_k * N_k * D * element_bytes; + layout.sim_k_offset = layout.pooled_k_bytes; + layout.sim_k_bytes = static_cast(args.batch) * args.nhead_k * N_k * sizeof(uint8_t); + layout.total_bytes = layout.sim_k_offset + layout.sim_k_bytes; + return layout; +} + +// ============================================================================ +// Stage launchers: read args.workspace_ptr split per layout, run one kernel. +// ============================================================================ + +namespace { + +template +void launch_kstats_only(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& s) +{ + const auto layout = sparge_blockmap_compute_workspace_layout(traits, args); + auto* ws_base = static_cast(args.workspace_ptr); + void* pooled_k_ptr = ws_base + layout.pooled_k_offset; + void* sim_k_ptr = ws_base + layout.sim_k_offset; + + auto [kargs, grids] = + sparge_kstats_create_kargs_and_grids(args, pooled_k_ptr, sim_k_ptr); + const dim3 blocks = KStatsKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = KStatsKernel::kBlockPerCu; + ck_tile::make_kernel(KStatsKernel{}, grids, blocks, 0, kargs)(s); +} + +template +void launch_blockmap_only(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& s) +{ + const auto layout = sparge_blockmap_compute_workspace_layout(traits, args); + auto* ws_base = static_cast(args.workspace_ptr); + void* pooled_k_ptr = ws_base + layout.pooled_k_offset; + void* sim_k_ptr = ws_base + layout.sim_k_offset; + + auto [kargs, grids] = + sparge_blockmap_create_kargs_and_grids(args, pooled_k_ptr, sim_k_ptr); + const dim3 blocks = BlockMapKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = BlockMapKernel::kBlockPerCu; + ck_tile::make_kernel(BlockMapKernel{}, grids, blocks, 0, kargs)(s); +} + +} // namespace + +// ============================================================================ +// Oneshot stages (no timing): caller chains them via launch_kernel. +// ============================================================================ + +void sparge_kstats_fwd_oneshot(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& s) +{ + if(traits.data_type == "fp16" && traits.hdim_q == 128) + { + launch_kstats_only(traits, args, s); + return; + } + if(traits.data_type == "bf16" && traits.hdim_q == 128) + { + launch_kstats_only(traits, args, s); + return; + } + std::cerr << "sparge_kstats_fwd_oneshot: unsupported config (data_type=" << traits.data_type + << ", hdim_q=" << traits.hdim_q << ")" << std::endl; +} + +void sparge_blockmap_only_fwd_oneshot(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& s) +{ + if(traits.data_type == "fp16" && traits.hdim_q == 128) + { + launch_blockmap_only(traits, args, s); + return; + } + if(traits.data_type == "bf16" && traits.hdim_q == 128) + { + launch_blockmap_only(traits, args, s); + return; + } + std::cerr << "sparge_blockmap_only_fwd_oneshot: unsupported config (data_type=" + << traits.data_type << ", hdim_q=" << traits.hdim_q << ")" << std::endl; +} + +// ============================================================================ +// Combined functions: kstats + blockmap + attention timed together. +// ============================================================================ + +float sparge_jenga_fwd(sparge_blockmap_traits bmap_t, + sparge_blockmap_args bmap_a, + fmha_jenga_fwd_traits attn_t, + fmha_jenga_fwd_args attn_a, + const ck_tile::stream_config& s) +{ + if(s.log_level_ > 0) + std::cout << ", sparge_kstats_" << bmap_t.data_type << "_d" << bmap_t.hdim_q + << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q + << ", fmha_jenga_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q << std::flush; + + return ck_tile::launch_kernel( + s, + [=](const ck_tile::stream_config& s_) { sparge_kstats_fwd_oneshot(bmap_t, bmap_a, s_); }, + [=](const ck_tile::stream_config& s_) { + sparge_blockmap_only_fwd_oneshot(bmap_t, bmap_a, s_); + }, + [=](const ck_tile::stream_config& s_) { fmha_jenga_fwd_oneshot(attn_t, attn_a, s_); }); +} + +float sparge_vsa_fwd_combined(sparge_blockmap_traits bmap_t, + sparge_blockmap_args bmap_a, + fmha_vsa_fwd_traits attn_t, + fmha_vsa_fwd_args attn_a, + const ck_tile::stream_config& s) +{ + if(s.log_level_ > 0) + std::cout << ", sparge_kstats_" << bmap_t.data_type << "_d" << bmap_t.hdim_q + << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q + << ", fmha_vsa_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q << std::flush; + + return ck_tile::launch_kernel( + s, + [=](const ck_tile::stream_config& s_) { sparge_kstats_fwd_oneshot(bmap_t, bmap_a, s_); }, + [=](const ck_tile::stream_config& s_) { + sparge_blockmap_only_fwd_oneshot(bmap_t, bmap_a, s_); + }, + [=](const ck_tile::stream_config& s_) { fmha_vsa_fwd_oneshot(attn_t, attn_a, s_); }); +} + +// R26 split-launch: partition heads into two buckets by per-head pv_threshold +// (sentinel >= 1e29f vs finite), materialise device-side remap LUTs, then issue +// one fmha launch per non-empty bucket. Bucket selection happens entirely on +// the host; the kernel just reads head_remap_ptr[blockIdx.y] to recover the +// original head index. +// +// R30: the "finite" bucket binary is selected by attn_a.pv_mode_compile (1 = +// per-wave (R25 A1 default); 2 = per-block (R30)). The "sentinel" bucket is +// always kNone (mode 0) — sentinel heads requested PV-skip OFF so the per-head +// per-mode choice is degenerate. Per-head per-mode bucket (3+ buckets) is a +// future R31 extension; this commit keeps the 2-bucket scheme and routes the +// active mode through attn_true.pv_mode_compile. +// +// Backward compat: if pv_threshold_per_head_ptr is null, fall back to the +// original single-launch path using attn_a.pv_threshold scalar. +float sparge_sparge_fwd_combined(sparge_blockmap_traits bmap_t, + sparge_blockmap_args bmap_a, + fmha_sparge_fwd_traits attn_t, + fmha_sparge_fwd_args attn_a, + const ck_tile::stream_config& s) +{ + if(s.log_level_ > 0) + std::cout << ", sparge_kstats_" << bmap_t.data_type << "_d" << bmap_t.hdim_q + << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q + << ", fmha_sparge_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q + << std::flush; + + // Decide bucket plan. Pull per-head thresholds from device buffer when set, + // else broadcast the scalar across all heads to a single bucket. + const int nhead_q = attn_a.nhead_q; + std::vector false_heads; // pv_threshold >= 1e29f -> kNone binary (mode 0) + std::vector true_heads; // finite -> kPerWave or kPerBlock binary + false_heads.reserve(nhead_q); + true_heads.reserve(nhead_q); + + if(attn_a.pv_threshold_per_head_ptr != nullptr) + { + std::vector pv_host(nhead_q); + auto err = hipMemcpy(pv_host.data(), + attn_a.pv_threshold_per_head_ptr, + static_cast(nhead_q) * sizeof(float), + hipMemcpyDeviceToHost); + if(err != hipSuccess) + { + std::cerr << "sparge_sparge_fwd_combined: hipMemcpy pv_threshold_per_head failed: " + << hipGetErrorString(err) << std::endl; + return -1.f; + } + for(int h = 0; h < nhead_q; ++h) + { + if(pv_host[h] >= 1e29f) + false_heads.push_back(h); + else + true_heads.push_back(h); + } + } + else + { + // Scalar mode: identity remap, single binary picked by pv_mode_compile + // (R30) or the legacy pv_skip_compile bool (R25 A1). When the scalar + // pv_threshold is the sentinel, force the kNone binary regardless of + // mode_compile — the mode is then irrelevant because no skip happens. + if(attn_a.pv_threshold >= 1e29f) + for(int h = 0; h < nhead_q; ++h) + false_heads.push_back(h); + else + for(int h = 0; h < nhead_q; ++h) + true_heads.push_back(h); + } + + // R26-R3 gate: skip empty buckets so we never schedule a zero-grid launch. + const bool need_false = !false_heads.empty(); + const bool need_true = !true_heads.empty(); + + // Materialise per-bucket head-remap device buffers (one int32 each, freed at + // end of this function -- before that we keep them alive across the launch). + ck_tile::DeviceMem false_remap_dev(std::max(1, false_heads.size() * sizeof(int32_t))); + ck_tile::DeviceMem true_remap_dev(std::max(1, true_heads.size() * sizeof(int32_t))); + if(need_false) + false_remap_dev.ToDevice(false_heads.data()); + if(need_true) + true_remap_dev.ToDevice(true_heads.data()); + + // Build per-bucket attn args. Scalar pv_threshold field is left as-is so the + // device fallback (when pv_threshold_per_head is null and remap is null) + // remains correct; per-head buffer takes priority when remap is active. + fmha_sparge_fwd_args attn_false = attn_a; + fmha_sparge_fwd_args attn_true = attn_a; + // R30: derive the effective per-bucket mode. The "true" (finite-threshold) + // bucket inherits attn_a.pv_mode_compile so the CLI --pv_mode picks per-wave + // (1) or per-block (2). The "false" (sentinel) bucket is always mode 0 + // (kNone). If a caller still sets only the legacy pv_skip_compile bool + // (R25-A1-era) and leaves pv_mode_compile at its default 1, the behaviour + // is unchanged. + if(need_false) + { + attn_false.head_remap_ptr = static_cast(false_remap_dev.GetDeviceBuffer()); + attn_false.nhead_in_launch = static_cast(false_heads.size()); + attn_false.pv_skip_compile = false; // legacy bool — kept consistent + attn_false.pv_mode_compile = 0; // route to kNone binary (R30) + } + if(need_true) + { + attn_true.head_remap_ptr = static_cast(true_remap_dev.GetDeviceBuffer()); + attn_true.nhead_in_launch = static_cast(true_heads.size()); + attn_true.pv_skip_compile = true; // legacy bool — kept consistent + // R30: pv_mode_compile carries through unchanged from attn_a (CLI choice). + // attn_true is a copy of attn_a, so attn_true.pv_mode_compile already + // holds the user's selection (0 = kNone, 1 = per-wave, 2 = per-block). + // We deliberately do NOT override mode 0 here: if the user passes + // --pv_mode=none together with a finite pv_threshold, that is an + // explicit "build the bucket but don't skip" request (useful as a + // control measurement). Routing it to kNone keeps the CLI honest. + } + + // Chain callables: kstats -> blockmap -> [fmha_false?] -> [fmha_true?]. + // Empty buckets are skipped by emitting an empty lambda; the wrapped path + // never issues a kernel launch in that branch. + auto cb_kstats = [=](const ck_tile::stream_config& s_) { + sparge_kstats_fwd_oneshot(bmap_t, bmap_a, s_); + }; + auto cb_bmap = [=](const ck_tile::stream_config& s_) { + sparge_blockmap_only_fwd_oneshot(bmap_t, bmap_a, s_); + }; + auto cb_fmha_false = [=](const ck_tile::stream_config& s_) { + if(need_false) + fmha_sparge_fwd_oneshot(attn_t, attn_false, s_); + }; + auto cb_fmha_true = [=](const ck_tile::stream_config& s_) { + if(need_true) + fmha_sparge_fwd_oneshot(attn_t, attn_true, s_); + }; + + // launch_kernel returns elapsed ms for the whole chain when timing is on. + // We always pass 4 callables and gate execution inside the lambda; this + // keeps the timing contract stable, while a no-op lambda has negligible + // (~ns) cost compared to the saved 5-15us host launch. + return ck_tile::launch_kernel(s, cb_kstats, cb_bmap, cb_fmha_false, cb_fmha_true); +} diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp new file mode 100644 index 00000000000..7591b94e54d --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp @@ -0,0 +1,184 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" +#include "ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp" +#include "ck_tile/ops/sparse_attn/pipeline/sparge_kstats_pipeline.hpp" +#include "ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp" +#include "ck_tile/ops/sparse_attn/kernel/sparge_kstats_kernel.hpp" + +#include "fmha_fwd_trek.hpp" + +#include +#include + +// ============================================================================ +// Args and traits for sparge block map GPU kernel +// ============================================================================ +struct sparge_blockmap_args +{ + const void* q_ptr; + const void* k_ptr; + + ck_tile::index_t batch; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + + float simthreshd1; + float cdfthreshd; + float topk; + float scale; + + void* block_map_ptr; + void* lut_ptr; + void* valid_block_num_ptr; + + // Caller-owned K-stats workspace; size from sparge_blockmap_get_workspace_size. + // Internal layout (pooled_k then sim_k) given by sparge_blockmap_workspace_layout. + void* workspace_ptr = nullptr; + + // size = nhead_q to match SpargeAttn upstream hyperparameter_check + const float* simthreshd1_per_head_ptr = nullptr; + const float* cdfthreshd_per_head_ptr = nullptr; + const float* topk_per_head_ptr = nullptr; + + // R32 Items 2+3. Pipeline only honours mask_enum::mask_top_left; CLI warns + // on other types. Defaults preserve back-compat for callers not yet setting. + mask_enum mask_type = mask_enum::no_mask; + bool attention_sink = false; +}; + +struct sparge_blockmap_workspace_layout +{ + size_t pooled_k_offset; // bytes from workspace_ptr + size_t pooled_k_bytes; + size_t sim_k_offset; // bytes from workspace_ptr + size_t sim_k_bytes; + size_t total_bytes; +}; + +struct sparge_blockmap_traits +{ + std::string data_type; + int hdim_q; +}; + +// ============================================================================ +// Create kernel args and grid dimensions +// ============================================================================ +template +auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args, + const void* pooled_k_ws_ptr, + const void* sim_k_ws_ptr) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = BlockMapKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.stride_q, + args.stride_k, + args.nhead_stride_q, + args.nhead_stride_k, + args.batch_stride_q, + args.batch_stride_k, + args.simthreshd1, + args.cdfthreshd, + args.topk, + args.scale, + args.block_map_ptr, + args.lut_ptr, + args.valid_block_num_ptr, + pooled_k_ws_ptr, + sim_k_ws_ptr, + args.topk_per_head_ptr, + args.cdfthreshd_per_head_ptr, + static_cast(args.mask_type), + args.attention_sink); + + dim3 grids = BlockMapKernel::GridSize(args.batch, args.nhead_q, args.seqlen_q); + return ck_tile::make_tuple(kargs, grids); +} + +template +auto sparge_kstats_create_kargs_and_grids(sparge_blockmap_args args, + void* pooled_k_ws_ptr, + void* sim_k_ws_ptr) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = KStatsKernel::MakeKargs(args.k_ptr, + args.seqlen_k, + args.hdim_q, + args.nhead_k, + args.stride_k, + args.nhead_stride_k, + args.batch_stride_k, + args.simthreshd1, + pooled_k_ws_ptr, + sim_k_ws_ptr, + args.simthreshd1_per_head_ptr); + + dim3 grids = KStatsKernel::GridSize(args.batch, args.nhead_k, args.seqlen_k); + return ck_tile::make_tuple(kargs, grids); +} + +// ============================================================================ +// Hand-written template instantiation dispatch +// ============================================================================ + +// Workspace sizing helpers (host, no template instantiation needed). +sparge_blockmap_workspace_layout +sparge_blockmap_compute_workspace_layout(sparge_blockmap_traits traits, sparge_blockmap_args args); + +inline size_t sparge_blockmap_get_workspace_size(sparge_blockmap_traits traits, + sparge_blockmap_args args) +{ + return sparge_blockmap_compute_workspace_layout(traits, args).total_bytes; +} + +// Stage 1: K-stats only. Writes pooled_k + sim_k into args.workspace_ptr. +void sparge_kstats_fwd_oneshot(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& stream_config); + +// Stage 2: block_map only. Reads pooled_k + sim_k from args.workspace_ptr. +void sparge_blockmap_only_fwd_oneshot(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& stream_config); + +// Combined functions: kstats + blockmap + attention with unified timing. +float sparge_jenga_fwd(sparge_blockmap_traits, + sparge_blockmap_args, + fmha_jenga_fwd_traits, + fmha_jenga_fwd_args, + const ck_tile::stream_config&); + +float sparge_vsa_fwd_combined(sparge_blockmap_traits, + sparge_blockmap_args, + fmha_vsa_fwd_traits, + fmha_vsa_fwd_args, + const ck_tile::stream_config&); + +float sparge_sparge_fwd_combined(sparge_blockmap_traits, + sparge_blockmap_args, + fmha_sparge_fwd_traits, + fmha_sparge_fwd_args, + const ck_tile::stream_config&); diff --git a/example/ck_tile/50_sparse_attn/sparge_tool.hpp b/example/ck_tile/50_sparse_attn/sparge_tool.hpp new file mode 100644 index 00000000000..94426c6fd8a --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_tool.hpp @@ -0,0 +1,411 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace sparge { + +struct SpargeParams +{ + // BLKQ=64, BLKK=128 align with SpargeAttn SM90 (Hopper) convention; + // cf. upstream csrc/qattn/qk_int_sv_f8_cuda_sm90.cu:143-144. + // SM80/SM89 path uses the inverse 128/64 (cf. qk_int_sv_f16_cuda_sm80.cu:137-138). + int BLKQ = 64; + int BLKK = 128; + + // Similarity gate threshold (TODO: per-head support). + float simthreshd1 = 0.6f; + + // Exactly one of the following should be used: + // - Use CDF threshold if topk < 0 + // - Both should be in [0, 1] <-- NEED TO CHECK THIS + float cdfthreshd = 0.98f; + float topk = -1.0f; + + // If true, treat Q/K as BHSD; otherwise BSHD (same convention as CK examples). + bool i_perm = true; +}; + +// Output format CK VSA expects. +struct VSALut +{ + ck_tile::HostTensor lut; // [B, Hq, Q_blk, K_blk] delta-encoded + ck_tile::HostTensor valid_block_num; // [B, Hq, Q_blk] +}; + +namespace detail { + +template +inline float to_f32(const T& x) +{ + return ck_tile::type_convert(x); +} + +// Read element from HostTensor with either BHSD or BSHD layout. +// Q: [B, Hq, Sq, D] if i_perm else [B, Sq, Hq, D] +// K: [B, Hk, Sk, D] if i_perm else [B, Sk, Hk, D] +template +inline float load(const ck_tile::HostTensor& X, bool i_perm, int b, int h, int s, int d) +{ + return i_perm ? to_f32(X(b, h, s, d)) : to_f32(X(b, s, h, d)); +} + +// Compute pooled mean vector of one block: mean over tokens in [s0, s1). +template +std::vector +pooled_mean_block(const ck_tile::HostTensor& X, bool i_perm, int b, int h, int s0, int s1, int d) +{ + std::vector mean(d, 0.0f); + const int bs = std::max(0, s1 - s0); + if(bs == 0) + return mean; + + for(int s = s0; s < s1; ++s) + { + for(int d_ = 0; d_ < d; ++d_) + { + mean[d_] += load(X, i_perm, b, h, s, d_); + } + } + const float inv = 1.0f / static_cast(bs); + for(int d_ = 0; d_ < d; ++d_) + mean[d_] *= inv; + return mean; +} + +// Compute "sim" flag of one block following SpargeAttn's intent: +// mean_sim = sum(Gram(x_hat)) / (BS_*BS_), where x_hat are token vectors normalized along D. +// +// Important: sum(Gram) = ||sum_i x_hat_i||^2, so we can compute it in O(BS_*D) exactly +// instead of O(BS_^2 * D). +template +bool sim_block_flag(const ck_tile::HostTensor& X, + bool i_perm, + int b, + int h, + int s0, + int s1, + int d, + float simthreshd1) +{ + const int bs = std::max(0, s1 - s0); + if(bs == 0) + return false; + + std::vector sum_hat(d, 0.0f); + + for(int s = s0; s < s1; ++s) + { + // Compute L2 norm over D. + float norm2 = 0.0f; + for(int d_ = 0; d_ < d; ++d_) + { + const float v = load(X, i_perm, b, h, s, d_); + norm2 += v * v; + } + float inv_norm = 1.0f; + // spargeAttn use eps to prevent division by zero + if(norm2 > 0.0f) + inv_norm = 1.0f / std::sqrt(norm2); + + // Accumulate normalized vector. + for(int d_ = 0; d_ < d; ++d_) + { + sum_hat[d_] += load(X, i_perm, b, h, s, d_) * inv_norm; + } + } + + float sum_gram = 0.0f; + for(int d_ = 0; d_ < d; ++d_) + sum_gram += sum_hat[d_] * sum_hat[d_]; + + const float denom = static_cast(bs) * static_cast(bs); + const float mean_sim = sum_gram / denom; + + return mean_sim > simthreshd1; +} + +inline int select_count_from_cdf(const std::vector& sorted_probs, float cdfthreshd) +{ + // Choose the smallest n such that cdf[n-1] >= cdfthreshd. + // Ensure at least 1. + if(sorted_probs.empty()) + return 0; + if(cdfthreshd <= 0.0f) + return 1; + + float c = 0.0f; + for(int i = 0; i < static_cast(sorted_probs.size()); ++i) + { + c += sorted_probs[i]; + if(c >= cdfthreshd) + return i + 1; + } + return static_cast(sorted_probs.size()); +} + +inline int select_count_from_topk(int K_blk, float topk) +{ + if(K_blk <= 0) + return 0; + int n = static_cast(std::floor(topk * static_cast(K_blk))); + n = std::max(1, n); + return n; +} + +} // namespace detail + +// Build one-hot block_map[b,hq,qb,kb] in {0,1}. +// - No causal mask +// - No attention sink +// - Logic matches SpargeAttn's structure: +// - score softmax is only over sim_kblocks; ~sim_kblocks are forced ON later +// - if a Q-block is not "similar", force the whole row ON +template +ck_tile::HostTensor build_block_map_meansim(const ck_tile::HostTensor& Q, + const ck_tile::HostTensor& K, + const SpargeParams& p) +{ + const auto qlens = Q.get_lengths(); + const auto klens = K.get_lengths(); + + const int B = static_cast(qlens[0]); + const int Hq = p.i_perm ? static_cast(qlens[1]) : static_cast(qlens[2]); + const int Sq = p.i_perm ? static_cast(qlens[2]) : static_cast(qlens[1]); + const int D = static_cast(qlens[3]); + + [[maybe_unused]] const int Bk = static_cast(klens[0]); + const int Hk = p.i_perm ? static_cast(klens[1]) : static_cast(klens[2]); + const int Sk = p.i_perm ? static_cast(klens[2]) : static_cast(klens[1]); + [[maybe_unused]] const int Dk = static_cast(klens[3]); + + assert(B == Bk && D == Dk && Hq % Hk == 0); + assert(p.BLKQ > 0 && p.BLKK > 0); + + const int nhead_ratio_qk = Hq / Hk; + const int Q_blk = ck_tile::integer_divide_ceil(Sq, p.BLKQ); + const int K_blk = ck_tile::integer_divide_ceil(Sk, p.BLKK); + + ck_tile::HostTensor block_map({B, Hq, Q_blk, K_blk}); + + // pooled_q: [B,Hq,Q_blk,D], pooled_k: [B,Hk,K_blk,D] + // sim_q: [B,Hq,Q_blk], sim_k: [B,Hk,K_blk] + std::vector pooled_q(static_cast(B) * Hq * Q_blk * D, 0.0f); + std::vector pooled_k(static_cast(B) * Hk * K_blk * D, 0.0f); + std::vector sim_q(static_cast(B) * Hq * Q_blk, 0); + std::vector sim_k(static_cast(B) * Hk * K_blk, 0); + + auto idx_pq = [&](int b, int hq, int qb, int d) { + return (((b * Hq + hq) * Q_blk + qb) * D + d); + }; + auto idx_pk = [&](int b, int hk, int kb, int d) { + return (((b * Hk + hk) * K_blk + kb) * D + d); + }; + auto idx_sq = [&](int b, int hq, int qb) { return ((b * Hq + hq) * Q_blk + qb); }; + auto idx_sk = [&](int b, int hk, int kb) { return ((b * Hk + hk) * K_blk + kb); }; + + for(int b = 0; b < B; ++b) + { + for(int hq = 0; hq < Hq; ++hq) + { + // Q blocks + for(int qb = 0; qb < Q_blk; ++qb) + { + const int s0 = qb * p.BLKQ; + const int s1 = std::min(Sq, (qb + 1) * p.BLKQ); + + // pooled mean + auto mean = detail::pooled_mean_block(Q, p.i_perm, b, hq, s0, s1, D); + for(int d = 0; d < D; ++d) + pooled_q[idx_pq(b, hq, qb, d)] = mean[d]; + + // sim flag + sim_q[idx_sq(b, hq, qb)] = + detail::sim_block_flag(Q, p.i_perm, b, hq, s0, s1, D, p.simthreshd1) ? 1 : 0; + } + } + + for(int hk = 0; hk < Hk; ++hk) + { + // K blocks + for(int kb = 0; kb < K_blk; ++kb) + { + const int s0 = kb * p.BLKK; + const int s1 = std::min(Sk, (kb + 1) * p.BLKK); + + auto mean = detail::pooled_mean_block(K, p.i_perm, b, hk, s0, s1, D); + for(int d = 0; d < D; ++d) + pooled_k[idx_pk(b, hk, kb, d)] = mean[d]; + + sim_k[idx_sk(b, hk, kb)] = + detail::sim_block_flag(K, p.i_perm, b, hk, s0, s1, D, p.simthreshd1) ? 1 : 0; + } + } + } + + const float scale = 1.0f / std::sqrt(static_cast(D)); + + // Main loop + for(int b = 0; b < B; ++b) + { + for(int hq = 0; hq < Hq; ++hq) + { + const int hk = hq / nhead_ratio_qk; + + for(int qb = 0; qb < Q_blk; ++qb) + { + const bool q_is_sim = (sim_q[idx_sq(b, hq, qb)] != 0); + + // If Q-block is not "similar", force dense row. + if(!q_is_sim) + { + for(int kb = 0; kb < K_blk; ++kb) + block_map(b, hq, qb, kb) = 1; + continue; + } + + // Compute scores over K blocks (only sim_kblocks participate in softmax; others set + // to -inf). + std::vector score(K_blk, -std::numeric_limits::infinity()); + for(int kb = 0; kb < K_blk; ++kb) + { + const bool k_is_sim = (sim_k[idx_sk(b, hk, kb)] != 0); + if(!k_is_sim) + { + block_map(b, hq, qb, kb) = 1; + continue; + } + + float dot = 0.0f; + for(int d = 0; d < D; ++d) + { + dot += pooled_q[idx_pq(b, hq, qb, d)] * pooled_k[idx_pk(b, hk, kb, d)]; + } + score[kb] = dot * scale; + } + + // Softmax over K_blk (numerically stable). If all -inf, probs become all zeros. + float maxv = -std::numeric_limits::infinity(); + for(int kb = 0; kb < K_blk; ++kb) + maxv = std::max(maxv, score[kb]); + + std::vector prob(K_blk, 0.0f); + if(std::isfinite(maxv)) + { + float sumexp = 0.0f; + for(int kb = 0; kb < K_blk; ++kb) + { + if(!std::isfinite(score[kb])) + continue; + const float e = std::exp(score[kb] - maxv); + prob[kb] = e; + sumexp += e; + } + if(sumexp > 0.0f) + { + const float inv = 1.0f / sumexp; + for(int kb = 0; kb < K_blk; ++kb) + prob[kb] *= inv; + } + else + { + // All exponentials underflowed: keep zeros. + std::fill(prob.begin(), prob.end(), 0.0f); + } + } + + // Sort indices by prob descending. + std::vector order(K_blk); + std::iota(order.begin(), order.end(), 0); + std::sort(order.begin(), order.end(), [&](int a, int c) { + if(prob[a] != prob[c]) + return prob[a] > prob[c]; + return a < c; // tie-breaker for determinism + }); + + // Determine how many to select. + int num_to_select = 0; + if(p.topk > 0.0f) + { + num_to_select = detail::select_count_from_topk(K_blk, p.topk); + } + else + { + // Use CDF threshold selection (smallest n s.t. cumulative prob >= cdfthreshd). + std::vector sorted_probs(K_blk); + for(int i = 0; i < K_blk; ++i) + sorted_probs[i] = prob[order[i]]; + num_to_select = detail::select_count_from_cdf(sorted_probs, p.cdfthreshd); + num_to_select = std::max(1, num_to_select); + } + + // Select top-kb blocks by order[0..num_to_select-1]. + for(int i = 0; i < num_to_select; ++i) + { + const int kb = order[i]; + block_map(b, hq, qb, kb) = 1; + } + } + } + } + + return block_map; +} + +// Convert one-hot block_map -> delta-encoded LUT + valid_block_num (CK VSA format). +template +VSALut block_map_to_vsa_lut_delta(const ck_tile::HostTensor& block_map) +{ + const auto lens = block_map.get_lengths(); + const int B = static_cast(lens[0]); + const int H = static_cast(lens[1]); + const int Q = static_cast(lens[2]); + const int K = static_cast(lens[3]); + + VSALut out{ + ck_tile::HostTensor({B, H, Q, K}), + ck_tile::HostTensor({B, H, Q}), + }; + + for(int b = 0; b < B; ++b) + { + for(int h = 0; h < H; ++h) + { + for(int q = 0; q < Q; ++q) + { + int32_t valid = 0; + int32_t prev = 0; + + for(int k = 0; k < K; ++k) + { + const bool on = static_cast(block_map(b, h, q, k)) != 0; + if(on) + { + out.lut(b, h, q, valid) = static_cast(k - prev); + prev = static_cast(k); + ++valid; + } + } + + out.valid_block_num(b, h, q) = valid; + + // Optional: zero-fill the unused tail for determinism. + for(int i = valid; i < K; ++i) + out.lut(b, h, q, i) = 0; + } + } + } + + return out; +} + +} // namespace sparge diff --git a/example/ck_tile/50_sparse_attn/test_sparge.cpp b/example/ck_tile/50_sparse_attn/test_sparge.cpp new file mode 100644 index 00000000000..c368cb0304c --- /dev/null +++ b/example/ck_tile/50_sparse_attn/test_sparge.cpp @@ -0,0 +1,702 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// Unified test for Sparge pipeline: blockmap generation + sparse attention (Jenga/VSA). + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/reference/reference_blocked_attention.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" + +#include "01_fmha/mask.hpp" // R32: mask_info::decode, mask_enum +#include "fmha_fwd_trek.hpp" +#include "sparge_blockmap_trek.hpp" +#include "sparge_tool.hpp" + +// ============================================================================ +// Helpers +// ============================================================================ + +template +ck_tile::HostTensor make_qkv_tensor(ck_tile::index_t batch, + ck_tile::index_t nhead, + ck_tile::index_t seqlen, + ck_tile::index_t hdim, + bool i_perm) +{ + if(i_perm) + return ck_tile::HostTensor({batch, nhead, seqlen, hdim}); + return ck_tile::HostTensor({batch, seqlen, nhead, hdim}); +} + +template +ck_tile::HostTensor to_bhsd(const ck_tile::HostTensor& tensor, bool is_bhsd) +{ + auto lens = tensor.get_lengths(); + ck_tile::index_t batch = lens[0]; + ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1]; + ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2]; + ck_tile::index_t hdim = lens[3]; + + ck_tile::HostTensor out({batch, nhead, seqlen, hdim}); + for(ck_tile::index_t b = 0; b < batch; ++b) + for(ck_tile::index_t h = 0; h < nhead; ++h) + for(ck_tile::index_t s = 0; s < seqlen; ++s) + for(ck_tile::index_t d = 0; d < hdim; ++d) + out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d); + return out; +} + +template +auto get_error_tolerance() +{ + // Matches dense FMHA fp16/bf16 bounds; validated on (b=1,h=2,d=128, + // s in {512, 2048, 4096, 8192}) with maxdiff = 0.00 across both dtypes. + double rtol = 1e-2; + double atol = 4e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template +float to_float_for_compare(T value) +{ + return static_cast(value); +} + +template <> +float to_float_for_compare(ck_tile::bf16_t value) +{ + return ck_tile::type_convert(value); +} + +// ============================================================================ +// Arg parser +// ============================================================================ +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "0:no validation, 1:cpu validation") + .insert("pipeline", "jenga", "attention pipeline: jenga / vsa") + .insert("b", "1", "batch size") + .insert("h", "4", "num of head for q") + .insert("h_k", "-1", "num of head for k/v, -1 means equal to h") + .insert("s", "4096", "seqlen_q") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("topk", "0.3", "topk ratio for blockmap (fraction of K-blocks to keep)") + .insert("cdfthreshd", "-1", "CDF threshold for blockmap (overrides topk if >= 0)") + .insert("simthreshd1", "0.6", "similarity threshold for blockmap") + .insert("prec", "fp16", "data type: fp16/bf16") + .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") + .insert("operm", "1", "permute output") + .insert("seed", "42", "random seed") + .insert("warmup", "5", "warmup iterations") + .insert("repeat", "20", "benchmark iterations") + .insert("kname", "0", "print kernel name") + .insert("dump_o", + "", + "if non-empty, dump raw output buffer bytes to this path (for bit-identical " + "baseline comparison)") + .insert("pv_threshold", + "1e30", + "SpargeAttn PV-skip per-Q-tile threshold; default +1e30 disables skip") + .insert("pv_threshold_per_head", + "", + "R26 split-launch: comma-separated per-head pv_threshold list " + "(length must == h). Empty = scalar mode using -pv_threshold.") + .insert("pv_skip_compile", + "1", + "R25 V0: 1=use kEnablePVSkip=true template instance (existing path); 0=use " + "kEnablePVSkip=false instance (PV-skip AST removed at compile time, equivalent to " + "VSA baseline). Deprecated by -pv_mode; kept for back-compat scripts.") + .insert("pv_mode", + "warp", + "R30: PV-skip mode select. one of {none, warp, block}. " + "none = no skip (kNone binary; matches VSA baseline). " + "warp = per-wavefront butterfly vote (R25 A1; default). " + "block = per-block AND vote via 1 LDS slot + block_sync_lds (R30). " + "Overrides -pv_skip_compile when set explicitly.") + .insert("mask", + "0", + "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" + "'t', top-left causal mask, 'b', bottom-r causal mask\n" + "'t:l,r', top-left sliding window attn(swa) with FA style left right size\n" + "'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n" + "'xt:window_size', xformer style masking from top-left, " + "window_size negative is causal, positive is swa\n" + "'xb:window_size', xformer style masking from bottom-r, " + "window_size negative is causal, positive is swa\n" + "'g:y,x', generic attention mask coordinate with y/x size " + "(only debug purpose for now)") + .insert("attention_sink", + "0", + "SpargeAttn: force block-map column 0 ON (kb=0 always selected). " + "0=off, 1=on. Block-map level only; independent of -mask sink prefix."); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// ============================================================================ +// Main test +// ============================================================================ +template +bool run_test(const ck_tile::ArgParser& arg_parser) +{ + int do_validation = arg_parser.get_int("v"); + std::string pipeline = arg_parser.get_str("pipeline"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + float topk = arg_parser.get_float("topk"); + float cdfthreshd = arg_parser.get_float("cdfthreshd"); + float simthreshd1 = arg_parser.get_float("simthreshd1"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int kname = arg_parser.get_int("kname"); + std::string dump_o_path = arg_parser.get_str("dump_o"); + float pv_threshold = arg_parser.get_float("pv_threshold"); + int pv_skip_compile = arg_parser.get_int("pv_skip_compile"); + std::string pv_per_head_s = arg_parser.get_str("pv_threshold_per_head"); + std::string pv_mode_str = arg_parser.get_str("pv_mode"); + std::string mask_str = arg_parser.get_str("mask"); + bool attention_sink = arg_parser.get_bool("attention_sink"); + + // R30: --pv_mode maps to the int dispatched at host. + // none -> 0 (kNone), warp -> 1 (kPerWave), block -> 2 (kPerBlock). + // Back-compat: if the user explicitly passed -pv_skip_compile=0 but left + // -pv_mode at default ("warp"), honour the legacy intent (mode=0). The CLI + // doesn't expose "was this passed explicitly", so we mirror the rule used + // pre-R30: bool 0 => kNone, bool 1 => kPerWave. + int pv_mode_compile; + if(pv_mode_str == "none") + pv_mode_compile = 0; + else if(pv_mode_str == "warp") + pv_mode_compile = 1; + else if(pv_mode_str == "block") + pv_mode_compile = 2; + else + { + std::cerr << "Unknown -pv_mode value: '" << pv_mode_str + << "' (expected one of: none, warp, block)" << std::endl; + return false; + } + // Legacy bool wins iff user explicitly disabled and pv_mode stayed warp. + if(pv_skip_compile == 0 && pv_mode_str == "warp") + pv_mode_compile = 0; + + if(nhead_k < 0) + nhead_k = nhead; + if(seqlen_k < 0) + seqlen_k = seqlen_q; + if(hdim_v < 0) + hdim_v = hdim_q; + + mask_info mask = mask_info::decode(mask_str, seqlen_q, seqlen_k); + if(mask.type != mask_enum::no_mask && mask.type != mask_enum::mask_top_left) + std::fprintf(stderr, + "[test_sparge] WARN: -mask='%s' (type=%d) - block-map only " + "filters mask_top_left; selection will not prune past-diagonal " + "blocks. attention kernel still applies the mask.\n", + mask_str.c_str(), + static_cast(mask.type)); + + // If cdfthreshd >= 0, use CDF mode; otherwise use topk mode + if(cdfthreshd >= 0.0f) + topk = -1.0f; + + constexpr ck_tile::index_t BLKQ = 64; + constexpr ck_tile::index_t BLKK = 128; + + if(hdim_q != 128 || hdim_v != 128) + { + std::cout << "\n>>> TEST SKIPPED <<<\n" + << "Kernel instances are generated for hdim=128 only.\n"; + return true; + } + + ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; + ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; + + std::string prec_str = std::is_same_v ? "fp16" : "bf16"; + std::cout << "[" << pipeline << "|" << prec_str << "] b=" << batch << " h=" << nhead + << " s=" << seqlen_q << " d=" << hdim_q << " topk=" << topk << " sim1=" << simthreshd1 + << std::flush; + + // ---- allocate host tensors ---- + auto q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); + auto k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); + auto v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); + auto output_host = o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) + : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); + + ck_tile::HostTensor block_map_host({batch, nhead, num_q_blocks, num_k_blocks}); + ck_tile::HostTensor lut_host({batch, nhead, num_q_blocks, num_k_blocks}); + ck_tile::HostTensor valid_block_num_host({batch, nhead, num_q_blocks}); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); + + // ---- device tensors ---- + ck_tile::DeviceMem q_dev(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_dev(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_dev(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_dev(output_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem block_map_dev(block_map_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lut_dev(lut_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem valid_bn_dev(valid_block_num_host.get_element_space_size_in_bytes()); + + q_dev.ToDevice(q_host.data()); + k_dev.ToDevice(k_host.data()); + v_dev.ToDevice(v_host.data()); + o_dev.SetZero(); + block_map_dev.SetZero(); + lut_dev.SetZero(); + valid_bn_dev.SetZero(); + + // ---- strides (BHSD when i_perm=true) ---- + auto q_strides = q_host.get_strides(); + auto k_strides = k_host.get_strides(); + auto v_strides = v_host.get_strides(); + auto o_strides = output_host.get_strides(); + + float scale_s = 1.0f / std::sqrt(static_cast(hdim_q)); + + // ---- build blockmap args ---- + sparge_blockmap_traits bmap_traits; + bmap_traits.data_type = std::is_same_v ? "fp16" : "bf16"; + bmap_traits.hdim_q = hdim_q; + + sparge_blockmap_args bmap_args; + bmap_args.q_ptr = q_dev.GetDeviceBuffer(); + bmap_args.k_ptr = k_dev.GetDeviceBuffer(); + bmap_args.batch = batch; + bmap_args.seqlen_q = seqlen_q; + bmap_args.seqlen_k = seqlen_k; + bmap_args.hdim_q = hdim_q; + bmap_args.nhead_q = nhead; + bmap_args.nhead_k = nhead_k; + bmap_args.stride_q = q_strides[i_perm ? 2 : 1]; + bmap_args.stride_k = k_strides[i_perm ? 2 : 1]; + bmap_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; + bmap_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; + bmap_args.batch_stride_q = q_strides[0]; + bmap_args.batch_stride_k = k_strides[0]; + bmap_args.simthreshd1 = simthreshd1; + bmap_args.cdfthreshd = (topk < 0.0f) ? cdfthreshd : -1.0f; + bmap_args.topk = topk; + bmap_args.scale = scale_s; + bmap_args.block_map_ptr = block_map_dev.GetDeviceBuffer(); + bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr; + bmap_args.valid_block_num_ptr = (pipeline == "vsa") ? valid_bn_dev.GetDeviceBuffer() : nullptr; + bmap_args.mask_type = mask.type; // R32 Item 2 + bmap_args.attention_sink = attention_sink; // R32 Item 3 + + // K-stats workspace: caller-owned, sized via host helper, allocated once outside any timing. + const size_t ws_bytes = sparge_blockmap_get_workspace_size(bmap_traits, bmap_args); + ck_tile::DeviceMem kstats_ws_dev(ws_bytes); + bmap_args.workspace_ptr = kstats_ws_dev.GetDeviceBuffer(); + + // Per-head superparam buffers, all sized [nhead_q] to match SpargeAttn upstream contract. + // K-side kernel reads only the first nhead_k entries via [hk]. + // Filled with scalar broadcast; per-head index correctness verified by separate unit test. + ck_tile::DeviceMem topk_per_head_dev(static_cast(nhead) * sizeof(float)); + ck_tile::DeviceMem sim1_per_head_dev(static_cast(nhead) * sizeof(float)); + ck_tile::DeviceMem cdf_per_head_dev(static_cast(nhead) * sizeof(float)); + { + std::vector topk_h(nhead, topk); + std::vector sim1_h(nhead, simthreshd1); + std::vector cdf_h(nhead, cdfthreshd); + topk_per_head_dev.ToDevice(topk_h.data()); + sim1_per_head_dev.ToDevice(sim1_h.data()); + cdf_per_head_dev.ToDevice(cdf_h.data()); + bmap_args.topk_per_head_ptr = + static_cast(topk_per_head_dev.GetDeviceBuffer()); + bmap_args.simthreshd1_per_head_ptr = + static_cast(sim1_per_head_dev.GetDeviceBuffer()); + bmap_args.cdfthreshd_per_head_ptr = + static_cast(cdf_per_head_dev.GetDeviceBuffer()); + } + + // R26 split-launch: optional per-head pv_threshold buffer. Parse the CLI + // comma list (length must match nhead); empty list -> scalar broadcast + // (legacy path, single launch via host). + ck_tile::DeviceMem pv_per_head_dev(static_cast(nhead) * sizeof(float)); + std::vector pv_per_head_host; + bool use_pv_per_head = false; + if(!pv_per_head_s.empty()) + { + std::stringstream ss(pv_per_head_s); + std::string item; + while(std::getline(ss, item, ',')) + { + if(!item.empty()) + pv_per_head_host.push_back(std::stof(item)); + } + if(static_cast(pv_per_head_host.size()) != nhead) + { + std::cerr << "\n[pv_threshold_per_head] length " << pv_per_head_host.size() + << " != h=" << nhead << std::endl; + return false; + } + pv_per_head_dev.ToDevice(pv_per_head_host.data()); + use_pv_per_head = true; + } + + // ---- build attention args ---- + ck_tile::stream_config stream_cfg; + stream_cfg.stream_id_ = nullptr; + stream_cfg.time_kernel_ = true; + stream_cfg.log_level_ = kname; + stream_cfg.cold_niters_ = warmup; + stream_cfg.nrepeat_ = repeat; + + float avg_ms = -1.0f; + + if(pipeline == "jenga") + { + fmha_jenga_fwd_traits attn_traits; + attn_traits.hdim_q = hdim_q; + attn_traits.hdim_v = hdim_v; + attn_traits.data_type = std::is_same_v ? "fp16" : "bf16"; + attn_traits.is_v_rowmajor = true; + attn_traits.mask_type = mask.type; + attn_traits.bm0 = BLKQ; + + fmha_jenga_fwd_args attn_args; + attn_args.q_ptr = q_dev.GetDeviceBuffer(); + attn_args.k_ptr = k_dev.GetDeviceBuffer(); + attn_args.v_ptr = v_dev.GetDeviceBuffer(); + attn_args.block_relation_onehot_ptr = block_map_dev.GetDeviceBuffer(); + attn_args.o_ptr = o_dev.GetDeviceBuffer(); + attn_args.seqlen_q = seqlen_q; + attn_args.seqlen_k = seqlen_k; + attn_args.batch = batch; + attn_args.max_seqlen_q = seqlen_q; + attn_args.hdim_q = hdim_q; + attn_args.hdim_v = hdim_v; + attn_args.nhead_q = nhead; + attn_args.nhead_k = nhead_k; + attn_args.scale_s = scale_s; + attn_args.stride_q = q_strides[i_perm ? 2 : 1]; + attn_args.stride_k = k_strides[i_perm ? 2 : 1]; + attn_args.stride_v = v_strides[i_perm ? 2 : 1]; + attn_args.stride_o = o_strides[o_perm ? 2 : 1]; + attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2]; + attn_args.batch_stride_q = q_strides[0]; + attn_args.batch_stride_k = k_strides[0]; + attn_args.batch_stride_v = v_strides[0]; + attn_args.batch_stride_o = o_strides[0]; + attn_args.window_size_left = mask.left; + attn_args.window_size_right = mask.right; + attn_args.mask_type = static_cast(mask.type); + + avg_ms = sparge_jenga_fwd(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg); + } + else if(pipeline == "vsa") + { + // R25: -pipeline=vsa now dispatches to the sparge pipeline family that adds + // SpargeAttn §4.4 PV-skip; pass pv_threshold (+1e30 disables skip, matches old vsa). + fmha_sparge_fwd_traits attn_traits; + attn_traits.hdim_q = hdim_q; + attn_traits.hdim_v = hdim_v; + attn_traits.data_type = std::is_same_v ? "fp16" : "bf16"; + attn_traits.is_v_rowmajor = true; + attn_traits.mask_type = mask.type; + attn_traits.bm0 = BLKQ; + + fmha_sparge_fwd_args attn_args; + attn_args.q_ptr = q_dev.GetDeviceBuffer(); + attn_args.k_ptr = k_dev.GetDeviceBuffer(); + attn_args.v_ptr = v_dev.GetDeviceBuffer(); + attn_args.lut_ptr = lut_dev.GetDeviceBuffer(); + attn_args.valid_block_num_ptr = valid_bn_dev.GetDeviceBuffer(); + attn_args.o_ptr = o_dev.GetDeviceBuffer(); + attn_args.seqlen_q = seqlen_q; + attn_args.seqlen_k = seqlen_k; + attn_args.batch = batch; + attn_args.max_seqlen_q = seqlen_q; + attn_args.hdim_q = hdim_q; + attn_args.hdim_v = hdim_v; + attn_args.nhead_q = nhead; + attn_args.nhead_k = nhead_k; + attn_args.scale_s = scale_s; + attn_args.pv_threshold = pv_threshold; + attn_args.pv_skip_compile = (pv_skip_compile != 0); + attn_args.pv_mode_compile = pv_mode_compile; // R30: 0=none,1=warp,2=block + // R26 split-launch: when CLI provided per-head list, hand the device + // buffer to the combined wrapper; host code there will partition heads + // into 2 buckets and issue per-bucket launches. + attn_args.pv_threshold_per_head_ptr = + use_pv_per_head ? static_cast(pv_per_head_dev.GetDeviceBuffer()) + : nullptr; + attn_args.stride_q = q_strides[i_perm ? 2 : 1]; + attn_args.stride_k = k_strides[i_perm ? 2 : 1]; + attn_args.stride_v = v_strides[i_perm ? 2 : 1]; + attn_args.stride_o = o_strides[o_perm ? 2 : 1]; + attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2]; + attn_args.batch_stride_q = q_strides[0]; + attn_args.batch_stride_k = k_strides[0]; + attn_args.batch_stride_v = v_strides[0]; + attn_args.batch_stride_o = o_strides[0]; + attn_args.window_size_left = mask.left; + attn_args.window_size_right = mask.right; + attn_args.mask_type = static_cast(mask.type); + + avg_ms = + sparge_sparge_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg); + } + else + { + std::cerr << "Unknown pipeline: " << pipeline << " (use jenga or vsa)\n"; + return false; + } + + // ---- TFLOPS calculation (dense FMHA formula, so sparsity gains show as higher TFLOPS) ---- + std::size_t flop = static_cast(batch) * nhead * + (static_cast(2) * seqlen_q * seqlen_k * hdim_q + + static_cast(2) * seqlen_q * seqlen_k * hdim_v); + float tflops = (avg_ms > 0.f) ? static_cast(flop) / 1.E9f / avg_ms : 0.f; + + if(avg_ms > 0.f) + { + std::cout << std::fixed << ", " << std::setprecision(3) << avg_ms << " ms, " + << std::setprecision(2) << tflops << " TFlops" << std::flush; + } + + // ---- copy results back ---- + o_dev.FromDevice(output_host.data()); + block_map_dev.FromDevice(block_map_host.data()); + + // ---- optional raw output dump (for bit-identical baseline comparison) ---- + if(!dump_o_path.empty()) + { + std::ofstream ofs(dump_o_path, std::ios::binary | std::ios::trunc); + if(!ofs) + { + std::cerr << "\n [dump_o] failed to open " << dump_o_path << std::endl; + } + else + { + ofs.write(reinterpret_cast(output_host.data()), + static_cast(output_host.get_element_space_size_in_bytes())); + std::cout << "\n [dump_o] wrote " << output_host.get_element_space_size_in_bytes() + << " bytes to " << dump_o_path; + } + } + + // ---- count active blocks ---- + ck_tile::index_t total_blocks = batch * nhead * num_q_blocks * num_k_blocks; + ck_tile::index_t active_blocks = 0; + for(size_t i = 0; i < block_map_host.mData.size(); ++i) + if(block_map_host.mData[i]) + active_blocks++; + float actual_sparsity = + 1.0f - static_cast(active_blocks) / static_cast(total_blocks); + std::cout << ", sparsity=" << std::setprecision(2) << actual_sparsity << "(" << active_blocks + << "/" << total_blocks << ")" << std::flush; + + // ---- validation ---- + bool pass = true; + if(do_validation) + { + auto q_ref = to_bhsd(q_host, i_perm); + auto k_ref = to_bhsd(k_host, i_perm); + auto v_ref = to_bhsd(v_host, i_perm); + + // R32: CPU reference lacks causal mask + attention_sink; skip block_map + // cross-check + VSA LUT self-consistency when either is in effect. The + // attention-output check below still runs (consumes GPU bmap). + const bool skip_cpu_bm_check = (mask.type != mask_enum::no_mask) || attention_sink; + + bool bm_pass = true; + bool lut_pass = true; + if(!skip_cpu_bm_check) + { + + sparge::SpargeParams sp; + sp.BLKQ = BLKQ; + sp.BLKK = BLKK; + sp.simthreshd1 = simthreshd1; + sp.cdfthreshd = cdfthreshd; + sp.topk = topk; + sp.i_perm = i_perm; + + auto block_map_cpu = sparge::build_block_map_meansim(q_host, k_host, sp); + + size_t bm_total = block_map_host.mData.size(); + size_t bm_mismatch = 0; + size_t shown = 0; + constexpr size_t MAXSHOW = 10; + std::cout << "\n [block_map cross-check] total=" << bm_total; + for(size_t i = 0; i < bm_total; ++i) + { + uint8_t g = block_map_host.mData[i]; + uint8_t c = block_map_cpu.mData[i]; + if(g != c) + { + if(shown < MAXSHOW) + { + size_t k_idx = i % num_k_blocks; + size_t q_idx = (i / num_k_blocks) % num_q_blocks; + size_t h_idx = (i / (num_k_blocks * num_q_blocks)) % nhead; + size_t b_idx = i / (num_k_blocks * num_q_blocks * nhead); + std::cout << "\n miss[" << shown << "] (b=" << b_idx << ",h=" << h_idx + << ",qb=" << q_idx << ",kb=" << k_idx << ") gpu=" << int(g) + << " cpu=" << int(c); + ++shown; + } + ++bm_mismatch; + } + } + bm_pass = (bm_mismatch == 0); + float bm_ratio = bm_total ? 100.0f * float(bm_mismatch) / float(bm_total) : 0.0f; + std::cout << "\n [block_map cross-check] mismatch=" << bm_mismatch << "/" << bm_total + << " (" << std::setprecision(4) << bm_ratio << "%) " + << (bm_pass ? "PASS" : "FAIL"); + + auto cpu_lut = sparge::block_map_to_vsa_lut_delta(block_map_cpu); + size_t lut_fails = 0; + for(ck_tile::index_t b = 0; b < batch && lut_fails < MAXSHOW; ++b) + { + for(ck_tile::index_t h = 0; h < nhead && lut_fails < MAXSHOW; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks && lut_fails < MAXSHOW; ++qb) + { + int32_t valid = cpu_lut.valid_block_num(b, h, qb); + int32_t active_count = 0; + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + if(block_map_cpu(b, h, qb, kb)) + ++active_count; + int32_t recon_kb = 0; + bool delta_ok = true; + for(int32_t i = 0; i < valid; ++i) + { + int32_t d = cpu_lut.lut(b, h, qb, i); + if(d < 0) + { + delta_ok = false; + break; + } + recon_kb += d; + if(recon_kb >= num_k_blocks) + { + delta_ok = false; + break; + } + if(!block_map_cpu(b, h, qb, recon_kb)) + { + delta_ok = false; + break; + } + } + if(valid != active_count || !delta_ok) + { + lut_pass = false; + if(lut_fails < MAXSHOW) + std::cout << "\n lut_fail (b=" << b << ",h=" << h << ",qb=" << qb + << ") valid=" << valid << " active=" << active_count + << " delta_ok=" << delta_ok; + ++lut_fails; + } + } + } + } + std::cout << "\n [VSA LUT self-consistency] " << (lut_pass ? "PASS" : "FAIL"); + } // end if(!skip_cpu_bm_check) + else + { + std::cout << "\n [block_map cross-check] SKIPPED (mask/sink active; CPU ref lacks)"; + std::cout << "\n [VSA LUT self-consistency] SKIPPED"; + } + + ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); + ck_tile::reference_blocked_attention( + q_ref, k_ref, v_ref, block_map_host, output_ref, BLKQ, BLKK, scale_s); + + auto [rtol, atol] = get_error_tolerance(); + + float max_diff = 0.0f; + size_t num_errors = 0; + + auto output_host_bhsd = to_bhsd(output_host, o_perm); + for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i) + { + float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]); + float ref_val = to_float_for_compare(output_ref.mData[i]); + float diff = std::abs(gpu_val - ref_val); + float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; + + max_diff = std::max(max_diff, diff); + + if(diff > atol && rel_diff > rtol) + num_errors++; + } + + pass = (num_errors == 0) && bm_pass && lut_pass; + std::cout << "\n [attention output] " << ((num_errors == 0) ? "PASS" : "FAIL") + << "(err=" << num_errors << "/" << output_host_bhsd.mData.size() + << " maxdiff=" << max_diff << ")"; + } + + std::cout << std::endl; + return pass; +} + +// ============================================================================ +// Main +// ============================================================================ +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + std::cerr << "Failed to parse arguments\n"; + return -1; + } + + std::string prec = arg_parser.get_str("prec"); + + bool test_result = false; + if(prec == "fp16") + { + test_result = run_test(arg_parser); + } + else if(prec == "bf16") + { + test_result = run_test(arg_parser); + } + else + { + std::cerr << "Unsupported precision: " << prec << "\n"; + return -1; + } + + return test_result ? 0 : -1; +} diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp index cd3513530d4..e461f7d7435 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" +#include #include #include #include @@ -133,34 +134,41 @@ struct FmhaFwdJengaKernel }; // std::variant<> can't take in a list initializer, overload for backward compatibility - CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* block_relation_onehot_ptr, - void* o_ptr, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + // 256-bool LDS staging caps N_k <= 256 (for kN0=64 -> seqlen_k <= 16384). + // Not constexpr because the assert needs runtime evaluation. + CK_TILE_HOST static Kargs MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* block_relation_onehot_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type) { + // 256-bool LDS staging caps N_k <= 256 per Q-tile. + // For kN0=64 this means seqlen_k <= 16384. + assert(ck_tile::integer_divide_ceil(seqlen_k, FmhaPipeline::kN0) <= 256 && + "256-bool LDS staging caps N_k <= 256 (for kN0=64: seqlen_k <= 16384)"); + Kargs kargs{{q_ptr, k_ptr, v_ptr, @@ -248,7 +256,11 @@ struct FmhaFwdJengaKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { // allocate LDS - // Extra LDS for staging block_relation_onehot (256 bools); keep 4B alignment for LDS loads. + // Extra LDS stages 256 bools (4B-aligned for LDS loads) — caps N_k <= 256 per Q-tile, + // i.e. seqlen_k <= 256 * kN0 (for kN0=64 -> seqlen_k <= 16384). MakeKargs asserts this. + // The extra 1024B is jenga-specific: pipeline (block_fmha_pipeline_qr_ks_vs_async_jenga + // .hpp:261) stages block_relation_onehot here. Do NOT copy this `+ 256*sizeof(int)` to + // other sparse kernels (e.g. VSA) without first wiring a real reader. __shared__ char smem_ptr[GetSmemSize() + 256 * sizeof(int)]; // if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z ==0) printf("smem size: %d", diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_sparge_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_sparge_kernel.hpp new file mode 100644 index 00000000000..cbca128ca6f --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_sparge_kernel.hpp @@ -0,0 +1,494 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" +// PVSkipMode enum lives in the sparge pipeline header; pull it in so the +// kernel template arg can name it (R30: promote bool kEnablePVSkip_ to 3-way enum). +#include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp" + +#include +#include +#include +#include + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] + +namespace ck_tile { + +template +struct FmhaFwdSpargeKernel +{ + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; + static constexpr PVSkipMode kPVSkipMode = kPVSkipMode_; + // Legacy alias preserved: any non-kNone mode is "PV-skip enabled". + static constexpr bool kEnablePVSkip = (kPVSkipMode_ != PVSkipMode::kNone); + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using BiasDataType = ck_tile::remove_cvref_t; + using RandValOutputDataType = + ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + + using VLayout = ck_tile::remove_cvref_t; + + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; + static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum; + static constexpr bool kDoFp8StaticQuant = + (QScaleEnum != ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE); + static_assert(!FmhaPipeline::kIsGroupMode, "Sparge sparse attention supports batch mode only."); + static_assert(BiasEnum == BlockAttentionBiasEnum::NO_BIAS, + "Sparge sparse attention does not support bias."); + static_assert(!kStoreLSE, "Sparge sparse attention does not support LSE output."); + static_assert(!kHasDropout, "Sparge sparse attention does not support dropout."); + static_assert(!kHasLogitsSoftCap, "Sparge sparse attention does not support logits soft-cap."); + static_assert(!kDoFp8StaticQuant, + "Sparge sparse attention does not support FP8 static quantization yet."); + + using AttentionVariant = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct FmhaFwdEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaFwdCommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* lut_ptr; + const void* valid_block_num_ptr; + void* o_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t num_head_q; + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t nhead_ratio_qk; + float scale_s; + float pv_threshold; + // R26 split-launch: when non-null, indexed by remapped i_nhead (post head_remap), + // overrides scalar pv_threshold. Buffer length = num_head_q. + const float* pv_threshold_per_head; + // R26 split-launch: when non-null, i_nhead = head_remap_ptr[blockIdx.y]. + // Buffer length = nhead_in_launch. Null = identity (blockIdx.y directly). + const int* head_remap_ptr; + // R26 split-launch: gridDim.y when head_remap_ptr is active (== bucket size). + // Kept for future host-side asserts / debug; kernel reads via blockIdx.y. + ck_tile::index_t nhead_in_launch; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_o; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_o; + }; + + struct FmhaFwdMaskKargs + { + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; + }; + + struct FmhaFwdBatchModeKargs + : FmhaFwdCommonKargs, + std::conditional_t> + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_o; + }; + + using Kargs = FmhaFwdBatchModeKargs; + + struct BlockIndices + { + ck_tile::index_t batch_idx; + ck_tile::index_t qo_head_idx; + ck_tile::index_t kv_head_idx; + }; + + // std::variant<> can't take in a list initializer, overload for backward compatibility + CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* lut_ptr, + const void* valid_block_num_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float pv_threshold, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + // R26 split-launch (default-null preserves + // backward compat = scalar mode). + const float* pv_threshold_per_head = nullptr, + const int* head_remap_ptr = nullptr, + ck_tile::index_t nhead_in_launch = 0) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + lut_ptr, + valid_block_num_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + pv_threshold, + pv_threshold_per_head, + head_remap_ptr, + nhead_in_launch, + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // FmhaFwdCommonKargs + {}, // FmhaFwdMaskKargs or FmhaFwdEmptyKargs<1> + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_o}; + + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + return kargs; + } + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) + { + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), + nhead_, + batch_size_); + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) + { + const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const index_t i_block = blockIdx.x; + // R26 split-launch: if head_remap_ptr is set, translate the launch-local + // head index to the original num_head_q-space index. Null pointer -> + // identity (single-launch backward compat). The remap LUT load is uniform + // across the wavefront (same blockIdx.y for all lanes), but the compiler + // can't infer scalar-uniformity through a global ptr indirection, so we + // broadcast via readfirstlane. Without this, dependent offset/buffer- + // descriptor computations spill to VGPRs and buffer_load_dwordx4 inline + // asm rejects the VGPR operand. + const index_t i_nhead = + (kargs.head_remap_ptr != nullptr) + ? __builtin_amdgcn_readfirstlane(kargs.head_remap_ptr[blockIdx.y]) + : static_cast(blockIdx.y); + const index_t i_batch = blockIdx.z; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + if constexpr(kHasMask) + { + return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_o = 0; + + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + + // sparse mask + const int* lut_ptr = + reinterpret_cast(kargs.lut_ptr) + + static_cast(i_batch * kargs.num_head_q + i_nhead) * + ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0) + + i_tile_m * ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0); + const int* valid_block_num_ptr = + reinterpret_cast(kargs.valid_block_num_ptr) + + static_cast(i_batch * kargs.num_head_q + i_nhead) * + ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) + + i_tile_m; + const int valid_block_num_value = valid_block_num_ptr[0]; + + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + if constexpr(std::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {i_n1, 0}); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + AttentionVariant variant; + const auto variant_params = ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + + BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + + // R26 split-launch: per-head pv_threshold override (null = scalar mode). + // i_nhead is already scalar-broadcast in GetTileIndex; the load is uniform + // and the resulting float lands in SGPRs naturally. We additionally route + // via readfirstlane on the int representation as a defensive hint to keep + // it scalar even when the compiler is conservative about float traffic. + float pv_threshold_resolved; + if(kargs.pv_threshold_per_head != nullptr) + { + const int raw = __builtin_amdgcn_readfirstlane( + __builtin_bit_cast(int, kargs.pv_threshold_per_head[i_nhead])); + pv_threshold_resolved = __builtin_bit_cast(float, raw); + } + else + { + pv_threshold_resolved = kargs.pv_threshold; + } + + auto o_acc_tile = FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + lut_ptr, + valid_block_num_value, + mask, + kargs.scale_s, + pv_threshold_resolved, + variant, + variant_params, + block_indices, + smem_ptr); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp index 5caf27756ff..14fd86e8d14 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp @@ -251,8 +251,7 @@ struct FmhaFwdVSAKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { // allocate LDS - // Extra LDS for staging block_relation_onehot (256 bools); keep 4B alignment for LDS loads. - __shared__ char smem_ptr[GetSmemSize() + 256 * sizeof(int)]; + __shared__ char smem_ptr[GetSmemSize()]; // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); diff --git a/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp new file mode 100644 index 00000000000..bb1cdbfec46 --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp @@ -0,0 +1,239 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include + +namespace ck_tile { + +template +struct SpargeBlockMapKernel +{ + using Pipeline = remove_cvref_t; + + static constexpr index_t kBlockSize = Pipeline::kBlockSize; + static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu; + + using QDataType = typename Pipeline::QDataType; + using KDataType = typename Pipeline::KDataType; + + static constexpr index_t kM0 = Pipeline::kM0; + static constexpr index_t kN0 = Pipeline::kN0; + static constexpr index_t D = Pipeline::D; + + static constexpr index_t kAlignment = 16 / sizeof(QDataType); // 16B = dwordx4 load width + + struct Kargs + { + const void* q_ptr; + const void* k_ptr; + + index_t seqlen_q; + index_t seqlen_k; + index_t hdim_q; + + index_t nhead_q; + index_t nhead_ratio_qk; + + index_t stride_q; + index_t stride_k; + index_t nhead_stride_q; + index_t nhead_stride_k; + index_t batch_stride_q; + index_t batch_stride_k; + + float simthreshd1; + float cdfthreshd; + float topk; + float scale; + + void* block_map_ptr; + void* lut_ptr; + void* valid_block_num_ptr; + + // K-block stats workspace produced by SpargeKStatsKernel + const void* + pooled_k_ws_ptr; // [batch, nhead_k, N_k, D] KDataType (fp16/bf16, matches K dtype) + const void* sim_k_ws_ptr; // [batch, nhead_k, N_k] uint8 + + index_t N_k; + + // Per-head topk (size = nhead_q floats). Required (non-null). + const float* topk_per_head; + + // Per-head cdfthreshd (size = nhead_q floats). Required (non-null); + // only consulted on topk<=0 path. + const float* cdfthreshd_per_head; + + // R32 Items 2+3. mask_type stored as index_t (not mask_enum) to keep this + // include-tree header independent of example/01_fmha/mask.hpp. Magic + // constant 1 == mask_enum::mask_top_left (01_fmha/mask.hpp:13-19). + index_t mask_type; + bool attention_sink; + }; + + CK_TILE_HOST static constexpr auto MakeKargs(const void* q_ptr, + const void* k_ptr, + index_t seqlen_q, + index_t seqlen_k, + index_t hdim_q, + index_t nhead_q, + index_t nhead_ratio_qk, + index_t stride_q, + index_t stride_k, + index_t nhead_stride_q, + index_t nhead_stride_k, + index_t batch_stride_q, + index_t batch_stride_k, + float simthreshd1, + float cdfthreshd, + float topk, + float scale, + void* block_map_ptr, + void* lut_ptr, + void* valid_block_num_ptr, + const void* pooled_k_ws_ptr, + const void* sim_k_ws_ptr, + const float* topk_per_head, + const float* cdfthreshd_per_head, + index_t mask_type, + bool attention_sink) + { + const index_t N_k = integer_divide_ceil(seqlen_k, kN0); + return Kargs{q_ptr, + k_ptr, + seqlen_q, + seqlen_k, + hdim_q, + nhead_q, + nhead_ratio_qk, + stride_q, + stride_k, + nhead_stride_q, + nhead_stride_k, + batch_stride_q, + batch_stride_k, + simthreshd1, + cdfthreshd, + topk, + scale, + block_map_ptr, + lut_ptr, + valid_block_num_ptr, + pooled_k_ws_ptr, + sim_k_ws_ptr, + N_k, + topk_per_head, + cdfthreshd_per_head, + mask_type, + attention_sink}; + } + + CK_TILE_HOST static constexpr auto GridSize(index_t batch, index_t nhead_q, index_t seqlen_q) + { + const index_t Q_blk = integer_divide_ceil(seqlen_q, kM0); + return dim3(Q_blk, nhead_q, batch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + const index_t qb = static_cast(blockIdx.x); + const index_t hq = static_cast(blockIdx.y); + const index_t b = static_cast(blockIdx.z); + + const index_t hk = hq / kargs.nhead_ratio_qk; + + // Q pointer for this (batch, head, q_block) + const auto* q_base = reinterpret_cast(kargs.q_ptr) + + b * kargs.batch_stride_q + hq * kargs.nhead_stride_q + + qb * kM0 * kargs.stride_q; + + // K pointer for this (batch, head_k) + const auto* k_base = reinterpret_cast(kargs.k_ptr) + + b * kargs.batch_stride_k + hk * kargs.nhead_stride_k; + + // Q DRAM view with OOB padding + const auto q_dram_naive = make_naive_tensor_view( + q_base, + make_tuple(kargs.seqlen_q - qb * kM0, D), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + const auto q_dram = pad_tensor_view( + q_dram_naive, make_tuple(number{}, number{}), sequence{}); + + auto q_window = make_tile_window(q_dram, + make_tuple(number{}, number{}), + {0, 0}, + Pipeline::MakeQBlockDistribution()); + + // K DRAM view with OOB padding + const auto k_dram_naive = + make_naive_tensor_view(k_base, + make_tuple(kargs.seqlen_k, D), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + const auto k_dram = pad_tensor_view( + k_dram_naive, make_tuple(number{}, number{}), sequence{}); + + auto k_window = make_tile_window(k_dram, + make_tuple(number{}, number{}), + {0, 0}, + Pipeline::MakeKBlockDistribution()); + + // Output pointers for this (batch, head, q_block) + const index_t N_k = kargs.N_k; + const index_t bmap_offset = + (b * kargs.nhead_q + hq) * integer_divide_ceil(kargs.seqlen_q, kM0) * N_k + qb * N_k; + auto* bmap_ptr = reinterpret_cast(kargs.block_map_ptr) + bmap_offset; + + int32_t* lut_out = nullptr; + int32_t* valid_out = nullptr; + if(kargs.lut_ptr != nullptr) + { + lut_out = reinterpret_cast(kargs.lut_ptr) + bmap_offset; + const index_t valid_offset = + (b * kargs.nhead_q + hq) * integer_divide_ceil(kargs.seqlen_q, kM0) + qb; + valid_out = reinterpret_cast(kargs.valid_block_num_ptr) + valid_offset; + } + + // Shared memory + __shared__ char smem[Pipeline::GetSmemSize()]; + + // K-stat workspace: pre-offset for this (b, hk). + const index_t nhead_k = kargs.nhead_q / kargs.nhead_ratio_qk; + const index_t khead_off = (b * nhead_k + hk) * N_k; + const auto* pooled_k_ws = + reinterpret_cast(kargs.pooled_k_ws_ptr) + khead_off * D; + const auto* sim_k_ws = reinterpret_cast(kargs.sim_k_ws_ptr) + khead_off; + + const float topk_eff = kargs.topk_per_head[hq]; + const float cdfthreshd_eff = kargs.cdfthreshd_per_head[hq]; + + Pipeline{}(q_window, + k_window, + kargs.seqlen_q, + kargs.seqlen_k, + qb, + N_k, + kargs.nhead_ratio_qk, + kargs.simthreshd1, + cdfthreshd_eff, + topk_eff, + kargs.scale, + bmap_ptr, + lut_out, + valid_out, + pooled_k_ws, + sim_k_ws, + static_cast(smem), + kargs.mask_type, + kargs.attention_sink); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sparse_attn/kernel/sparge_kstats_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/sparge_kstats_kernel.hpp new file mode 100644 index 00000000000..893e9a232e4 --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/kernel/sparge_kstats_kernel.hpp @@ -0,0 +1,132 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include + +namespace ck_tile { + +// Kernel A wrapper: grid (N_k, nhead_k, batch). Each work-group precomputes +// K-block stats (pooled_k_mean[D], sim_k) for one (b, hk, kb) into a workspace +// that Kernel B (block_map) reads instead of recomputing per Q-block. +template +struct SpargeKStatsKernel +{ + using Pipeline = remove_cvref_t; + + static constexpr index_t kBlockSize = Pipeline::kBlockSize; + static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu; + + using QDataType = typename Pipeline::QDataType; + using KDataType = typename Pipeline::KDataType; + + static constexpr index_t kN0 = Pipeline::kN0; + static constexpr index_t D = Pipeline::D; + + static constexpr index_t kAlignment = 16 / sizeof(KDataType); + + struct Kargs + { + const void* k_ptr; + + index_t seqlen_k; + index_t hdim_q; + index_t nhead_k; + + index_t stride_k; + index_t nhead_stride_k; + index_t batch_stride_k; + + float simthreshd1; + + void* pooled_k_ptr; // [batch, nhead_k, N_k, D] KDataType (fp16/bf16, matches K dtype) + void* sim_k_ptr; // [batch, nhead_k, N_k] uint8 + + index_t N_k; + + // Per-head simthreshd1 pointer (size = nhead_q floats; kernel indexes [hk] only). + // Required (non-null); matches SpargeAttn upstream contract. + const float* simthreshd1_per_head; + }; + + CK_TILE_HOST static constexpr auto MakeKargs(const void* k_ptr, + index_t seqlen_k, + index_t hdim_q, + index_t nhead_k, + index_t stride_k, + index_t nhead_stride_k, + index_t batch_stride_k, + float simthreshd1, + void* pooled_k_ptr, + void* sim_k_ptr, + const float* simthreshd1_per_head) + { + const index_t N_k = integer_divide_ceil(seqlen_k, kN0); + return Kargs{k_ptr, + seqlen_k, + hdim_q, + nhead_k, + stride_k, + nhead_stride_k, + batch_stride_k, + simthreshd1, + pooled_k_ptr, + sim_k_ptr, + N_k, + simthreshd1_per_head}; + } + + CK_TILE_HOST static constexpr auto GridSize(index_t batch, index_t nhead_k, index_t seqlen_k) + { + const index_t N_k = integer_divide_ceil(seqlen_k, kN0); + return dim3(N_k, nhead_k, batch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + const index_t kb = static_cast(blockIdx.x); + const index_t hk = static_cast(blockIdx.y); + const index_t b = static_cast(blockIdx.z); + + const auto* k_base = reinterpret_cast(kargs.k_ptr) + + b * kargs.batch_stride_k + hk * kargs.nhead_stride_k + + kb * kN0 * kargs.stride_k; + + const auto k_dram_naive = make_naive_tensor_view( + k_base, + make_tuple(kargs.seqlen_k - kb * kN0, D), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + const auto k_dram = pad_tensor_view( + k_dram_naive, make_tuple(number{}, number{}), sequence{}); + + auto k_window = make_tile_window(k_dram, + make_tuple(number{}, number{}), + {0, 0}, + Pipeline::MakeKBlockDistribution()); + + const index_t N_k = kargs.N_k; + const index_t khead_off = (b * kargs.nhead_k + hk) * N_k; + auto* pooled_k_out = + reinterpret_cast(kargs.pooled_k_ptr) + (khead_off + kb) * D; + auto* sim_k_out = reinterpret_cast(kargs.sim_k_ptr) + (khead_off + kb); + + __shared__ char smem[Pipeline::GetSmemSize()]; + + const float simthreshd1_eff = kargs.simthreshd1_per_head[hk]; + + Pipeline{}(k_window, + kargs.seqlen_k, + kb, + simthreshd1_eff, + pooled_k_out, + sim_k_out, + static_cast(smem)); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp index 67936c4353f..717d82aca78 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp @@ -318,26 +318,26 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga { if(!block_relation_onehot[i_total_loops]) { - i_total_loops++; - if(i_total_loops < num_total_loop) - { - // move K tile windows - move_tile_window(k_dram_block_window, {kN0, 0}); - k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); - - if(block_relation_onehot[i_total_loops]) - { - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), - k_dram_window, - number<-1>{}, - k_oob_ck, - k_pre_np); - } - move_tile_window(k_dram_window, {0, kK0}); - move_tile_window(v_dram_window, {0, kN0}); - continue; - } - break; + // scan-ahead: find the next active block in one shot + index_t next = i_total_loops + 1; + while(next < num_total_loop && !block_relation_onehot[next]) + next++; + if(next >= num_total_loop) + break; + const index_t delta = next - i_total_loops; + i_total_loops = next; + // jump K/V windows to the next active block + move_tile_window(k_dram_block_window, {kN0 * delta, 0}); + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); + move_tile_window(v_dram_window, {0, kN0 * delta}); + // immediately prefetch the active K tile + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + continue; } // STAGE 1, QK gemm @@ -430,6 +430,12 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga s.get_tile_distribution()); // Pcompute{j} __builtin_amdgcn_sched_barrier(0x7F); + // Ensure gemm_0's LDS reads (K tile) from all threads are completed before V store + // Only needed when K tail and V use the same LDS buffer + if constexpr(LdsSeq.at(number{}) == LdsSeq.at(number{})) + { + __builtin_amdgcn_s_barrier(); + } // store & prefetch next v, after the max reduction auto v_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledVRegBlockDescriptor()); diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp new file mode 100644 index 00000000000..0a8baa4e623 --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp @@ -0,0 +1,855 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// R30: PV-skip mode enum. R25 A1 shipped a per-wavefront vote; R30 adds a +// per-block consensus vote (matches upstream SpargeAttn kPerBlock semantics; +// see R29 researcher report per_block_vload_guard.md). kNone disables the +// skip path entirely (AST removed). The legacy bool `kEnablePVSkip_=true` +// maps to kPerWave; `false` maps to kNone — preserved via codegen. +enum class PVSkipMode : int +{ + kNone = 0, + kPerWave = 1, + kPerBlock = 2, +}; + +// Sparge variant of qr/ks/vs/async pipeline. Cloned from BlockFmhaPipelineQRKSVSAsyncVSA; +// adds PV-skip per Q-tile (SpargeAttn paper 4.4). Kept as a separate file so the original +// _vsa.hpp can remain frozen as an A/B baseline. +// +// R30: kPVSkipMode_ promoted from bool to 3-value enum {kNone, kPerWave, kPerBlock}. +// kPerWave is the R25 A1 shipped path; kPerBlock adds a block-wide consensus AND vote +// (1 LDS slot + 1 block_sync_lds) so all waves in a block agree before skipping the +// PV mma. Per R29 audit, the V load / V->LDS store / cp_async pipeline stay +// unconditional in BOTH per-wave and per-block modes (only the gemm_1 is gated). +// +// QUANT-HOOK: future int8/sage variant will add QScaleEnum template arg + per-tile descale Kargs; +// _sparge_sage.hpp will live alongside this file and reuse the PV-skip path verbatim. +template +struct BlockFmhaPipelineQRKSVSAsyncSparge +{ + static constexpr PVSkipMode kPVSkipMode = kPVSkipMode_; + // Legacy alias: true iff any PV-skip mode (per-wave or per-block) is active. + // Kept so existing `if constexpr (kEnablePVSkip)` reads still compile. + static constexpr bool kEnablePVSkip = (kPVSkipMode_ != PVSkipMode::kNone); + static constexpr bool kPerBlockPVSkip = (kPVSkipMode_ == PVSkipMode::kPerBlock); + + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) + // only need special care about seq_k padding (oob need set -INF of p instead of zero) + static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && + Problem::kPadHeadDimV == true); + static constexpr bool kPadSeqLenQ = true; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) + static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + + static_assert(BiasEnum == BlockAttentionBiasEnum::NO_BIAS, + "VSA sparse attention does not support bias."); + static_assert(!kHasDropout, "VSA sparse attention does not support dropout."); + static_assert(!kStoreLSE, "VSA sparse attention does not support LSE output."); + static_assert(!kHasLogitsSoftCap, "VSA sparse attention does not support logits soft-cap."); + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + static constexpr index_t kAlignmentO = Policy::template GetAlignmentO(); + +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static constexpr auto R_LOG2E = 1.0 / log2e_v; +#endif + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + // minimize occupancy + if constexpr(kQKHeaddim <= 32) + { + if constexpr(kPadSeqLenK && FmhaMask::IsMasking) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 64) + { + if constexpr(kPadSeqLenK) + return 2; + else + return 3; + } + else if constexpr(kQKHeaddim <= 128) + { + if constexpr(kPadSeqLenK) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 192) + { + if constexpr(kPadSeqLenK) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 256) + { + return 1; + } + else + { + return 1; + }; + } + }(); + + static constexpr const char* name = "qr_async"; + + // R30: per-block PV-skip needs one int32 LDS slot to broadcast the AND-vote + // result across waves. Reserved at the TAIL of the pipeline's LDS budget + // (after the existing K + V allocations), 4 bytes, aligned. When mode is + // kNone or kPerWave the byte is unused; the sentinel cost is negligible + // (4 bytes vs the multi-kB K/V tiles) so we always reserve it to keep the + // smem layout uniform across modes — simpler than per-mode policy plumbing. + static constexpr ck_tile::index_t kPerBlockVoteSlotBytes = 4; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize() + kPerBlockVoteSlotBytes; + } + + // R30: byte offset of the per-block vote flag from `smem_ptr`. Lives just + // past the policy's K+V smem footprint. + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetPerBlockVoteSlotOffset() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const int* kv_block_idx_ptr, + int kv_blocks, + FmhaMask mask, + float scale_s, + float pv_threshold, // SpargeAttn PV-skip threshold; see §2 of pv_skip plan + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr) const + { + if constexpr(!kEnablePVSkip) + { + (void)pv_threshold; // silence unused-param when PV-skip is compiled out + } + // R25 Step 1 redesign D: PV-skip control is a compile-time gate + // (kEnablePVSkip). The entire PV-skip logic block below is wrapped in + // `if constexpr (kEnablePVSkip)`, so when this template parameter is + // false the AST contains no vote, no scalar gate, no extra LDS, and + // codegen converges with _vsa.hpp's FmhaFwdVSAKernel. + // + // Runtime fast-path (C3-lite): pv_threshold == +1e30 sentinel disables + // the skip at runtime via one scalar branch (sgpr); kept inside the + // `if constexpr` so the OFF instantiation pays zero cost. + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); + + // K tile in LDS + auto k_lds_ptr = reinterpret_cast(smem_ptr); + auto k_lds_store = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)), + Policy::template MakeKLdsStoreBlockDescriptor(i_buf).get_lengths(), + {0, 0, 0}); + }, + number{}); + + auto k_lds_Load_view = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor()); + + auto k_lds_load = + make_tile_window(k_lds_Load_view, + Policy::template MakeKLdsLoadBlockDescriptor().get_lengths(), + {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + int seqlen_k_start = kv_block_idx_ptr[0] * kN0; + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); + q_dram_window.init_raw(); + + // TODO: we use async Copy for K, which is inline asm + // a side effect is we have to use inline asm for q as well + auto q = decltype(load_tile(q_dram_window)){}; + // TODO: start from rocm-6.2, compiler will have problem if manually set clear of q. + // however, q would be cleared in the constructor of static distributed tensor + // set_tile(q, number<0>{}); // use per-dword clear to avoid scratch + load_tile_raw(q, q_dram_window); + __builtin_amdgcn_sched_barrier(0); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + __builtin_amdgcn_sched_barrier(0); + const auto q_origin = q_dram_window.get_window_origin(); + const auto num_total_loop = kv_blocks; + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK) + { + if(num_total_loop <= 0) + { + buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) + // otherwise will have compute error(maybe compiler bug?) + + // Note: here occ are all cleard, return it + return o_acc; + } + __builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + auto k_dram_window = make_tile_window( + k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + k_dram_window.init_raw(); + constexpr auto k_oob_ck = bool_constant{}; + constexpr auto k_pre_np = bool_constant{}; + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + // prefetch K tile + async_load_tile_raw( + k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + __builtin_amdgcn_sched_barrier(0); + + // buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer()); + buffer_load_fence(k_dram_window.get_num_of_access()); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(1 <= k0_loops); + static_assert(1 <= k1_loops); + // main loop + do + { + // STAGE 1, QK gemm + clear_tile(s_acc); // initialize C + if constexpr(k0_loops > 1) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + async_load_tile_raw(k_lds_store(number{})>{}), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + if constexpr(i_k0 < k0_loops - 1) + move_tile_window(k_dram_window, {0, kK0}); + + async_load_fence(k_dram_window.get_num_of_access()); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + gemm_0(s_acc, + get_slice_tile( + q, sequence<0, i_k0 * kK0>{}, sequence{}), + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); + }); + } + + // TODO: this to fix a bug when loop smaller than 2, + // the following fence/barrier will be scheduled inside 1st loop + if constexpr(k0_loops <= 2) + __builtin_amdgcn_sched_barrier(0); + + async_load_fence(); + __builtin_amdgcn_s_barrier(); + + int block_idx = kv_block_idx_ptr[i_total_loops + 1]; + auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); + __builtin_amdgcn_sched_barrier(0); + { // tail + gemm_0( + s_acc, + get_slice_tile( + q, sequence<0, (k0_loops - 1) * kK0>{}, sequence{}), + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); + } + __builtin_amdgcn_sched_barrier(1); + + // STAGE 2, scale_s, mask, softmax (no bias/soft-cap) +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + __builtin_amdgcn_sched_barrier(0x7F); + // Ensure gemm_0's LDS reads (K tile) from all threads are completed before V store + // Only needed when K tail and V use the same LDS buffer + if constexpr(LdsSeq.at(number{}) == LdsSeq.at(number{})) + { + __builtin_amdgcn_s_barrier(); + } + // store & prefetch next v, after the max reduction. + // R25 Step 1 redesign D: V→LDS store and the next-V DRAM load are + // UNCONDITIONAL — per-warp PV-skip cannot gate them (cross-warp + // shared LDS state; see Researcher audit A.3/A.4/A.5). + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + + store_tile(v_lds_window_tmp, v_shuffle_tmp); + } + else + { + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, v_buf); + } + + if constexpr(k1_loops > 1) + { + move_tile_window( + v_dram_window, + {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + } + __builtin_amdgcn_sched_barrier(0); + + // ================================================================ + // PV-SKIP per Q-tile (SpargeAttn paper §4.4) + // R25 Step 1 redesign D — per-warp arithmetic-only: + // Compile-time `if constexpr (kEnablePVSkip)` wraps the entire + // block. When kEnablePVSkip=false the AST has zero PV-skip + // artifacts → codegen converges with _vsa.hpp. + // + // When enabled, a per-warp predicate gates ONLY the per-row, + // VGPR-private work (exp2 → p_compute, rowsum, `l += rowsum_p`). + // V load / V→LDS store / gemm_1 / every `s_barrier` / + // `block_sync_lds` stay unconditional (cross-warp LDS dep — see + // Researcher audit A.7). + // + // On warp_skip, this warp's owned rows of p_compute are zeroed + // so the unconditional gemm_1 contributes 0 to o_acc (audit + // A.7 "simplest realisation"). The alpha-rescale `l *= tmp` and + // `o *= tmp` still apply. + // + // pv_threshold semantics shift: now per-warp max diff (slightly + // more aggressive than per-block at the same threshold; matches + // upstream SpargeAttn `kPerWarp` mode default). + // + // Skip iff: scale_s * (m_local - m_old) + pv_threshold <= 0 + // (where m_local/m_old are warp-uniform after block_tile_reduce_sync) + // ================================================================ + // Per-warp PV-skip predicate. Only declared when kEnablePVSkip; + // wrapped in a lambda so the false instantiation contains nothing. + auto compute_warp_skip = [&]() { + if constexpr(kEnablePVSkip) + { + // C3-lite scalar fast-path: pv_threshold == +1e30 sentinel + // disables skip; runtime cost is a single sgpr branch. + if(pv_threshold >= 1e29f) + return false; + // Per-row predicate: warp-AND over rows this warp owns. + int warp_skip_int = 1; + constexpr auto m_spans = decltype(m_local)::get_distributed_spans(); + sweep_tile_span(m_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const float diff = scale_s * (static_cast(m_local[i_idx]) - + static_cast(m_old[i_idx])); + if(!(diff + pv_threshold <= 0.0f)) + warp_skip_int = 0; + }); + // Warp-level AND reduce (wave=64 on gfx942; xor butterfly). + // No LDS, no s_barrier, no cross-warp dependency. + warp_skip_int &= __shfl_xor(warp_skip_int, 32); + warp_skip_int &= __shfl_xor(warp_skip_int, 16); + warp_skip_int &= __shfl_xor(warp_skip_int, 8); + warp_skip_int &= __shfl_xor(warp_skip_int, 4); + warp_skip_int &= __shfl_xor(warp_skip_int, 2); + warp_skip_int &= __shfl_xor(warp_skip_int, 1); + return warp_skip_int != 0; + } + else + { + return false; + } + }; + const bool warp_skip = compute_warp_skip(); + + // ================================================================ + // R30: per-block PV-skip — block-wide AND vote over warp_skip. + // Hand-rolled (no `block_and` primitive in CK-tile, no + // `__syncthreads_and` analog — see R30 idiom catalog §7.5). + // + // Protocol: + // 1. Lane 0 of each wave atomicAnd's its warp_skip int into a + // shared LDS sentinel (initialised to 1 by lane 0 of wave 0 + // before the vote). + // 2. block_sync_lds() — all stores visible, all waves rendezvous + // (uses the same s_waitcnt+s_barrier discipline as the K/V + // LDS chain; lgkmcnt accounting stays consistent — idiom + // §3.1 / §4.2). + // 3. All lanes read the sentinel back into a register. The + // result is wave-uniform (and effectively SGPR after + // readfirstlane) — used to gate gemm_1 at :607 / :665 below. + // + // Cost: 1 LDS init + 1 atomicAnd + 1 block_sync_lds + 1 LDS load. + // The vote slot lives at `smem_ptr + GetPerBlockVoteSlotOffset()`, + // 4 bytes past the policy K+V budget (see GetSmemSize override). + // No interaction with LdsSeq rotation slots. + // + // V load / V->LDS store / cp_async pipeline stay UNCONDITIONAL in + // both per-wave and per-block modes — matches upstream SpargeAttn + // (R29 audit) and CK-tile LDS-rotation discipline. + // ================================================================ + bool block_skip = false; + if constexpr(kPerBlockPVSkip) + { + // Carve a 4-byte uint32 slot at the LDS tail. The cast is safe: + // GetSmemSize() bumped the smem_ptr allocation by 4 bytes (see + // pipeline override above), so the slot is dedicated to this + // pipeline instance and never reused by K/V tiles. + auto* vote_slot = reinterpret_cast(static_cast(smem_ptr) + + GetPerBlockVoteSlotOffset()); + + const int lane_id = threadIdx.x % warpSize; + const int warp_id = threadIdx.x / warpSize; + + // Initialise the sentinel to 1 (skip-everything) before any + // wave votes. Only one thread does the init; the subsequent + // block_sync_lds() makes it visible to all waves. + if(warp_id == 0 && lane_id == 0) + { + *vote_slot = 1u; + } + block_sync_lds(); + + // Each wave contributes its warp_skip (already wave-uniform + // after the butterfly in compute_warp_skip). Lane 0 of each + // wave issues the atomicAnd; other lanes are idle. The atomic + // is on LDS (s_or_b32 / ds_and_b32), much cheaper than global. + if(lane_id == 0) + { + atomicAnd(vote_slot, warp_skip ? 1u : 0u); + } + block_sync_lds(); + + // Broadcast the consensus back to every lane. + const uint32_t consensus = *vote_slot; + block_skip = (consensus != 0u); + } + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + if constexpr(FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + // exp2 → p_compute and rowsum_p. + // R25 redesign D: when kEnablePVSkip + warp_skip, we zero this + // warp's owned rows of p_compute so the unconditional gemm_1 + // contributes zero to o_acc, and skip the rowsum. + // R30: per-block mode uses block_skip (uniform across waves) and + // additionally skips gemm_1 itself (see guard at the gemm_1 site + // below). The p_compute zeroing remains so rowsum_p -> 0 and + // `l += rowsum_p` is a no-op for skipped iters. + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + if constexpr(kPerBlockPVSkip) + { + if(block_skip) + { + p_compute(i_j_idx) = SMPLComputeDataType{0}; + return; + } + } + else if constexpr(kEnablePVSkip) + { + if(warp_skip) + { + p_compute(i_j_idx) = SMPLComputeDataType{0}; + return; + } + } +#if CK_TILE_FMHA_FWD_FAST_EXP2 + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); +#else + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + + // l{j}, Oacc{j}: alpha rescale of l / o always runs. + // When warp_skip, rowsum_p is already 0 for this + // warp's owned rows (p_compute zeroed above), so + // `l += rowsum_p` is a no-op — no extra branch needed. + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + }(); +#else + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + const auto p = [&]() { + if constexpr(std::is_same_v) + return impl::cast_tile_pkrtz_fp16_fp32(p_compute); + else + return cast_tile(p_compute); + }(); + + // STAGE 3, KV gemm — always runs (block-wide LDS dep; per-warp + // skipping has been absorbed by zeroing p_compute rows above). + { + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) + { + v_buf = load_tile(v_dram_window, + number<-1>{}, + bool_constant{}); // load next v_buf + } + // block_sync_lds() stays UNCONDITIONAL — it is the + // workgroup barrier the V->LDS rotation chain requires + // (idiom catalog §3.1 / §4.1). Only the gemm_1 MFMA is + // gated on block_skip when in per-block mode. + block_sync_lds(); + if constexpr(kPerBlockPVSkip) + { + if(!block_skip) + { + gemm_1( + o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1>{}, + sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, + kK1>{})); + } + } + else + { + gemm_1(o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1>{}, + sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, + kK1>{})); + } + + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, + kK1>{}); + store_tile(v_lds_window_tmp, v_shuffle_tmp); + } + else + { + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, + kK1>{}); + store_tile(v_lds_window_tmp, v_buf); + } + if constexpr(i_k1 < k1_loops - 1) + move_tile_window(v_dram_window, {0, kK1}); + }); + } + } + i_total_loops++; + if(i_total_loops < num_total_loop) + { + // V load runs unconditionally under redesign D, so no skip + // compensation needed (same offset arithmetic as _vsa.hpp). + move_tile_window(v_dram_window, {0, kN0 * (block_idx - 1)}); + move_tile_window(k_dram_block_window, {kN0 * block_idx, 0}); + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); + + if constexpr(k1_loops >= 2 && + LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) + __builtin_amdgcn_s_barrier(); + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + } + // tail — gemm_1 runs unconditionally under redesign D (per-wave). + // R30: per-block mode gates the MFMA on block_skip; block_sync_lds + // still runs unconditionally (workgroup barrier for LDS rotation). + { + block_sync_lds(); + if constexpr(kPerBlockPVSkip) + { + if(!block_skip) + { + gemm_1( + o_acc, + get_slice_tile( + p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, + kK1>{})); + } + } + else + { + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, + kK1>{})); + } + } + } while(i_total_loops < num_total_loop); + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + return o_acc; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp index 2b097ae5827..507c91a585e 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp @@ -200,7 +200,7 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - int seqlen_k_start = kv_block_idx_ptr[0] * kM0; + int seqlen_k_start = kv_block_idx_ptr[0] * kN0; auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_origin(), @@ -387,6 +387,12 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA s.get_tile_distribution()); // Pcompute{j} __builtin_amdgcn_sched_barrier(0x7F); + // Ensure gemm_0's LDS reads (K tile) from all threads are completed before V store + // Only needed when K tail and V use the same LDS buffer + if constexpr(LdsSeq.at(number{}) == LdsSeq.at(number{})) + { + __builtin_amdgcn_s_barrier(); + } // store & prefetch next v, after the max reduction if constexpr(std::is_same_v) { diff --git a/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp b/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp new file mode 100644 index 00000000000..176063cee1d --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp @@ -0,0 +1,571 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce.hpp" + +namespace ck_tile { + +template +struct SpargeBlockMapPipeline +{ + using Problem = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t D = BlockFmhaShape::kQKHeaddim; + static constexpr index_t NumWarps = BlockFmhaShape::NumWarps; + static constexpr index_t WarpSize = get_warp_size(); + + static constexpr index_t KPerThread = 16 / sizeof(QDataType); + static constexpr index_t KThreads = D / KPerThread; + static constexpr index_t SeqThreadPerWarp = WarpSize / KThreads; + static constexpr index_t MPerThread = kM0 / (SeqThreadPerWarp * NumWarps); + static constexpr index_t NPerThread = kN0 / (SeqThreadPerWarp * NumWarps); + + static constexpr index_t kBlockPerCu = 1; + static constexpr index_t kMaxKBlocks = 1024; + + // LDS layout (non-overlapping, all used simultaneously in K-block loop): + // [0 .. kReduceBytes) cross-warp reduction scratch slab 0 + // [kReduceBytes .. 2*kReduceBytes) cross-warp reduction scratch slab 1 + // (ping-pong for K-loop double buffer) + // [kScoreOffset ..) scores[N_k] + // [kBmapOffset ..) block_map[N_k] + // [kSmallOffset ..) softmax/selection argmax scratch (2*NumWarps + // floats) + // Column-stride pad: k_idx*(KPerThread+1) instead of k_idx*KPerThread to break + // the 4-way intra-warp bank conflict. Per-warp slab size: KThreads * (KPerThread + 1) floats. + static constexpr index_t kColPaddedStride = KPerThread + 1; + static constexpr index_t kPerWarpFloats = KThreads * kColPaddedStride; + static constexpr index_t kReduceBytes = NumWarps * kPerWarpFloats * sizeof(float); + static constexpr index_t kReduceTotalBytes = 2 * kReduceBytes; // 2 slabs (K-loop ping-pong) + static constexpr index_t kScoreOffset = kReduceTotalBytes; + static constexpr index_t kBmapOffset = kScoreOffset + kMaxKBlocks * sizeof(float); + static constexpr index_t kSmallOffset = kBmapOffset + kMaxKBlocks * sizeof(uint8_t); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return kSmallOffset + 2 * NumWarps * sizeof(float); + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeQBlockDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeKBlockDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + // Extract tile data into a local float array via static_for (compile-time indices). + template + CK_TILE_DEVICE static void tile_to_float(const Tile& tile, float (&out)[BufSize]) + { + static_assert(Tile::get_thread_buffer_size() == BufSize); + const auto& buf = tile.get_thread_buffer(); + static_for<0, BufSize, 1>{}([&](auto i) { out[i.value] = type_convert(buf[i]); }); + } + + // Column-wise (dim=0) sum: accumulate SeqPerThread rows into KPerThread partial sums, + // then xor-shuffle across m_idx within warp. + template + CK_TILE_DEVICE static void column_reduce_thread_and_warp(const float* __restrict__ data, + float (&col_acc)[KPerThread]) + { + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] = 0.f; + + for(index_t m = 0; m < SeqPerThread; ++m) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += data[m * KPerThread + k]; + + for(index_t stride = KThreads; stride < WarpSize; stride *= 2) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += warp_shuffle(col_acc[k], __lane_id() ^ stride); + } + + // Cross-warp LDS reduction for column sums. + // Templated TrailingSync flag: when false, the trailing __syncthreads() is dropped — + // only safe when the next access targets a *different* slab and the intervening work + // does not read smem_reduce. Used at the slab_b call in the K-loop, where the next + // iter's first cross-warp reduce writes to slab_a and is preceded by its own leading sync. + template + CK_TILE_DEVICE static void column_reduce_cross_warp(float (&col_acc)[KPerThread], + float* __restrict__ smem_reduce) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t lane_id = tid % WarpSize; + const index_t k_idx = lane_id % KThreads; + const index_t m_idx = lane_id / KThreads; + + // Column-stride pad: stride k_idx by (KPerThread+1)=9 instead of 8, changing + // per-lane bank from (k_idx*8+k)%32 to (k_idx*9+k)%32. For k=0, lanes + // (k_idx={0,4,8,12}) hit banks {0,4,8,12} instead of all 0. + if(m_idx == 0) + for(index_t k = 0; k < KPerThread; ++k) + smem_reduce[warp_id * kPerWarpFloats + k_idx * kColPaddedStride + k] = col_acc[k]; + __syncthreads(); + + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] = 0.f; + for(index_t w = 0; w < NumWarps; ++w) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += smem_reduce[w * kPerWarpFloats + k_idx * kColPaddedStride + k]; + if constexpr(TrailingSync) + __syncthreads(); + } + + // Compute ||v||^2 per row: sum along KPerThread then xor-shuffle across k_idx. + template + CK_TILE_DEVICE static void row_reduce_sq_norm(const float* __restrict__ data, + float (&row_norms)[SeqPerThread], + index_t actual_seq) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t m_idx = (tid % WarpSize) / KThreads; + + for(index_t m = 0; m < SeqPerThread; ++m) + { + float sq = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + { + float v = data[m * KPerThread + k]; + sq += v * v; + } + for(index_t stride = 1; stride < KThreads; stride *= 2) + sq += warp_shuffle(sq, __lane_id() ^ stride); + + index_t gsq = m * (SeqThreadPerWarp * NumWarps) + warp_id * SeqThreadPerWarp + m_idx; + row_norms[m] = (gsq < actual_seq) ? sq : 0.f; + } + } + + // Column reduce of normalised rows: sum_hat[d] = sum_i data[i,d] / ||data[i,:]||. + template + CK_TILE_DEVICE static void column_reduce_normalised(const float* __restrict__ data, + const float* __restrict__ row_norms, + float (&col_acc)[KPerThread], + index_t actual_seq) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t m_idx = (tid % WarpSize) / KThreads; + + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] = 0.f; + + for(index_t m = 0; m < SeqPerThread; ++m) + { + // Round 12: hardware fast rsqrt (v_rsq_f32, ~1 ULP) replaces sw sqrt+rcp. + float inv_norm = (row_norms[m] > 0.f) ? rsqrtf(row_norms[m]) : 0.f; + index_t gsq = m * (SeqThreadPerWarp * NumWarps) + warp_id * SeqThreadPerWarp + m_idx; + if(gsq < actual_seq) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += data[m * KPerThread + k] * inv_norm; + } + + for(index_t stride = KThreads; stride < WarpSize; stride *= 2) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += warp_shuffle(col_acc[k], __lane_id() ^ stride); + } + + // Scalar reduce across k_idx lanes (within warp). + CK_TILE_DEVICE static float reduce_across_k(float v) + { + for(index_t stride = 1; stride < KThreads; stride *= 2) + v += warp_shuffle(v, __lane_id() ^ stride); + return v; + } + + // Full-block scalar reduce (warp xor + cross-warp LDS). + CK_TILE_DEVICE static float block_reduce_sum(float v, float* smem_small) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t lane_id = tid % WarpSize; + + for(index_t stride = 1; stride < WarpSize; stride *= 2) + v += warp_shuffle(v, __lane_id() ^ stride); + if(lane_id == 0) + smem_small[warp_id] = v; + __syncthreads(); + if(tid == 0) + { + float s = 0.f; + for(index_t w = 0; w < NumWarps; ++w) + s += smem_small[w]; + smem_small[0] = s; + } + __syncthreads(); + return smem_small[0]; + } + + CK_TILE_DEVICE static float block_reduce_max(float v, float* smem_small) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t lane_id = tid % WarpSize; + + for(index_t stride = 1; stride < WarpSize; stride *= 2) + v = max(v, warp_shuffle(v, __lane_id() ^ stride)); + if(lane_id == 0) + smem_small[warp_id] = v; + __syncthreads(); + if(tid == 0) + { + float s = smem_small[0]; + for(index_t w = 1; w < NumWarps; ++w) + s = max(s, smem_small[w]); + smem_small[0] = s; + } + __syncthreads(); + return smem_small[0]; + } + + // ====================================================================== + template + CK_TILE_DEVICE void operator()(const QWindowType& q_window_in, + const KWindowType& /*k_window_in*/, + index_t seqlen_q, + index_t /*seqlen_k*/, + index_t qb, + index_t N_k, + index_t /*nhead_ratio_qk*/, + float simthreshd1, + float cdfthreshd, + float topk, + float scale, + uint8_t* block_map_ptr, + int32_t* lut_ptr, + int32_t* valid_block_num_ptr, + const KDataType* __restrict__ pooled_k_ws_ptr, + const uint8_t* __restrict__ sim_k_ws_ptr, + void* smem_ptr, + index_t mask_type, + bool attention_sink) const + { + const index_t tid = static_cast(threadIdx.x); + + // mask_enum::mask_top_left == 1 (01_fmha/mask.hpp:16). Multiplicative + // form handles BLKQ=64,BLKK=128 (kM0=kN0 case. + const bool is_causal_tl = (mask_type == 1); + + // K-loop no longer reduces; only Q-stats uses smem_float0. + // smem_float1 slab is allocated for layout compat but unused. + auto* smem_float0 = reinterpret_cast(smem_ptr); + auto* smem_scores = + reinterpret_cast(reinterpret_cast(smem_ptr) + kScoreOffset); + auto* smem_bmap = + reinterpret_cast(reinterpret_cast(smem_ptr) + kBmapOffset); + auto* smem_small = + reinterpret_cast(reinterpret_cast(smem_ptr) + kSmallOffset); + + const index_t bs_q = min(static_cast(kM0), seqlen_q - qb * kM0); + const float inv_bs_q = (bs_q > 0) ? (1.0f / static_cast(bs_q)) : 0.f; + + // ================================================================== + // Q Block Statistics + // ================================================================== + auto q_tile = load_tile(q_window_in); + + float q_data[MPerThread * KPerThread]; + tile_to_float(q_tile, q_data); + + // 1a. L2 norm per token + float psq[MPerThread]; + row_reduce_sq_norm(q_data, psq, bs_q); + + // 1b. Column sum -> mean + // Drop trailing sync: next reduce reuses same slab (smem_float0) with its own + // leading __syncthreads() before reading. pooled_q_mean is register-only between reduces. + float pooled_q_mean[KPerThread]; + column_reduce_thread_and_warp(q_data, pooled_q_mean); + column_reduce_cross_warp(pooled_q_mean, smem_float0); + for(index_t k = 0; k < KPerThread; ++k) + pooled_q_mean[k] *= inv_bs_q; + + // 1c. Normalised sum_hat + // Drop trailing sync: next cross-warp reduce in K-loop iter 0 writes + // slab_a=smem_float0 (kb=0 even); its leading __syncthreads() covers the WAR. + // sum_hat is register-only here. + float sum_hat[KPerThread]; + column_reduce_normalised(q_data, psq, sum_hat, bs_q); + column_reduce_cross_warp(sum_hat, smem_float0); + + // 1d. sim_q = ||sum_hat||^2 / bs_q^2 + float sh_sq = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + sh_sq += sum_hat[k] * sum_hat[k]; + sh_sq = reduce_across_k(sh_sq); + const float denom_q = static_cast(bs_q) * static_cast(bs_q); + const bool sim_q = (denom_q > 0.f) && ((sh_sq / denom_q) > simthreshd1); + + // Not similar → force all K blocks ON, early exit + if(!sim_q) + { + // R32 Item 2: only fill causal-valid prefix when active. + const index_t causal_kb_end = + is_causal_tl ? min(N_k, integer_divide_ceil((qb + 1) * kM0, kN0)) : N_k; + + for(index_t i = tid; i < N_k; i += kBlockSize) + block_map_ptr[i] = (i < causal_kb_end) ? 1 : 0; + + // R32 Item 3: sink force. Under top-left causal, kb=0 always + // causal-valid for qb>=0 -> no-op; meaningful for mask=no + sink=1. + if(attention_sink && tid == 0) + block_map_ptr[0] = 1; + __syncthreads(); // sink visible to LUT-build below + + if(lut_ptr != nullptr && tid == 0) + { + int32_t valid = 0, prev = 0; + for(index_t kb = 0; kb < causal_kb_end; ++kb) + { + lut_ptr[valid] = static_cast(kb) - prev; + prev = static_cast(kb); + ++valid; + } + for(index_t i = valid; i < N_k; ++i) + lut_ptr[i] = 0; + *valid_block_num_ptr = valid; + } + return; + } + + // ================================================================== + // K Block Loop + // ================================================================== + for(index_t i = tid; i < N_k; i += kBlockSize) + smem_bmap[i] = 0; + __syncthreads(); + + // K-stats precomputed by SpargeKStatsKernel. Each thread loads its own + // KPerThread-slice of pooled_k_mean from DRAM workspace; sim_k is a single byte. + // No K-tile load, no cross-warp reduce in the K-loop. + const index_t lane_id_kb = tid % WarpSize; + const index_t k_idx_kb = lane_id_kb % KThreads; + + for(index_t kb = 0; kb < N_k; ++kb) + { + // R32 Item 2: top-left causal at block grain. + // (qb,kb) past-diagonal iff kb*kN0 >= (qb+1)*kM0. + const bool causal_killed = is_causal_tl && (kb * kN0 >= (qb + 1) * kM0); + + const KDataType* p_kb = pooled_k_ws_ptr + kb * D + k_idx_kb * KPerThread; + float pooled_k_mean[KPerThread]; + for(index_t k = 0; k < KPerThread; ++k) + pooled_k_mean[k] = type_convert(p_kb[k]); + + float dot = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + dot += pooled_q_mean[k] * pooled_k_mean[k]; + dot = reduce_across_k(dot); + + const bool sim_k = (sim_k_ws_ptr[kb] != 0); + + if(tid == 0) + { + // INVARIANT (mirrors SpargeAttn ref utils.py:175-180): + // ~sim_k blocks are forced ON in the bitmap (final_map[~sim_k]=1) + // AND have score = -inf so the selection step (topk / cdf) does NOT + // pick them again (would double-count toward topk budget). + // R32: causal_killed gates the force-on so past-diagonal blocks are + // NOT forced ON; bmap stays 0, scores -inf so selection excludes them. + if(causal_killed) + smem_scores[kb] = -numeric::infinity(); // bmap stays 0 + else if(!sim_k) + { + smem_bmap[kb] = 1; + smem_scores[kb] = -numeric::infinity(); + } + else + smem_scores[kb] = dot * scale; + } + } + __syncthreads(); // guard selection's reads of smem_bmap / smem_scores + + // ================================================================== + // Softmax + Selection + // ================================================================== + + // max + float lmax = -numeric::infinity(); + for(index_t i = tid; i < N_k; i += kBlockSize) + lmax = max(lmax, smem_scores[i]); + const float max_score = block_reduce_max(lmax, smem_small); + + // exp + sum + float lsum = 0.f; + for(index_t i = tid; i < N_k; i += kBlockSize) + { + float e = (smem_scores[i] > -numeric::infinity()) + ? __builtin_expf(smem_scores[i] - max_score) + : 0.f; + smem_scores[i] = e; + lsum += e; + } + const float sum_exp = block_reduce_sum(lsum, smem_small); + + // Round 13i: argmax is invariant under positive scaling (inv_sum > 0). When + // topk > 0 we never read normalised values for cdfthreshd, so skip the + // normalise pass entirely (saves N_k LDS writes + 1 __syncthreads). The + // cdfthreshd path (topk <= 0) still requires normalised scores so the + // accumulator `cumulative_prob` matches probabilities. + const bool topk_active = (topk > 0.f); + const float inv_sum = (!topk_active && sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f; + if(!topk_active) + { + for(index_t i = tid; i < N_k; i += kBlockSize) + smem_scores[i] *= inv_sum; + __syncthreads(); + } + + // Selection: iterative argmax + index_t num_to_select = + topk_active + ? max(static_cast(1), static_cast(topk * static_cast(N_k))) + : N_k; + + float cumulative_prob = 0.f; + for(index_t round = 0; round < num_to_select; ++round) + { + // thread-local argmax + float best_val = -1.f; + index_t best_idx = 0; + for(index_t i = tid; i < N_k; i += kBlockSize) + { + if(smem_scores[i] > best_val || (smem_scores[i] == best_val && i < best_idx)) + { + best_val = smem_scores[i]; + best_idx = i; + } + } + + // warp argmax + for(index_t stride = 1; stride < WarpSize; stride *= 2) + { + float rv = warp_shuffle(best_val, __lane_id() ^ stride); + index_t ri = warp_shuffle(best_idx, __lane_id() ^ stride); + if(rv > best_val || (rv == best_val && ri < best_idx)) + { + best_val = rv; + best_idx = ri; + } + } + + // cross-warp argmax via LDS + const index_t lane_id = tid % WarpSize; + const index_t warp_id = tid / WarpSize; + if(lane_id == 0) + { + smem_small[warp_id] = best_val; + smem_small[NumWarps + warp_id] = bit_cast(static_cast(best_idx)); + } + __syncthreads(); + + // Round 13g: collapse 2 syncs/round into 1. tid==0 computes the global + // winner AND writes the sentinel (smem_bmap=1, smem_scores=-1) in the same + // critical section, gated by bv>0. All threads then read smem_small[0] for + // the early break / cumulative_prob accumulation. Saves 1 __syncthreads per + // round (~32 syncs @ N_k=64 topk=0.5). + if(tid == 0) + { + float bv = smem_small[0]; + index_t bi = bit_cast(smem_small[NumWarps]); + for(index_t w = 1; w < NumWarps; ++w) + { + float wv = smem_small[w]; + index_t wi = bit_cast(smem_small[NumWarps + w]); + if(wv > bv || (wv == bv && wi < bi)) + { + bv = wv; + bi = wi; + } + } + // Write sentinel into bmap/scores in the same critical section. + // Guarded by bv > 0 so we never poison a valid score with -1. + if(bv > 0.f) + { + smem_bmap[bi] = 1; + smem_scores[bi] = -1.f; + } + smem_small[0] = bv; + } + __syncthreads(); + + float g_val = smem_small[0]; + + if(g_val <= 0.f) + break; + + if(topk > 0.f) + { + if(round + 1 >= num_to_select) + break; + } + else + { + cumulative_prob += g_val; + if(cumulative_prob >= cdfthreshd) + break; + } + } + + // ================================================================== + // Write outputs to global memory + // ================================================================== + // R32 Item 3: force smem_bmap[0]=1 BEFORE LUT collation reads it. + // Reuses existing LUT-build loop (R31 §4: don't manually insert into + // delta stream). Causal post-multiply unnecessary: D.2 sets killed + // scores to -inf; selection gate L490 `bv > 0` excludes them, so + // smem_bmap[bi]=1 never fires for killed blocks. + if(attention_sink && tid == 0) + smem_bmap[0] = 1; + __syncthreads(); + + for(index_t i = tid; i < N_k; i += kBlockSize) + block_map_ptr[i] = smem_bmap[i]; + + if(lut_ptr != nullptr && tid == 0) + { + int32_t valid = 0, prev = 0; + for(index_t kb = 0; kb < N_k; ++kb) + { + if(smem_bmap[kb] != 0) + { + lut_ptr[valid] = static_cast(kb) - prev; + prev = static_cast(kb); + ++valid; + } + } + for(index_t i = valid; i < N_k; ++i) + lut_ptr[i] = 0; + *valid_block_num_ptr = valid; + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sparse_attn/pipeline/sparge_kstats_pipeline.hpp b/include/ck_tile/ops/sparse_attn/pipeline/sparge_kstats_pipeline.hpp new file mode 100644 index 00000000000..9c122d8dea6 --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/pipeline/sparge_kstats_pipeline.hpp @@ -0,0 +1,110 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp" + +namespace ck_tile { + +// Kernel A of the K-stat precompute split: one work-group per (b, hk, kb) +// computes pooled_k_mean and sim_k for that K-block once. Kernel B then reads +// from the workspace instead of recomputing per Q-block. +template +struct SpargeKStatsPipeline +{ + using Problem = remove_cvref_t; + using Base = SpargeBlockMapPipeline; + using QDataType = typename Base::QDataType; + using KDataType = typename Base::KDataType; + + static constexpr index_t kBlockSize = Base::kBlockSize; + static constexpr index_t kM0 = Base::kM0; + static constexpr index_t kN0 = Base::kN0; + static constexpr index_t D = Base::D; + static constexpr index_t NumWarps = Base::NumWarps; + static constexpr index_t WarpSize = Base::WarpSize; + + static constexpr index_t KPerThread = Base::KPerThread; + static constexpr index_t KThreads = Base::KThreads; + static constexpr index_t SeqThreadPerWarp = Base::SeqThreadPerWarp; + static constexpr index_t NPerThread = Base::NPerThread; + + static constexpr index_t kBlockPerCu = 1; + + static constexpr index_t kColPaddedStride = Base::kColPaddedStride; + static constexpr index_t kPerWarpFloats = Base::kPerWarpFloats; + static constexpr index_t kReduceBytes = NumWarps * kPerWarpFloats * sizeof(float); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return kReduceBytes; } + + CK_TILE_HOST_DEVICE static constexpr auto MakeKBlockDistribution() + { + return Base::MakeKBlockDistribution(); + } + + // operator(): one work-group, one K-block. Writes D fp32 + 1 uint8 to workspace. + template + CK_TILE_DEVICE void operator()(const KWindowType& k_window, + index_t seqlen_k, + index_t kb, + float simthreshd1, + KDataType* __restrict__ pooled_k_out, // D KDataType (fp16/bf16) + uint8_t* __restrict__ sim_k_out, // 1 byte + void* smem_ptr) const + { + const index_t tid = static_cast(threadIdx.x); + auto* smem_reduce = reinterpret_cast(smem_ptr); + + const index_t bs_k = min(static_cast(kN0), seqlen_k - kb * kN0); + const float inv_bs_k = (bs_k > 0) ? (1.0f / static_cast(bs_k)) : 0.f; + + auto k_tile = load_tile(k_window); + + float k_data[NPerThread * KPerThread]; + Base::template tile_to_float(k_tile, k_data); + + const index_t warp_id = tid / WarpSize; + const index_t lane_id = tid % WarpSize; + const index_t k_idx = lane_id % KThreads; + const index_t m_idx = lane_id / KThreads; + + // pooled_k_mean: column sum then cross-warp reduce. + // Drop trailing sync (next cross_warp_reduce has its own leading sync). + float pooled_k_mean[KPerThread]; + Base::template column_reduce_thread_and_warp(k_data, pooled_k_mean); + Base::template column_reduce_cross_warp(pooled_k_mean, smem_reduce); + for(index_t k = 0; k < KPerThread; ++k) + pooled_k_mean[k] *= inv_bs_k; + + // Write pooled_k_mean to global early so its register liveness ends here, + // freeing VGPR before k_sum_hat becomes live. + if(warp_id == 0 && m_idx == 0) + { + for(index_t k = 0; k < KPerThread; ++k) + pooled_k_out[k_idx * KPerThread + k] = type_convert(pooled_k_mean[k]); + } + + // K row L2 norms + normalised column sum (k_sum_hat) + float k_psq[NPerThread]; + Base::template row_reduce_sq_norm(k_data, k_psq, bs_k); + + float k_sum_hat[KPerThread]; + Base::template column_reduce_normalised(k_data, k_psq, k_sum_hat, bs_k); + // Drop trailing sync (no further smem read; only intra-warp shuffle + global write). + Base::template column_reduce_cross_warp(k_sum_hat, smem_reduce); + + // sim_k = (||k_sum_hat||^2 / bs_k^2) > simthreshd1 + float ksh_sq = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + ksh_sq += k_sum_hat[k] * k_sum_hat[k]; + ksh_sq = Base::reduce_across_k(ksh_sq); + const float denom_k = static_cast(bs_k) * static_cast(bs_k); + const bool sim_k = (denom_k > 0.f) && ((ksh_sq / denom_k) > simthreshd1); + + if(tid == 0) + *sim_k_out = sim_k ? static_cast(1) : static_cast(0); + } +}; + +} // namespace ck_tile