diff --git a/benchmarks/standalone_indexer/_common.py b/benchmarks/standalone_indexer/_common.py new file mode 100644 index 000000000..1d258a89d --- /dev/null +++ b/benchmarks/standalone_indexer/_common.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +"""Shared helpers for the standalone indexer-kernel benchmarks (torch-only).""" + +import time + +import torch + + +def make_kernel_inputs(B, oH, T_t, T_s, H, d_i, dtype, device="cuda", seed=0): + """Build the projected (Hq, Hk, W_o) the kernels actually consume. + + The original profilers generated Q/K/W and ran the einsum projections + (C_q, H_q, H_k, W_o) before the kernel. Those projections are plain GEMMs, + not part of the Triton kernels under test, so here we sample the kernel + inputs directly. + + Hq: (B, oH, T_t, H, d_i) + Hk: (B, oH, T_s, d_i) + W_o: (B, oH, T_t, H) + """ + g = torch.Generator(device=device).manual_seed(seed) + Hq = torch.randn((B, oH, T_t, H, d_i), dtype=dtype, device=device, generator=g) + Hk = torch.randn((B, oH, T_s, d_i), dtype=dtype, device=device, generator=g) + W_o = torch.randn((B, oH, T_t, H), dtype=dtype, device=device, generator=g) + return Hq, Hk, W_o + + +def time_fn(fn, n_warmup=15, n_iter=50): + """Time a no-arg thunk that launches GPU work. Returns seconds/call.""" + for _ in range(n_warmup): + fn() + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(n_iter): + fn() + torch.cuda.synchronize() + return (time.perf_counter() - t0) / n_iter diff --git a/benchmarks/standalone_indexer/indexer_bridge.py b/benchmarks/standalone_indexer/indexer_bridge.py new file mode 100644 index 000000000..a16a8af32 --- /dev/null +++ b/benchmarks/standalone_indexer/indexer_bridge.py @@ -0,0 +1,164 @@ +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +"""Torch launchers that drive the ACTUAL Triton kernels in the source file +``transformer_engine/jax/triton_extensions/indexer.py``. + +The profiling scripts can import their launchers from here instead of from the +self-contained ``indexer_kernels`` copy, so the benchmarks exercise the real +source kernels (forward ``_score_reduce_kernel``, backward +``_score_dscores_chunk_kernel``, top-k ``_score_topk_kernel`` / +``_score_topk_single_kernel``). + +Loading trick: importing ``transformer_engine.jax.triton_extensions.indexer`` +the normal way runs ``transformer_engine/jax/__init__.py``, which loads the TE +core C library -- broken in a namespace-package / JAX-less setup. The Triton +kernels need none of that, so we ``importlib``-load ``indexer.py`` directly, +pre-stubbing the parent packages (and the ``.utils`` relative import, which is +only used by the JAX lowerings, never by a direct kernel launch). ``jax`` itself +must be importable (it is -- the kernels' module defines JAX primitives at load +time), but no JAX device/compute is used here. +""" + +import importlib.util +import os +import sys +import types + +import torch +import triton + + +def _load_origin_indexer(): + repo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + idx_path = os.path.join( + repo_root, "transformer_engine", "jax", "triton_extensions", "indexer.py" + ) + if not os.path.isfile(idx_path): + raise FileNotFoundError(f"origin indexer.py not found at {idx_path}") + + # Stub parent packages so the relative `from .utils import ...` resolves + # WITHOUT running the real transformer_engine.jax.__init__ (broken core lib). + for name in ( + "transformer_engine", + "transformer_engine.jax", + "transformer_engine.jax.triton_extensions", + ): + if name not in sys.modules or not hasattr(sys.modules[name], "__path__"): + m = types.ModuleType(name) + m.__path__ = [] # mark as a package + sys.modules[name] = m + + util_name = "transformer_engine.jax.triton_extensions.utils" + if util_name not in sys.modules: + u = types.ModuleType(util_name) + # Only the JAX lowerings call this; direct kernel launches don't. + u.triton_call_lowering = lambda *a, **k: None + sys.modules[util_name] = u + + mod_name = "transformer_engine.jax.triton_extensions.indexer" + spec = importlib.util.spec_from_file_location(mod_name, idx_path) + mod = importlib.util.module_from_spec(spec) + sys.modules[mod_name] = mod + spec.loader.exec_module(mod) + return mod + + +_src = _load_origin_indexer() + +# Re-export the source objects the launchers / profilers reference. +_score_reduce_kernel = _src._score_reduce_kernel # triton Autotuner +_score_dscores_chunk_kernel = _src._score_dscores_chunk_kernel +_score_topk_kernel = _src._score_topk_kernel +_score_topk_single_kernel = _src._score_topk_single_kernel +_next_pow2 = _src._next_pow2 +_HBWD_BLOCK_T = _src._HBWD_BLOCK_T +_HBWD_BLOCK_S = _src._HBWD_BLOCK_S +_BWD_H_CHUNK = _src._BWD_H_CHUNK +_SINGLE_SORT_MAX = _src._SINGLE_SORT_MAX +_SCORE_TOPK_CONFIGS = _src._SCORE_TOPK_CONFIGS +_SINGLE_TOPK_CONFIGS = _src._SINGLE_TOPK_CONFIGS + +ORIGIN_FILE = _src.__file__ + + +# --- Forward ----------------------------------------------------------------- + +def score_reduce(Hq, Hk, W_o, out_dtype=None): + assert Hq.ndim == 5 and Hk.ndim == 4 and W_o.ndim == 4 + B, oH, T_t, H, d_i = Hq.shape + T_s = Hk.shape[2] + if out_dtype is None: + out_dtype = Hq.dtype + Hq, Hk, W_o = Hq.contiguous(), Hk.contiguous(), W_o.contiguous() + O = torch.empty((B, oH, T_t, T_s), dtype=out_dtype, device=Hq.device) + + def grid(meta): + return (triton.cdiv(T_s, meta["BLOCK_S"]), + triton.cdiv(T_t, meta["BLOCK_T"]), B * oH) + + _score_reduce_kernel[grid]( + Hq, Hk, W_o, O, B=B, oH=oH, T_t=T_t, T_s=T_s, H=H, d_i=d_i, + ) + return O + + +# --- Backward ---------------------------------------------------------------- + +def bwd_h_chunk(H): + """H_CHUNK selection identical to the source ``_score_reduce_bwd``.""" + if H % _BWD_H_CHUNK == 0: + return _BWD_H_CHUNK + for c in (4, 2): + if H % c == 0: + return c + return 1 + + +def score_dscores_chunk(Hq_chunk, Hk, W_o_chunk, dO): + B, oH, T, H_CHUNK, d_i = Hq_chunk.shape + T_s = dO.shape[-1] + + Hq_chunk = Hq_chunk.contiguous() + Hk = Hk.contiguous() + W_o_chunk = W_o_chunk.contiguous() + dO = dO.contiguous() + + dscores_chunk = torch.empty( + (B, oH, T, H_CHUNK, T_s), dtype=Hq_chunk.dtype, device=Hq_chunk.device) + dWo_chunk = torch.empty( + (B, oH, T, H_CHUNK), dtype=Hq_chunk.dtype, device=Hq_chunk.device) + + def grid(meta): + return ((T + meta["BLOCK_T"] - 1) // meta["BLOCK_T"], B * oH) + + _score_dscores_chunk_kernel[grid]( + Hq_chunk, Hk, W_o_chunk, dO, dscores_chunk, dWo_chunk, + B=B, oH=oH, T=T, T_s=T_s, H_CHUNK=H_CHUNK, d_i=d_i, + ) + return dscores_chunk, dWo_chunk + + +# --- Top-k ------------------------------------------------------------------- + +def score_topk(Hq, Hk, W_o, k): + B, oH, T_t, H, d_i = Hq.shape + T_s = Hk.shape[2] + if k <= 0 or (k & (k - 1)) != 0: + raise ValueError(f"k must be a positive power of 2; got {k}") + if k > T_s: + raise ValueError(f"k={k} must be <= T_s={T_s}") + S_PAD = _next_pow2(T_s) + Hq, Hk, W_o = Hq.contiguous(), Hk.contiguous(), W_o.contiguous() + out = torch.empty((B, oH, T_t, k), dtype=torch.int32, device=Hq.device) + + def grid(meta): + return (triton.cdiv(T_t, meta["BLOCK_T"]), B * oH) + + kernel = (_score_topk_single_kernel if S_PAD <= _SINGLE_SORT_MAX + else _score_topk_kernel) + kernel[grid]( + Hq, Hk, W_o, out, + B=B, oH=oH, T_t=T_t, T_s=T_s, H=H, d_i=d_i, K=k, S_PAD=S_PAD, + ) + return out diff --git a/benchmarks/standalone_indexer/profile_indexer.py b/benchmarks/standalone_indexer/profile_indexer.py new file mode 100644 index 000000000..163407e68 --- /dev/null +++ b/benchmarks/standalone_indexer/profile_indexer.py @@ -0,0 +1,88 @@ +"""Profile the standalone forward score-reduce Triton kernel (bf16). + +Extracted from ``transformer_engine/jax/triton_extensions/indexer.py`` -- +``_score_reduce_kernel`` only, no transformer_engine / jax dependency. + +Measures wall time and effective TFLOPS for the fused kernel: + + scores = relu(einsum("...thi,...si->...ths", Hq, Hk)) # never written + O = einsum("...ths,...th->...ts", scores, W_o) + +Run: + python benchmarks/standalone_indexer/profile_indexer.py +""" + +import torch + +from _common import make_kernel_inputs, time_fn +# Drive the ACTUAL source kernel in +# transformer_engine/jax/triton_extensions/indexer.py (via the indexer_bridge +# adapter). Swap to `indexer_kernels` to profile the self-contained copy. +from indexer_bridge import score_reduce + + +# --- FLOP accounting ------------------------------------------------------------ + +def kernel_flops(B, oH, T, S, H, d_i): + # The kernel does the score matmul + the weighted H-reduction; the four + # projection GEMMs (C_q, H_q, H_k, W_o) are excluded -- they're not in the + # kernel. 2 flops per multiply-add. + # scores = relu(Hq @ Hk^T) : 2 * B*oH * T * H * S * d_i + # O = sum_h scores * W_o : 2 * B*oH * T * S * H + n = B * oH + return 2 * (n * T * H * S * d_i + n * T * S * H) + + +# --- Reference (torch) for a correctness sanity check --------------------------- + +def torch_reference(Hq, Hk, W_o): + # Hq (B,oH,T,H,d_i), Hk (B,oH,S,d_i), W_o (B,oH,T,H) -> O (B,oH,T,S) + scores = torch.einsum("bothi,bosi->boths", Hq.float(), Hk.float()) + scores = torch.relu(scores) + O = torch.einsum("boths,both->bots", scores, W_o.float()) + return O + + +# --- Driver --------------------------------------------------------------------- + +CONFIGS = [ + #(B, oH, T, S, H, d_i) + ( 2, 64, 4096, 4096, 64, 128), +] + + +def check_correctness(): + B, oH, T, S, H, d_i = 1, 2, 128, 256, 8, 128 + Hq, Hk, W_o = make_kernel_inputs(B, oH, T, S, H, d_i, torch.bfloat16) + O = score_reduce(Hq, Hk, W_o).float() + ref = torch_reference(Hq, Hk, W_o) + rel = (O - ref).norm() / ref.norm().clamp_min(1e-9) + print(f" correctness: rel L2 err vs torch fp32 ref = {rel.item():.4e}") + + +def main(): + print(f"device: {torch.cuda.get_device_name(0)}\n") + check_correctness() + print() + for B, oH, T, S, H, d_i in CONFIGS: + Hq, Hk, W_o = make_kernel_inputs(B, oH, T, S, H, d_i, torch.bfloat16) + flops = kernel_flops(B, oH, T, S, H, d_i) + + print(f"--- B={B} oH={oH} T={T} S={S} H={H} d_i={d_i} bfloat16 ---") + print(f" kernel work = {flops/1e9:.2f} GFLOPs/call") + + try: + sec = time_fn(lambda: score_reduce(Hq, Hk, W_o)) + ms = sec * 1e3 + tflops = flops / sec / 1e12 + print(f" {'score_reduce':<14} {ms:8.3f} ms {tflops:6.2f} TFLOP/s") + except Exception as e: # noqa: BLE001 + print(f" {'score_reduce':<14} FAILED: {type(e).__name__}: {str(e).splitlines()[0]}") + + # Report the autotuner-selected config. + from indexer_bridge import _score_reduce_kernel + print(" Best config: ", _score_reduce_kernel.best_config) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/standalone_indexer/profile_indexer_bwd.py b/benchmarks/standalone_indexer/profile_indexer_bwd.py new file mode 100644 index 000000000..88c87ca59 --- /dev/null +++ b/benchmarks/standalone_indexer/profile_indexer_bwd.py @@ -0,0 +1,119 @@ +"""Profile the standalone backward score-chunk Triton kernel (bf16). + +Extracted from ``transformer_engine/jax/triton_extensions/indexer.py`` -- +``_score_dscores_chunk_kernel`` only, no transformer_engine / jax dependency. + +The original backward loops over H/H_CHUNK chunks, each chunk fusing +(score recompute + relu + mask + dO*W_o broadcast) into dscores_chunk and +reducing dWo_chunk in registers. The dHq/dHk reductions on dscores_chunk are +plain GEMMs (hipBLASLt einsums) and are NOT part of this kernel, so they are +excluded here. We time the per-chunk kernel and the full chunk sweep. + +Run: + python benchmarks/standalone_indexer/profile_indexer_bwd.py +""" + +import torch + +from _common import make_kernel_inputs, time_fn +# Drive the ACTUAL source kernel in +# transformer_engine/jax/triton_extensions/indexer.py (via the indexer_bridge +# adapter). Swap to `indexer_kernels` to profile the self-contained copy. +from indexer_bridge import bwd_h_chunk, score_dscores_chunk + + +def kernel_flops_per_chunk(B, oH, T, S, H_CHUNK, d_i): + # Per chunk: score recompute matmul + the dWo sum-reduction over s. + n = B * oH + return 2 * (n * T * H_CHUNK * S * d_i + n * T * H_CHUNK * S) + + +def torch_reference_chunk(Hq_chunk, Hk, W_o_chunk, dO): + # scores = relu(Hq_chunk @ Hk^T), then + # dWo[...,h] = sum_s relu(scores)[...,h,s] * dO[...,s] + # dscores[...] = (scores>0) * dO[...,s] * W_o[...,h] + scores = torch.einsum("bothi,bosi->boths", Hq_chunk.float(), Hk.float()) + relu_mask = scores > 0 + h_relu = torch.relu(scores) + dWo = torch.einsum("boths,bots->both", h_relu, dO.float()) + dscores = relu_mask.float() * (dO.float()[:, :, :, None, :] + * W_o_chunk.float()[..., None]) + return dscores, dWo + + +CONFIGS = [ + #(B, oH, T, S, H, d_i) + ( 2, 64, 1024, 1024, 64, 128), +] + + +def check_correctness(): + B, oH, T, S, H, d_i = 1, 2, 128, 256, 8, 128 + H_CHUNK = bwd_h_chunk(H) + Hq, Hk, W_o = make_kernel_inputs(B, oH, T, S, H, d_i, torch.bfloat16) + dO = torch.randn((B, oH, T, S), dtype=torch.float32, device="cuda") + Hq_c = Hq[:, :, :, :H_CHUNK, :].contiguous() + W_o_c = W_o[:, :, :, :H_CHUNK].contiguous() + dscores, dWo = score_dscores_chunk(Hq_c, Hk, W_o_c, dO) + ref_ds, ref_dwo = torch_reference_chunk(Hq_c, Hk, W_o_c, dO) + ds_err = (dscores.float() - ref_ds).norm() / ref_ds.norm().clamp_min(1e-9) + dwo_err = (dWo.float() - ref_dwo).norm() / ref_dwo.norm().clamp_min(1e-9) + print(f" correctness (H_CHUNK={H_CHUNK}): dscores rel err = {ds_err.item():.4e}, " + f"dWo rel err = {dwo_err.item():.4e}") + + +def main(): + print(f"device: {torch.cuda.get_device_name(0)}\n") + check_correctness() + print() + for B, oH, T, S, H, d_i in CONFIGS: + H_CHUNK = bwd_h_chunk(H) + n_chunks = H // H_CHUNK + Hq, Hk, W_o = make_kernel_inputs(B, oH, T, S, H, d_i, torch.bfloat16) + dO = torch.randn((B, oH, T, S), dtype=torch.float32, device="cuda") + + # Pre-slice the per-chunk views the original scan feeds the kernel. + chunks = [ + (Hq[:, :, :, c * H_CHUNK:(c + 1) * H_CHUNK, :].contiguous(), + W_o[:, :, :, c * H_CHUNK:(c + 1) * H_CHUNK].contiguous()) + for c in range(n_chunks) + ] + + per_chunk_flops = kernel_flops_per_chunk(B, oH, T, S, H_CHUNK, d_i) + total_flops = per_chunk_flops * n_chunks + + print(f"--- B={B} oH={oH} T={T} S={S} H={H} d_i={d_i} bfloat16 ---") + print(f" H_CHUNK={H_CHUNK} n_chunks={n_chunks}") + print(f" per-chunk work = {per_chunk_flops/1e9:.2f} GFLOPs " + f"full sweep = {total_flops/1e9:.2f} GFLOPs") + + # Single-chunk timing. + Hq_c, W_o_c = chunks[0] + try: + sec = time_fn(lambda: score_dscores_chunk(Hq_c, Hk, W_o_c, dO)) + ms = sec * 1e3 + tflops = per_chunk_flops / sec / 1e12 + print(f" {'one chunk':<14} {ms:8.3f} ms {tflops:6.2f} TFLOP/s") + except Exception as e: # noqa: BLE001 + print(f" {'one chunk':<14} FAILED: {type(e).__name__}: {str(e).splitlines()[0]}") + + # Full chunk-sweep timing (the work one backward pass issues). + def run_all(): + for hq_c, wo_c in chunks: + score_dscores_chunk(hq_c, Hk, wo_c, dO) + + try: + sec = time_fn(run_all) + ms = sec * 1e3 + tflops = total_flops / sec / 1e12 + print(f" {'full sweep':<14} {ms:8.3f} ms {tflops:6.2f} TFLOP/s") + except Exception as e: # noqa: BLE001 + print(f" {'full sweep':<14} FAILED: {type(e).__name__}: {str(e).splitlines()[0]}") + print() + # Report the autotuner-selected config. + from indexer_bridge import _score_dscores_chunk_kernel + print(" Best config: ", _score_dscores_chunk_kernel.best_config) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/standalone_indexer/profile_indexer_topk.py b/benchmarks/standalone_indexer/profile_indexer_topk.py new file mode 100644 index 000000000..663451187 --- /dev/null +++ b/benchmarks/standalone_indexer/profile_indexer_topk.py @@ -0,0 +1,80 @@ +"""Profile the standalone fused score + streaming top-k Triton kernel (bf16). + +Extracted from ``transformer_engine/jax/triton_extensions/indexer.py`` -- +``_score_topk_kernel`` only, no transformer_engine / jax dependency. + +Computes the same scores as the forward kernel but never materializes the +(B, oH, T_t, T_s) score matrix -- returns the top-k indices into T_s directly. +top-k is comparison-only and counted as 0 FLOP, so reported TFLOPS reflect the +score compute. + +Run: + python benchmarks/standalone_indexer/profile_indexer_topk.py +""" + +import torch + +from _common import make_kernel_inputs, time_fn +# Drive the ACTUAL source kernels in +# transformer_engine/jax/triton_extensions/indexer.py (via the indexer_bridge +# adapter). +from indexer_bridge import score_topk + + +def kernel_flops(B, oH, T, S, H, d_i): + # Score matmul + weighted H-reduction; top-k is 0 FLOP. + n = B * oH + return 2 * (n * T * H * S * d_i + n * T * S * H) + + +def torch_scores(Hq, Hk, W_o): + scores = torch.einsum("bothi,bosi->boths", Hq.float(), Hk.float()) + scores = torch.relu(scores) + return torch.einsum("boths,both->bots", scores, W_o.float()) + + +CONFIGS = [ + #(B, oH, T, S, H, d_i) + ( 2, 64, 1024, 1024, 64, 128), +] + +K_TOPK = 512 + + +def check_correctness(): + B, oH, T, S, H, d_i = 1, 2, 64, 256, 8, 128 + k = 64 + Hq, Hk, W_o = make_kernel_inputs(B, oH, T, S, H, d_i, torch.bfloat16) + idx = score_topk(Hq, Hk, W_o, k=k).long() + scores = torch_scores(Hq, Hk, W_o) + # Compare the score *values* at the selected indices (robust to tie order). + sel = torch.gather(scores, -1, idx) + sel_sorted = torch.sort(sel, dim=-1, descending=True).values + ref_vals = torch.topk(scores, k, dim=-1).values + rel = (sel_sorted - ref_vals).norm() / ref_vals.norm().clamp_min(1e-9) + print(f" correctness (k={k}): top-k value rel err vs torch = {rel.item():.4e}") + + +def main(): + print(f"device: {torch.cuda.get_device_name(0)}\nk = {K_TOPK}\n") + check_correctness() + print() + for B, oH, T, S, H, d_i in CONFIGS: + Hq, Hk, W_o = make_kernel_inputs(B, oH, T, S, H, d_i, torch.bfloat16) + flops = kernel_flops(B, oH, T, S, H, d_i) + + print(f"--- B={B} oH={oH} T={T} S={S} H={H} d_i={d_i} bfloat16 ---") + print(f" kernel work = {flops/1e9:.2f} GFLOPs/call (top-k = 0 FLOP)") + + try: + sec = time_fn(lambda: score_topk(Hq, Hk, W_o, k=K_TOPK)) + ms = sec * 1e3 + tflops = flops / sec / 1e12 + print(f" {'score_topk':<14} {ms:8.3f} ms {tflops:6.2f} TFLOP/s") + except Exception as e: # noqa: BLE001 + print(f" {'score_topk':<14} FAILED: {type(e).__name__}: {str(e).splitlines()[0]}") + print() + + +if __name__ == "__main__": + main() diff --git a/transformer_engine/jax/triton_extensions/indexer.py b/transformer_engine/jax/triton_extensions/indexer.py index e6477d6ae..3210cec42 100644 --- a/transformer_engine/jax/triton_extensions/indexer.py +++ b/transformer_engine/jax/triton_extensions/indexer.py @@ -39,20 +39,30 @@ def _score_reduce_autotune_configs(): # (resource exhaustion — VGPR/LDS budget for 64-iter H-loop with that # large an accumulator). Capped at 256. cfgs = [] - for bt in (64, 128, 256): - for bs in (32, 64, 128): - for num_warps in (4, 8): - for num_stages in (1, 2): - cfgs.append(triton.Config( - {"BLOCK_T": bt, "BLOCK_S": bs}, - num_warps=num_warps, num_stages=num_stages, - )) + # cfgs = [ + # triton.Config({"BLOCK_T": bt, "BLOCK_S": bs, "matrix_instr_nonkdim": nk_dim, "waves_per_eu": wpe}, num_warps=nw, num_stages=ns) + # for bt in (64, 128, 256) + # for bs in (32, 64, 128) + # for nk_dim in (16, 32) + # for wpe in (0, 1, 2) + # for nw in (4, 8) + # for ns in (1, 2, 3) + # ] # A few skinny / fat shapes the regular grid above won't hit. cfgs += [ triton.Config({"BLOCK_T": 32, "BLOCK_S": 128}, num_warps=4, num_stages=2), triton.Config({"BLOCK_T": 32, "BLOCK_S": 256}, num_warps=4, num_stages=2), triton.Config({"BLOCK_T": 256, "BLOCK_S": 32}, num_warps=8, num_stages=2), ] + cfgs += [ + triton.Config({"BLOCK_T": bt, "BLOCK_S": bs, "matrix_instr_nonkdim": nk_dim, "waves_per_eu": wpe}, num_warps=nw, num_stages=ns) + for bt in (32, 64) + for bs in (256, 512) + for nk_dim in (16,) + for wpe in (0, 2) # 0 means let backend compiler decide + for nw in (4,) + for ns in (2,) + ] return cfgs @@ -191,9 +201,32 @@ def grid_fn(merged_kwargs): _HBWD_BLOCK_T = 64 -_HBWD_BLOCK_S = 64 +_HBWD_BLOCK_S = 256 +def _score_dscores_chunk_autotune_configs(): + cfgs = [] + cfgs += [ + triton.Config( + {"BLOCK_T": bt, "BLOCK_S": bs, + "matrix_instr_nonkdim": nk_dim, "waves_per_eu": wpe}, + num_warps=nw, num_stages=1) + for bt in (32, 64, 128) + for bs in (128, 256) + for nk_dim in (16, 32) + for wpe in (0, 2) + for nw in (4, 8) + ] + # larger BLOCK_S for long T_s + cfgs += [ + triton.Config({"BLOCK_T": bt, "BLOCK_S": 512}, num_warps=4, num_stages=1) + for bt in (32, 64) + ] + return cfgs + + +@triton.autotune(configs=_score_dscores_chunk_autotune_configs(), + key=["T", "T_s", "H_CHUNK", "d_i"]) @triton.jit def _score_dscores_chunk_kernel( Hq_chunk_ptr, # input (B, oH, T, H_CHUNK, d_i) bf16 @@ -211,22 +244,22 @@ def _score_dscores_chunk_kernel( BLOCK_T: tl.constexpr, BLOCK_S: tl.constexpr, ): - """One CTA handles (T_tile, h_in) for one (b, h_outer). Loops over s_chunks. + """One CTA handles (T_tile, all H_CHUNK heads) for one (b, h_outer). - Each CTA writes its T_tile rows of (dscores_chunk[..., h_in, :], - dW_o_chunk[..., h_in]). dW_o is reduced in registers (sum over s) so - h_relu never lands in HBM -- we compute it on-the-fly and consume it. + Grid: (cdiv(T, BLOCK_T), B * oH). For each s-chunk we load dO_chunk and + Hk_chunk ONCE and reuse them across every head in the chunk -- the key + saving vs the original (which spun a separate CTA per head, each re-reading + dO/Hk). dW_o is reduced in registers (sum over s) per head, so h_relu never + lands in HBM. """ pid_t = tl.program_id(0) - pid_h_bh = tl.program_id(1) - h_in = pid_h_bh % H_CHUNK - pid_bh = pid_h_bh // H_CHUNK + pid_bh = tl.program_id(1) b = (pid_bh // oH).to(tl.int64) h_outer = (pid_bh % oH).to(tl.int64) - h_in_64 = h_in.to(tl.int64) rt = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) rdi = tl.arange(0, d_i) + rhc = tl.arange(0, H_CHUNK) rt_mask = rt < T hq_base = b * (oH * T * H_CHUNK * d_i) + h_outer * (T * H_CHUNK * d_i) @@ -235,64 +268,58 @@ def _score_dscores_chunk_kernel( do_base = b * (oH * T * T_s) + h_outer * (T * T_s) ds_base = b * (oH * T * H_CHUNK * T_s) + h_outer * (T * H_CHUNK * T_s) - # Load Hq[..., t_tile, h_in, :] -> [BLOCK_T, d_i] once per CTA - hq_ptrs = (Hq_chunk_ptr + hq_base - + rt[:, None] * (H_CHUNK * d_i) - + h_in_64 * d_i - + rdi[None, :]) - Hq_h = tl.load(hq_ptrs, mask=rt_mask[:, None], other=0.0) - - # Load W_o[..., t_tile, h_in] -> [BLOCK_T] once per CTA - wo_ptrs = W_o_chunk_ptr + wo_base + rt * H_CHUNK + h_in_64 - w_h = tl.load(wo_ptrs, mask=rt_mask, other=0.0).to(tl.float32) - - # dW_o accumulator: sum_s (h_relu * dO) -- reduced in regs - dWo_acc = tl.zeros((BLOCK_T,), dtype=tl.float32) + # Per-head dW_o accumulators packed as (BLOCK_T, H_CHUNK), reduced over s. + dWo_acc = tl.zeros((BLOCK_T, H_CHUNK), dtype=tl.float32) for s_start in range(0, T_s, BLOCK_S): rs = s_start + tl.arange(0, BLOCK_S) rs_mask = rs < T_s - # Load Hk[..., s_chunk, :] and dO[..., t_tile, s_chunk] + # Load Hk[..., s_chunk, :] and dO[..., t_tile, s_chunk] ONCE per s-chunk + # -- shared across all H_CHUNK heads below. hk_ptrs = Hk_ptr + hk_base + rs[:, None] * d_i + rdi[None, :] Hk_chunk = tl.load(hk_ptrs, mask=rs_mask[:, None], other=0.0) + Hk_T = tl.trans(Hk_chunk) # (d_i, BLOCK_S) do_ptrs = dO_ptr + do_base + rt[:, None] * T_s + rs[None, :] dO_chunk = tl.load( - do_ptrs, - mask=rt_mask[:, None] & rs_mask[None, :], - other=0.0, - ) - - # scores tile in registers (never lands in HBM at full size) - scores = tl.dot(Hq_h, tl.trans(Hk_chunk)) - relu_mask = scores > 0 - h_relu = tl.where(relu_mask, scores, 0.0) - - # dW_o contribution: sum_s (h_relu * dO) - dWo_acc += tl.sum(h_relu * dO_chunk, axis=1) - - # dscores tile = relu_mask * (dO * W_o) - dscores = tl.where(relu_mask, dO_chunk * w_h[:, None], 0.0) - - # Store dscores tile to HBM (bf16). Total dscores_chunk size is - # H_CHUNK x smaller than the full (B,oH,T,H,T_s) tensor. - ds_ptrs = (dscores_chunk_ptr + ds_base - + rt[:, None] * (H_CHUNK * T_s) - + h_in_64 * T_s - + rs[None, :]) - tl.store( - ds_ptrs, - dscores.to(dscores_chunk_ptr.dtype.element_ty), - mask=rt_mask[:, None] & rs_mask[None, :], + do_ptrs, mask=rt_mask[:, None] & rs_mask[None, :], other=0.0, ) - # Store dW_o[..., t_tile, h_in] - dwo_out_ptrs = dWo_chunk_ptr + wo_base + rt * H_CHUNK + h_in_64 + for h in tl.static_range(H_CHUNK): + # Hq/w for head h (small, L2-resident across s-chunks). + Hq_h = tl.load( + Hq_chunk_ptr + hq_base + rt[:, None] * (H_CHUNK * d_i) + + h * d_i + rdi[None, :], + mask=rt_mask[:, None], other=0.0, + ) + w_h = tl.load( + W_o_chunk_ptr + wo_base + rt * H_CHUNK + h, + mask=rt_mask, other=0.0, + ).to(tl.float32) + + scores = tl.dot(Hq_h, Hk_T) # (BLOCK_T, BLOCK_S) + relu_mask = scores > 0 + h_relu = tl.where(relu_mask, scores, 0.0) + + # dW_o[..., h] += sum_s (h_relu * dO); accumulate into column h. + dwo_h = tl.sum(h_relu * dO_chunk, axis=1) # (BLOCK_T,) + dWo_acc += tl.where(rhc[None, :] == h, dwo_h[:, None], 0.0) + + # dscores[..., h, s] = relu_mask * (dO * W_o) + dscores = tl.where(relu_mask, dO_chunk * w_h[:, None], 0.0) + ds_ptrs = (dscores_chunk_ptr + ds_base + + rt[:, None] * (H_CHUNK * T_s) + h * T_s + rs[None, :]) + tl.store( + ds_ptrs, dscores.to(dscores_chunk_ptr.dtype.element_ty), + mask=rt_mask[:, None] & rs_mask[None, :], + ) + + # Store dW_o[..., t_tile, :] for all heads. + dwo_out_ptrs = dWo_chunk_ptr + wo_base + rt[:, None] * H_CHUNK + rhc[None, :] tl.store( - dwo_out_ptrs, - dWo_acc.to(dWo_chunk_ptr.dtype.element_ty), - mask=rt_mask, + dwo_out_ptrs, dWo_acc.to(dWo_chunk_ptr.dtype.element_ty), + mask=rt_mask[:, None], ) @@ -321,21 +348,24 @@ def _score_dscores_chunk_lowering(ctx, Hq_chunk, Hk, W_o_chunk, dO): dO_aval = ctx.avals_in[3] B, oH, T, H_CHUNK, d_i = Hq_aval.shape T_s = dO_aval.shape[-1] - BLOCK_T = _HBWD_BLOCK_T if T >= _HBWD_BLOCK_T else T - BLOCK_S = _HBWD_BLOCK_S if T_s >= _HBWD_BLOCK_S else T_s - n_t_tiles = (T + BLOCK_T - 1) // BLOCK_T + + # Grid is (T-tiles, B*oH) -- one CTA per (T_tile, b, h_outer) covering ALL + # H_CHUNK heads (so dO/Hk are shared across heads), NOT one CTA per head. + # BLOCK_T/BLOCK_S/num_warps/num_stages come from the autotuner; all configs + # pin num_stages=1 (pipelining the s-loop crashes LLVM codegen on Triton + # 3.7.0 / gfx950). The grid depends on the autotuned BLOCK_T. + def grid_fn(merged_kwargs): + bt = merged_kwargs.get("BLOCK_T", _HBWD_BLOCK_T) + return ((T + bt - 1) // bt, B * oH) return triton_call_lowering( ctx, _score_dscores_chunk_kernel, Hq_chunk, Hk, W_o_chunk, dO, - grid=(n_t_tiles, B * oH * H_CHUNK), - num_warps=4, - num_stages=2, + grid=grid_fn, constexprs={ "B": B, "oH": oH, "T": T, "T_s": T_s, "H_CHUNK": H_CHUNK, "d_i": d_i, - "BLOCK_T": BLOCK_T, "BLOCK_S": BLOCK_S, }, ) @@ -509,13 +539,14 @@ def score_reduce_triton(Hq, Hk, W_o, *, out_dtype=None): # (logits[BLOCK_S, BLOCK_T*H] fp32 + Hk_chunk[BLOCK_S, d_i] bf16). # # Constraint: BLOCK_S must divide K (so INNER = K // BLOCK_S is an integer -# >= 1). Configs whose BLOCK_S exceeds K or doesn't divide K are filtered -# out at lowering time — otherwise jaxlib's autotuner would time them as -# zero-work (fast) and pick a bogus winner that returns all-zero indices. +# >= 1). Configs whose BLOCK_S exceeds K or doesn't divide K must be pruned +# (see _prune_topk_configs) — otherwise the autotuner would time them +# as zero-work (fast) and pick a bogus winner that returns all-zero indices. _SCORE_TOPK_CONFIGS = [ - triton.Config({"BLOCK_S": bs, "BLOCK_T": bt}, num_warps=nw, num_stages=ns) + triton.Config({"BLOCK_S": bs, "BLOCK_T": bt, "waves_per_eu": wpe}, num_warps=nw, num_stages=ns) for bt in (1, 2) for bs in (32, 64, 128, 256) + for wpe in (0, 2, 4) for nw in (4, 8) for ns in (1, 2) ] + [ @@ -528,6 +559,26 @@ def score_reduce_triton(Hq, Hk, W_o, *, out_dtype=None): ] +def _prune_topk_configs(configs, named_args, **kwargs): + """early_config_prune for _score_topk_kernel. Keep only configs where + BLOCK_S divides K (INNER = K//BLOCK_S >= 1) and BLOCK_T divides T_t. The + runtime values arrive in named_args or kwargs depending on call style.""" + vals = {**named_args, **kwargs} + k = vals["K"] + T_t = vals["T_t"] + return [ + c for c in configs + if c.kwargs["BLOCK_S"] <= k + and k % c.kwargs["BLOCK_S"] == 0 + and T_t % c.kwargs["BLOCK_T"] == 0 + ] + + +@triton.autotune( + configs=_SCORE_TOPK_CONFIGS, + key=["H", "d_i", "T_s", "K"], + prune_configs_by={"early_config_prune": _prune_topk_configs}, +) @triton.jit def _score_topk_kernel( Hq_ptr, # (B, oH, T_t, H, d_i) bf16 @@ -705,6 +756,132 @@ def _score_topk_kernel( tl.store(out_ptrs, top_k_idx, mask=rt_mask[:, None]) +# --- Single-sort top-k (for S_PAD that fits in registers) ------------- +# +# The streaming kernel above sorts a 2K buffer N_OUTER = S_PAD/K times. When k +# is a large fraction of T_s (e.g. k = T_s/2), that's several sorts of the 2K +# buffer. If all S_PAD candidates fit in registers, scattering them into one +# BLOCK_T*S_PAD buffer and doing a SINGLE descending sort is ~2x less sort work. +_SINGLE_SORT_MAX = 4096 + + +# Configs for the single-sort kernel. BLOCK_S must divide S_PAD; BLOCK_T must +# divide T_t (pruned below). Tuned winner on gfx950 is BLOCK_T=1, BLOCK_S=128. +_SINGLE_TOPK_CONFIGS = [ + triton.Config({"BLOCK_S": bs, "BLOCK_T": bt, "waves_per_eu": wpe}, num_warps=nw, num_stages=1) + for bs in (64, 128, 256) + for bt in (1, 2) + for wpe in (0, 2, 3, 4) + for nw in (4, 8) +] + + +def _prune_single_topk_configs(configs, named_args, **kwargs): + """early_config_prune for _score_topk_single_kernel. Keep only configs where + BLOCK_S divides S_PAD (the static chunk loop tiles it exactly) and BLOCK_T + divides T_t.""" + vals = {**named_args, **kwargs} + S_PAD = vals["S_PAD"] + T_t = vals["T_t"] + return [ + c for c in configs + if S_PAD % c.kwargs["BLOCK_S"] == 0 + and T_t % c.kwargs["BLOCK_T"] == 0 + ] + + +@triton.autotune( + configs=_SINGLE_TOPK_CONFIGS, + key=["H", "d_i", "T_s", "K"], + prune_configs_by={"early_config_prune": _prune_single_topk_configs}, +) +@triton.jit +def _score_topk_single_kernel( + Hq_ptr, Hk_ptr, W_o_ptr, Topk_idx_ptr, + B: tl.constexpr, oH: tl.constexpr, T_t: tl.constexpr, T_s: tl.constexpr, + H: tl.constexpr, d_i: tl.constexpr, K: tl.constexpr, S_PAD: tl.constexpr, + BLOCK_S: tl.constexpr, BLOCK_T: tl.constexpr, +): + """Like ``_score_topk_kernel`` but holds all S_PAD candidates and sorts once. + + Grid: (cdiv(T_t, BLOCK_T), B * oH). Buffer is BLOCK_T*S_PAD packed uint64 + with T encoded in the high bits (same 1D-sort-groups-per-T trick). Requires + BLOCK_S | S_PAD (so the static chunk loop tiles S_PAD exactly); no BLOCK_S|K + constraint is needed since there is no 2K streaming buffer. + """ + pid_t = tl.program_id(0) + pid_bh = tl.program_id(1) + b = (pid_bh // oH).to(tl.int64) + h_outer = (pid_bh % oH).to(tl.int64) + rh = tl.arange(0, H) + rdi = tl.arange(0, d_i) + rs_chunk = tl.arange(0, BLOCK_S) + rk = tl.arange(0, K) + rt_local = tl.arange(0, BLOCK_T) + rt = pid_t * BLOCK_T + rt_local + rt_64 = rt.to(tl.int64) + rt_mask = rt < T_t + + hq_base = b * (oH * T_t * H * d_i) + h_outer * (T_t * H * d_i) + Hq_token = tl.load( + Hq_ptr + hq_base + rt_64[:, None, None] * (H * d_i) + + rh[None, :, None] * d_i + rdi[None, None, :], + mask=rt_mask[:, None, None], other=0.0) + wo_base = b * (oH * T_t * H) + h_outer * (T_t * H) + w_o = tl.load(W_o_ptr + wo_base + rt_64[:, None] * H + rh[None, :], + mask=rt_mask[:, None], other=0.0).to(tl.float32) + Hq_flat = tl.reshape(Hq_token, (BLOCK_T * H, d_i)) + Hq_T = tl.trans(Hq_flat) + w_o_flat = tl.reshape(w_o, (BLOCK_T * H,)) + hk_base = b * (oH * T_s * d_i) + h_outer * (T_s * d_i) + + N_CHUNK: tl.constexpr = S_PAD // BLOCK_S + BIG: tl.constexpr = BLOCK_T * S_PAD + rb = tl.arange(0, BIG) + rb_t = rb // S_PAD + rb_pos = rb % S_PAD + t_enc_per_slot = (BLOCK_T - rb_t).to(tl.uint64) + top_packed = (t_enc_per_slot << 56) | rb_pos.to(tl.uint64) + + for c in tl.static_range(N_CHUNK): + rs = c * BLOCK_S + rs_chunk + rs_mask = rs < T_s + hk_ptrs = Hk_ptr + hk_base + rs[:, None] * d_i + rdi[None, :] + Hk_chunk = tl.load(hk_ptrs, mask=rs_mask[:, None], other=0.0) + logits = tl.dot(Hk_chunk, Hq_T) + logits = tl.maximum(logits, 0.0) + weighted = logits * w_o_flat[None, :] + weighted_3d = tl.reshape(weighted, (BLOCK_S, BLOCK_T, H)) + chunk_scores = tl.sum(weighted_3d, axis=2) + chunk_scores_T = tl.trans(chunk_scores) + bits = chunk_scores_T.to(tl.uint32, bitcast=True) + sign = bits >> 31 + flip_mask = (0 - sign.to(tl.int32)).to(tl.uint32) | 0x80000000 + sortable = bits ^ flip_mask + sortable = tl.where(rs_mask[None, :], sortable, 0) + t_enc_chunk = (BLOCK_T - rt_local).to(tl.uint64) + rs_2d = tl.broadcast_to(rs[None, :], (BLOCK_T, BLOCK_S)) + chunk_packed_2d = ((t_enc_chunk[:, None] << 56) + | (sortable.to(tl.uint64) << 24) | rs_2d.to(tl.uint64)) + chunk_packed_flat = tl.reshape(chunk_packed_2d, (BLOCK_T * BLOCK_S,)) + chunk_offset = c * BLOCK_S + in_slot = (rb_pos >= chunk_offset) & (rb_pos < chunk_offset + BLOCK_S) + j = rb_pos - chunk_offset + flat_idx = tl.where(in_slot, rb_t * BLOCK_S + j, 0).to(tl.int32) + gathered = tl.gather(chunk_packed_flat, flat_idx, axis=0) + top_packed = tl.where(in_slot, gathered, top_packed) + + top_packed = tl.sort(top_packed, descending=True) # SINGLE sort + out_idx = rt_local[:, None] * S_PAD + rk[None, :] + out_idx_flat = tl.reshape(out_idx, (BLOCK_T * K,)).to(tl.int32) + top_k_packed_flat = tl.gather(top_packed, out_idx_flat, axis=0) + top_k_packed = tl.reshape(top_k_packed_flat, (BLOCK_T, K)) + top_k_idx = (top_k_packed & 0xFFFFFF).to(tl.int32) + out_base = b * (oH * T_t * K) + h_outer * (T_t * K) + out_ptrs = Topk_idx_ptr + out_base + rt_64[:, None] * K + rk[None, :] + tl.store(out_ptrs, top_k_idx, mask=rt_mask[:, None]) + + _score_topk_p = extend_core.Primitive("te_indexer_score_topk_triton") _score_topk_p.multiple_results = True @@ -733,30 +910,13 @@ def _score_topk_lowering(ctx, Hq, Hk, W_o, *, k): T_s = Hk_aval.shape[2] S_PAD = _next_pow2(T_s) - # Build a K-filtered autotuner around the plain JIT kernel. We do this at - # lowering time (rather than decorating the kernel at definition) because - # configs with BLOCK_S > k or BLOCK_S that doesn't divide k would compile - # to a kernel where INNER = k // BLOCK_S = 0 — i.e. a no-op that's fastest - # in the autotune timing race. Filtering ensures the runtime picker only - # sees configs that actually do the work. - # - # Also filter BLOCK_T configs that don't evenly divide T_t — we mask the - # tail but unnecessary padding hurts L1/L2 efficiency. - valid_configs = [ - c for c in _SCORE_TOPK_CONFIGS - if c.kwargs["BLOCK_S"] <= k - and k % c.kwargs["BLOCK_S"] == 0 - and T_t % c.kwargs["BLOCK_T"] == 0 - ] - if not valid_configs: - raise ValueError( - f"No valid BLOCK_S/BLOCK_T config for k={k}, T_t={T_t}" - ) - - autotuned_kernel = triton.autotune( - configs=valid_configs, - key=["H", "d_i", "T_s", "K"], - )(_score_topk_kernel) + # Both kernels are self-autotuned (@triton.autotune at definition); invalid + # configs are dropped by their early_config_prune hooks, so the lowering just + # picks the right kernel and launches. Single-sort path when all S_PAD + # candidates fit in registers (~1.5x faster: one sort instead of S_PAD/K + # streaming sorts of a 2K buffer); streaming kernel for very large T_s. + autotuned_kernel = (_score_topk_single_kernel if S_PAD <= _SINGLE_SORT_MAX + else _score_topk_kernel) def grid_fn(merged_kwargs): bt = merged_kwargs.get("BLOCK_T", 1)