Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
eed42a9
Add host-side Sparge block-map pipeline for sparse attention examples
gino-lu Mar 20, 2026
9317fc4
Support 64x128 tile size in sparge fwd for Jenga and VSA paths
gino-lu Mar 24, 2026
643ad35
Merge remote-tracking branch 'origin/develop' into ginolu/sparge_atte…
gino-lu Apr 13, 2026
d1d457b
Add sparge gpu pipeline in tile_example_sparge_vsa_sparse_attn
gino-lu Apr 13, 2026
c7e6e4f
fix extra host side operations.
gino-lu Apr 14, 2026
ab44b83
refactor to combine two kernel
gino-lu Apr 22, 2026
eca3cb3
sparse_attn: add bm0 dispatch for sparge blockmap compatibility
gino-lu Apr 24, 2026
b00e544
sparse_attn: split KStats kernel, add README + perf charts
gino-lu May 5, 2026
668e107
fix(sparse_attn): backport PR #4742 LDS s_barrier
gino-lu May 17, 2026
7103eac
refactor(sparse_attn): caller-owned workspace + dtype-aware sizing
gino-lu May 17, 2026
879d508
cleanup(sparse_attn): R-tag rename + clang-format sweep
gino-lu May 17, 2026
840b8a3
test(sparse_attn): CPU-ref cross-check + BLKQ cite
gino-lu May 17, 2026
0f8b58a
sparse_attn: R25 Step 1 A1 — per-warp PV-skip (paper Algorithm 1) + V…
gino-lu May 18, 2026
304c1f9
Merge remote-tracking branch 'origin/develop' into ginolu/sparge_atte…
gino-lu May 20, 2026
d939c3b
sparse_attn: split-launch dispatch + 3-mode PV-skip
gino-lu May 20, 2026
9e3f883
sparse_attn: drop stale FMHA-vs-sparge perf section from README
gino-lu May 20, 2026
b3ea819
sparse_attn: annotate PV-skip chart with speedup vs dense baseline
gino-lu May 20, 2026
fb75da2
sparse_attn: wire -mask and -attention_sink (block-map prune + attn m…
gino-lu May 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions example/ck_tile/01_fmha/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,15 @@ message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}")
add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL example_fmha_fwd.cpp)
target_link_libraries(${EXAMPLE_FMHA_FWD} ${FMHA_FWD_INSTANCES})
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set_property(TARGET ${EXAMPLE_FMHA_FWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})

message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}")
# not using add_example_executable() to add this target, since we don't want this to be included in
# "make all/install/check"
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp)
target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES})
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set_property(TARGET ${EXAMPLE_FMHA_BWD} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})

# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
Expand Down
109 changes: 104 additions & 5 deletions example/ck_tile/50_sparse_attn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# CMakeLists.txt for sparse attention (Jenga and VSA)
#Copyright(c) Advanced Micro Devices, Inc., or its affiliates.
#SPDX - License - Identifier : MIT
#CMakeLists.txt for sparse attention(Jenga and VSA)

# Use SUPPORTED_GPU_TARGETS directly
#Use SUPPORTED_GPU_TARGETS directly
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
set(GPU_TARGETS ${SUPPORTED_GPU_TARGETS})

Expand All @@ -16,7 +16,7 @@ endif()

message(STATUS "Building Sparse Attention (Jenga & VSA) for targets: ${INST_TARGETS}")

