Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions benchmarks/standalone_indexer/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.
"""Shared helpers for the standalone indexer-kernel benchmarks (torch-only)."""

import time

import torch


def make_kernel_inputs(B, oH, T_t, T_s, H, d_i, dtype, device="cuda", seed=0):
"""Build the projected (Hq, Hk, W_o) the kernels actually consume.

The original profilers generated Q/K/W and ran the einsum projections
(C_q, H_q, H_k, W_o) before the kernel. Those projections are plain GEMMs,
not part of the Triton kernels under test, so here we sample the kernel
inputs directly.

Hq: (B, oH, T_t, H, d_i)
Hk: (B, oH, T_s, d_i)
W_o: (B, oH, T_t, H)
"""
g = torch.Generator(device=device).manual_seed(seed)
Hq = torch.randn((B, oH, T_t, H, d_i), dtype=dtype, device=device, generator=g)
Hk = torch.randn((B, oH, T_s, d_i), dtype=dtype, device=device, generator=g)
W_o = torch.randn((B, oH, T_t, H), dtype=dtype, device=device, generator=g)
return Hq, Hk, W_o


def time_fn(fn, n_warmup=15, n_iter=50):
"""Time a no-arg thunk that launches GPU work. Returns seconds/call."""
for _ in range(n_warmup):
fn()
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(n_iter):
fn()
torch.cuda.synchronize()
return (time.perf_counter() - t0) / n_iter
164 changes: 164 additions & 0 deletions benchmarks/standalone_indexer/indexer_bridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.
"""Torch launchers that drive the ACTUAL Triton kernels in the source file
``transformer_engine/jax/triton_extensions/indexer.py``.

The profiling scripts can import their launchers from here instead of from the
self-contained ``indexer_kernels`` copy, so the benchmarks exercise the real
source kernels (forward ``_score_reduce_kernel``, backward
``_score_dscores_chunk_kernel``, top-k ``_score_topk_kernel`` /
``_score_topk_single_kernel``).

Loading trick: importing ``transformer_engine.jax.triton_extensions.indexer``
the normal way runs ``transformer_engine/jax/__init__.py``, which loads the TE
core C library -- broken in a namespace-package / JAX-less setup. The Triton
kernels need none of that, so we ``importlib``-load ``indexer.py`` directly,
pre-stubbing the parent packages (and the ``.utils`` relative import, which is
only used by the JAX lowerings, never by a direct kernel launch). ``jax`` itself
must be importable (it is -- the kernels' module defines JAX primitives at load
time), but no JAX device/compute is used here.
"""

import importlib.util
import os
import sys
import types

import torch
import triton


def _load_origin_indexer():
repo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
idx_path = os.path.join(
repo_root, "transformer_engine", "jax", "triton_extensions", "indexer.py"
)
if not os.path.isfile(idx_path):
raise FileNotFoundError(f"origin indexer.py not found at {idx_path}")

# Stub parent packages so the relative `from .utils import ...` resolves
# WITHOUT running the real transformer_engine.jax.__init__ (broken core lib).
for name in (
"transformer_engine",
"transformer_engine.jax",
"transformer_engine.jax.triton_extensions",
):
if name not in sys.modules or not hasattr(sys.modules[name], "__path__"):
m = types.ModuleType(name)
m.__path__ = [] # mark as a package
sys.modules[name] = m

util_name = "transformer_engine.jax.triton_extensions.utils"
if util_name not in sys.modules:
u = types.ModuleType(util_name)
# Only the JAX lowerings call this; direct kernel launches don't.
u.triton_call_lowering = lambda *a, **k: None
sys.modules[util_name] = u

mod_name = "transformer_engine.jax.triton_extensions.indexer"
spec = importlib.util.spec_from_file_location(mod_name, idx_path)
mod = importlib.util.module_from_spec(spec)
sys.modules[mod_name] = mod
spec.loader.exec_module(mod)
return mod


_src = _load_origin_indexer()