# Code generation scripts
#Code generation scripts
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
${CMAKE_CURRENT_LIST_DIR}/generate.py
${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
Expand Down Expand Up @@ -83,6 +83,7 @@ message(DEBUG "adding example ${EXAMPLE_JENGA_SPARSE_ATTN}")
add_executable(${EXAMPLE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_jenga_sparse_attn.cpp)
target_link_libraries(${EXAMPLE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES})
target_include_directories(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set_property(TARGET ${EXAMPLE_JENGA_SPARSE_ATTN} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
Expand Down Expand Up @@ -148,9 +149,107 @@ message(DEBUG "adding example ${EXAMPLE_VSA_SPARSE_ATTN}")
add_executable(${EXAMPLE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_vsa_sparse_attn.cpp)
target_link_libraries(${EXAMPLE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES})
target_include_directories(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set_property(TARGET ${EXAMPLE_VSA_SPARSE_ATTN} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
)

# ============================================================================
# Sparge Sparse Attention (PV-skip enabled, derived from VSA)
# ============================================================================
set(SPARSE_ATTN_SPARGE_CODE_GEN_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--api fwd_sparge
--receipt 600
)

# Generate list of Sparge kernels (at configure time, only list)
execute_process(
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_SPARGE_CODE_GEN_ARGS}
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to generate Sparge kernel list")
endif()

file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_blob_list.txt SPARSE_ATTN_SPARGE_GEN_BLOBS)

# Generate Sparge kernel source files at build time
add_custom_command(
OUTPUT ${SPARSE_ATTN_SPARGE_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_SPARGE_CODE_GEN_ARGS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
COMMENT "Generate CK Tile Sparge Sparse Attention kernels"
)

message(STATUS "Sparge kernel files to be generated: ${SPARSE_ATTN_SPARGE_GEN_BLOBS}")

# Sparge Instances
set(SPARSE_ATTN_SPARGE_INSTANCES "tile_sparse_attn_sparge_instances")

add_library(${SPARSE_ATTN_SPARGE_INSTANCES} OBJECT EXCLUDE_FROM_ALL
${SPARSE_ATTN_SPARGE_GEN_BLOBS}
)
target_include_directories(${SPARSE_ATTN_SPARGE_INSTANCES} PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
)
set_source_files_properties(${SPARSE_ATTN_SPARGE_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
set_property(TARGET ${SPARSE_ATTN_SPARGE_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})

target_compile_options(${SPARSE_ATTN_SPARGE_INSTANCES} PRIVATE
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
-DCK_TILE_FMHA_FWD_FAST_EXP2
-Wno-undefined-func-template
-Wno-float-equal
)

# ============================================================================
# Sparge BlockMap GPU Kernel (hand-written instantiation, no codegen)
# ============================================================================
set(SPARGE_BLOCKMAP_INSTANCES "tile_sparge_blockmap_instances")

add_library(${SPARGE_BLOCKMAP_INSTANCES} OBJECT EXCLUDE_FROM_ALL
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp
)
target_include_directories(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
)
set_source_files_properties(
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp
PROPERTIES LANGUAGE HIP
)
set_property(TARGET ${SPARGE_BLOCKMAP_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})

target_compile_options(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
-DCK_TILE_FMHA_FWD_FAST_EXP2
-Wno-undefined-func-template
-Wno-float-equal
)

# ----------------------------------------------------------------------------
# Build unified Sparge test: combines blockmap, Jenga, and VSA attention
# for end-to-end evaluation and timing in a single executable.
# ----------------------------------------------------------------------------
set(EXAMPLE_SPARGE "tile_example_sparge")
message(DEBUG "adding example ${EXAMPLE_SPARGE}")
add_executable(${EXAMPLE_SPARGE} EXCLUDE_FROM_ALL test_sparge.cpp)
target_link_libraries(${EXAMPLE_SPARGE}
${SPARSE_ATTN_JENGA_INSTANCES}
${SPARSE_ATTN_VSA_INSTANCES}
${SPARSE_ATTN_SPARGE_INSTANCES}
${SPARGE_BLOCKMAP_INSTANCES}
)
target_include_directories(${EXAMPLE_SPARGE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
set_property(TARGET ${EXAMPLE_SPARGE} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
target_compile_options(${EXAMPLE_SPARGE} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
)

set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
52 changes: 52 additions & 0 deletions example/ck_tile/50_sparse_attn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Sparge Attention (Composable Kernel)

A Composable Kernel port of [SpargeAttn](https://github.com/thu-ml/SpargeAttn) for AMD GPU. Both the block-map pipeline (mean-pool → cosine sim → pooled QK → top-k LUT) and the sparse FMHA stage run on-GPU. Two attention backends are exposed via `-pipeline=vsa` (default, faster) and `-pipeline=jenga` (async K/V load variant).

## Status vs Upstream

Implemented:
- per-block mean-pool, cosine similarity, pooled QK
- top-k / `cdfthreshd` block selection, BlockMap LUT
- sparse FMHA (both `vsa` and `jenga` backends)
- per-head `topk` / `simthreshd1` / `cdfthreshd`
- **is_causal mask in pooled score** (top-left only at block-map grain) ([spas_sage_attn/utils.py:L338](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L338))
- **attention_sink** — block-map column 0 force-on ([spas_sage_attn/autotune.py:L355](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/autotune.py#L355))

Not yet ported (upstream pinned to commit [`ae5b629`](https://github.com/thu-ml/SpargeAttn/tree/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a)):
- **K smoothing** — pre-pool `k -= km`; required for diffusion / video checkpoints (CogVideoX, Mochi-1, Flux, OpenSora, SD 3.5) ([spas_sage_attn/core.py:L53](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/core.py#L53))
- **Sort-based top-k selection** — replaces our O(N_k^2) iterative argmax; matters at long seqlen (s ≥ 16k) ([spas_sage_attn/utils.py:L345](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L345))
- **Q/K int8 quant fusion in pool kernel** — enables a downstream int8 GEMM0 in the attn kernel ([spas_sage_attn/utils.py:L371](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L371))

## PV-skip modes

`pv_threshold` per-Q-tile skip in the attention kernel is implemented in three variants, selectable at runtime via `-pv_mode={none|warp|block}`:

- **`none`** — skip disabled; baseline matching the no-PV-skip codegen instance.
- **`warp`** (per-wavefront) — each wavefront votes locally via `__shfl_xor` butterfly AND; SGPR-resident flag. CK-tile-specific variant, not in upstream.
- **`block`** (per-block) — block-wide consensus vote via LDS broadcast; aligned with upstream sm80 ([`qk_int_sv_f16_cuda_sm80.cuh:L334`](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/csrc/qattn/qk_int_sv_f16_cuda_sm80.cuh#L334)). V loads stay unconditional in all modes — the guard wraps the PV MMA only, matching upstream and paper Algorithm 1.

![PV-skip mode comparison](docs/pv_skip_mode_comparison.png)

*MI300X, b=2 h=16 s=8192 d=128 fp16, 5 seeds × 9 sparsity points. All three modes dispatch to the `kM0=64 padK=0` tile bucket at this shape.*

On the canonical recipe shape, `none > warp > block` at every measured sparsity, with no crossover. The per-block guard adds +33..+35 VGPR (6..9 spills) on this tile configuration, depressing occupancy. `warp` is +0..+4 VGPR. The default is `-pv_mode=warp`; switch to `none` for the no-skip baseline or `block` to exercise the upstream-aligned variant. A shape sweep is needed before recommending `block` as default — the `kM0=128` path has Δ ≈ 0 VGPR for per-block and is a candidate.

## Usage

```bash
ninja tile_example_sparge
./bin/tile_example_sparge -pipeline=vsa -b=2 -h=32 -s=16384 -d=128 -topk=0.4 -simthreshd1=0.001
```

Select a PV-skip variant with `-pv_mode={none|warp|block}` (default `warp`); finite `-pv_threshold=20` lets the per-Q-tile skip predicate fire.

Mask + attention sink:
- `-mask` accepts the `01_fmha` grammar (`0` / `t` / `b` / `t:l,r` / `xt:N` / `g:y,x`, default `0`). The block-map selection prunes past-diagonal blocks only under `mask_top_left` (`t`); `b` / SWA / generic are forwarded to the attention kernel and emit a stderr WARN that the block-map selection is unchanged.
- `-attention_sink {0,1}` forces block-map column `kb=0` ON for every Q-block (default `0`). Under `-mask t` this is degenerate since `kb=0` is always causal-valid.

Add `-v=1` for CPU validation; use a small shape (`-b=1 -h=2 -s=512`), since full-shape CPU reference scales O(s²) and runs 30+ minutes at s=8k, hours at s=16k. When `-mask != 0` or `-attention_sink == 1`, the `[block_map cross-check]` and `[VSA LUT self-consistency]` cells are SKIPPED (the CPU reference does not model causal mask or sink); the `[attention output]` cell still runs but the dense reference applies no mask, so it will report FAIL on the kernel-correct output. Treat `-v=1` correctness as **block-map level only** in those configurations.

## References

- [SpargeAttn upstream](https://github.com/thu-ml/SpargeAttn) (pinned to [`ae5b629`](https://github.com/thu-ml/SpargeAttn/tree/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a))
- [Paper — Zhang et al., arXiv:2502.18137](https://arxiv.org/abs/2502.18137)
2 changes: 2 additions & 0 deletions example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,13 @@ def get_mask_check_map(mask: str):
PIPELINE_MAP = {
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsyncJenga",
"qr_async_vsa": "ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA",
"qr_async_sparge": "ck_tile::BlockFmhaPipelineQRKSVSAsyncSparge",
}

PIPELINE_ENUM_MAP = {
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_async_vsa": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_async_sparge": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
}

BOOL_MAP = {
Expand Down
Loading