# Re-export the source objects the launchers / profilers reference.
_score_reduce_kernel = _src._score_reduce_kernel # triton Autotuner
_score_dscores_chunk_kernel = _src._score_dscores_chunk_kernel
_score_topk_kernel = _src._score_topk_kernel
_score_topk_single_kernel = _src._score_topk_single_kernel
_next_pow2 = _src._next_pow2
_HBWD_BLOCK_T = _src._HBWD_BLOCK_T
_HBWD_BLOCK_S = _src._HBWD_BLOCK_S
_BWD_H_CHUNK = _src._BWD_H_CHUNK
_SINGLE_SORT_MAX = _src._SINGLE_SORT_MAX
_SCORE_TOPK_CONFIGS = _src._SCORE_TOPK_CONFIGS
_SINGLE_TOPK_CONFIGS = _src._SINGLE_TOPK_CONFIGS

ORIGIN_FILE = _src.__file__


# --- Forward -----------------------------------------------------------------

def score_reduce(Hq, Hk, W_o, out_dtype=None):
assert Hq.ndim == 5 and Hk.ndim == 4 and W_o.ndim == 4
B, oH, T_t, H, d_i = Hq.shape
T_s = Hk.shape[2]
if out_dtype is None:
out_dtype = Hq.dtype
Hq, Hk, W_o = Hq.contiguous(), Hk.contiguous(), W_o.contiguous()
O = torch.empty((B, oH, T_t, T_s), dtype=out_dtype, device=Hq.device)

def grid(meta):
return (triton.cdiv(T_s, meta["BLOCK_S"]),
triton.cdiv(T_t, meta["BLOCK_T"]), B * oH)

_score_reduce_kernel[grid](
Hq, Hk, W_o, O, B=B, oH=oH, T_t=T_t, T_s=T_s, H=H, d_i=d_i,
)
return O


# --- Backward ----------------------------------------------------------------

def bwd_h_chunk(H):
"""H_CHUNK selection identical to the source ``_score_reduce_bwd``."""
if H % _BWD_H_CHUNK == 0:
return _BWD_H_CHUNK
for c in (4, 2):
if H % c == 0:
return c
return 1


def score_dscores_chunk(Hq_chunk, Hk, W_o_chunk, dO):
B, oH, T, H_CHUNK, d_i = Hq_chunk.shape
T_s = dO.shape[-1]

Hq_chunk = Hq_chunk.contiguous()
Hk = Hk.contiguous()
W_o_chunk = W_o_chunk.contiguous()
dO = dO.contiguous()

dscores_chunk = torch.empty(
(B, oH, T, H_CHUNK, T_s), dtype=Hq_chunk.dtype, device=Hq_chunk.device)
dWo_chunk = torch.empty(
(B, oH, T, H_CHUNK), dtype=Hq_chunk.dtype, device=Hq_chunk.device)

def grid(meta):
return ((T + meta["BLOCK_T"] - 1) // meta["BLOCK_T"], B * oH)

_score_dscores_chunk_kernel[grid](
Hq_chunk, Hk, W_o_chunk, dO, dscores_chunk, dWo_chunk,
B=B, oH=oH, T=T, T_s=T_s, H_CHUNK=H_CHUNK, d_i=d_i,
)
return dscores_chunk, dWo_chunk


# --- Top-k -------------------------------------------------------------------

def score_topk(Hq, Hk, W_o, k):
B, oH, T_t, H, d_i = Hq.shape
T_s = Hk.shape[2]
if k <= 0 or (k & (k - 1)) != 0:
raise ValueError(f"k must be a positive power of 2; got {k}")
if k > T_s:
raise ValueError(f"k={k} must be <= T_s={T_s}")
S_PAD = _next_pow2(T_s)
Hq, Hk, W_o = Hq.contiguous(), Hk.contiguous(), W_o.contiguous()
out = torch.empty((B, oH, T_t, k), dtype=torch.int32, device=Hq.device)

def grid(meta):
return (triton.cdiv(T_t, meta["BLOCK_T"]), B * oH)

kernel = (_score_topk_single_kernel if S_PAD <= _SINGLE_SORT_MAX
else _score_topk_kernel)
kernel[grid](
Hq, Hk, W_o, out,
B=B, oH=oH, T_t=T_t, T_s=T_s, H=H, d_i=d_i, K=k, S_PAD=S_PAD,
)
return out
88 changes: 88 additions & 0 deletions benchmarks/standalone_indexer/profile_indexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Profile the standalone forward score-reduce Triton kernel (bf16).

Extracted from ``transformer_engine/jax/triton_extensions/indexer.py`` --
``_score_reduce_kernel`` only, no transformer_engine / jax dependency.

Measures wall time and effective TFLOPS for the fused kernel:

scores = relu(einsum("...thi,...si->...ths", Hq, Hk)) # never written
O = einsum("...ths,...th->...ts", scores, W_o)

Run:
python benchmarks/standalone_indexer/profile_indexer.py
"""

import torch

from _common import make_kernel_inputs, time_fn
# Drive the ACTUAL source kernel in
# transformer_engine/jax/triton_extensions/indexer.py (via the indexer_bridge
# adapter). Swap to `indexer_kernels` to profile the self-contained copy.
from indexer_bridge import score_reduce


# --- FLOP accounting ------------------------------------------------------------

def kernel_flops(B, oH, T, S, H, d_i):
# The kernel does the score matmul + the weighted H-reduction; the four
# projection GEMMs (C_q, H_q, H_k, W_o) are excluded -- they're not in the
# kernel. 2 flops per multiply-add.
# scores = relu(Hq @ Hk^T) : 2 * B*oH * T * H * S * d_i
# O = sum_h scores * W_o : 2 * B*oH * T * S * H
n = B * oH
return 2 * (n * T * H * S * d_i + n * T * S * H)


# --- Reference (torch) for a correctness sanity check ---------------------------

def torch_reference(Hq, Hk, W_o):
# Hq (B,oH,T,H,d_i), Hk (B,oH,S,d_i), W_o (B,oH,T,H) -> O (B,oH,T,S)
scores = torch.einsum("bothi,bosi->boths", Hq.float(), Hk.float())
scores = torch.relu(scores)
O = torch.einsum("boths,both->bots", scores, W_o.float())
return O


# --- Driver ---------------------------------------------------------------------

CONFIGS = [
#(B, oH, T, S, H, d_i)
( 2, 64, 4096, 4096, 64, 128),
]


def check_correctness():
B, oH, T, S, H, d_i = 1, 2, 128, 256, 8, 128
Hq, Hk, W_o = make_kernel_inputs(B, oH, T, S, H, d_i, torch.bfloat16)
O = score_reduce(Hq, Hk, W_o).float()
ref = torch_reference(Hq, Hk, W_o)
rel = (O - ref).norm() / ref.norm().clamp_min(1e-9)
print(f" correctness: rel L2 err vs torch fp32 ref = {rel.item():.4e}")


def main():
print(f"device: {torch.cuda.get_device_name(0)}\n")
check_correctness()
print()
for B, oH, T, S, H, d_i in CONFIGS:
Hq, Hk, W_o = make_kernel_inputs(B, oH, T, S, H, d_i, torch.bfloat16)
flops = kernel_flops(B, oH, T, S, H, d_i)

print(f"--- B={B} oH={oH} T={T} S={S} H={H} d_i={d_i} bfloat16 ---")
print(f" kernel work = {flops/1e9:.2f} GFLOPs/call")

try:
sec = time_fn(lambda: score_reduce(Hq, Hk, W_o))
ms = sec * 1e3
tflops = flops / sec / 1e12
print(f" {'score_reduce':<14} {ms:8.3f} ms {tflops:6.2f} TFLOP/s")
except Exception as e: # noqa: BLE001
print(f" {'score_reduce':<14} FAILED: {type(e).__name__}: {str(e).splitlines()[0]}")

# Report the autotuner-selected config.
from indexer_bridge import _score_reduce_kernel
print(" Best config: ", _score_reduce_kernel.best_config)


if __name__ == "__main__":
main()
Loading