From 3e96c7e905892eda97e1227de0e516e3e4d08a6f Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 18 Jun 2026 18:45:53 +0000 Subject: [PATCH 1/4] Initial commit w/ interleaving and cold-caching --- benchmarks/asv/README.md | 265 +++++++++++ benchmarks/asv/__init__.py | 0 benchmarks/asv/asv.conf.json | 16 + benchmarks/asv/bench_attention.py | 102 +++++ benchmarks/asv/bench_casting.py | 100 +++++ benchmarks/asv/bench_gemm.py | 99 +++++ benchmarks/asv/bench_gemm_fp8.py | 104 +++++ benchmarks/asv/bench_grouped_gemm.py | 94 ++++ benchmarks/asv/bench_normalization.py | 83 ++++ benchmarks/asv/compare_results.py | 143 ++++++ benchmarks/asv/driver.py | 613 ++++++++++++++++++++++++++ benchmarks/asv/parser_TEasv.py | 172 ++++++++ benchmarks/asv/requirements.txt | 3 + benchmarks/asv/run_benchmarks.sh | 52 +++ 14 files changed, 1846 insertions(+) create mode 100644 benchmarks/asv/README.md create mode 100644 benchmarks/asv/__init__.py create mode 100644 benchmarks/asv/asv.conf.json create mode 100644 benchmarks/asv/bench_attention.py create mode 100644 benchmarks/asv/bench_casting.py create mode 100644 benchmarks/asv/bench_gemm.py create mode 100644 benchmarks/asv/bench_gemm_fp8.py create mode 100644 benchmarks/asv/bench_grouped_gemm.py create mode 100644 benchmarks/asv/bench_normalization.py create mode 100644 benchmarks/asv/compare_results.py create mode 100644 benchmarks/asv/driver.py create mode 100644 benchmarks/asv/parser_TEasv.py create mode 100644 benchmarks/asv/requirements.txt create mode 100755 benchmarks/asv/run_benchmarks.sh diff --git a/benchmarks/asv/README.md b/benchmarks/asv/README.md new file mode 100644 index 000000000..eee8900e8 --- /dev/null +++ b/benchmarks/asv/README.md @@ -0,0 +1,265 @@ +# Benchmarks for TransformerEngine + +GPU microbenchmarks driven by `driver.py`. Results are written in +[ASV (Air Speed Velocity)](https://asv.readthedocs.io/) JSON format so they +can be browsed with `asv publish` / `asv preview`, but the `asv` CLI is **not** +used to run benchmarks — `driver.py` runs everything in-process. + +## Prerequisites + +- TransformerEngine must already be built and installed in the current Python environment. +- A ROCm or CUDA GPU must be available. +- `asv` is only required if you want the HTML dashboard (`pip install asv`). + +## Running benchmarks + +Each `bench_*.py` file is directly executable, or you can drive them through +`driver.py`. Results are saved to `benchmarks/.asv/results/` in ASV-compatible +format by default. + +```bash +cd benchmarks/asv +python driver.py --all # run every suite +python driver.py bench_gemm # run one suite via driver +python bench_gemm.py # run one suite directly +python bench_gemm.py time_forward # filter to a specific method +python bench_gemm.py -w 5 -n 20 # custom warmup/iteration counts +python bench_casting.py --no-save # skip saving results +python bench_casting.py --cold-cache # flush cache before each sample +python bench_gemm.py --inner 50 # fix inner-loop count to 50 +python bench_gemm.py --target-window-ms 5 # tune inner so each window >=5 ms +``` + +### Timing model: inner loop and cache state + +Each `time_*` method runs the kernel `_inner` times inside a single CUDA event +window and divides by `_inner`, so kernel-launch and CUDA-event jitter +(`~0.5 µs` resolution on AMD) are amortized. By default the driver +**auto-tunes** `_inner` per (combo, method) so each window lasts at least +`--target-window-ms` (default `1.0 ms`): + +| Flag | Effect | +|---|---| +| `--inner auto` (default) | Probe a single invocation, then pick `_inner` so the next timed window lasts ≥ `--target-window-ms`. Capped at 10000. | +| `--inner N` | Force a fixed `_inner = N` (overrides auto-tune). | +| `--target-window-ms T` | Target window duration for `--inner auto` (default `1.0`). | +| `--cold-cache` | Write a `--cache-flush-mb` byte scratch buffer before each sample to evict L2 + Infinity Cache. Implies `--inner=1` (otherwise iterations 2..N would refill the cache and the measurement degenerates back to warm-cache). | +| `--cache-flush-mb M` | Scratch buffer size for `--cold-cache` (default `256`, sized for the MI300 Infinity Cache). | + +Choose the regime that matches the question you're asking: +- **Warm cache, large `_inner`** (default): steady-state kernel throughput, + matches what a hot inner loop in a model sees. Lowest variance. +- **Cold cache, `_inner=1`**: realistic cost of the kernel as an isolated + call into cold memory — closer to what `rocprofv3 --hip-trace` reports + on a freshly launched kernel. Higher variance; bandwidth-bound + benchmarks (cast, normalization) typically run 1.5–3× slower than warm. + +Caveat: the inner loop runs in Python, so each iteration carries +~80–200 ns of interpreter overhead. For sub-microsecond kernels this is +not removable without CUDA graph capture; pick `--inner` deliberately +in that regime or use the cold-cache mode. + +### Sample scheduling: interleaving + +By default the driver does **not** collect a benchmark's samples in one +contiguous block. It samples in round-robin chunks: it sets up a group of +`(method, combo)` benchmarks, then takes one sample from each per round, for +`-n` rounds. This is on by default because *sequential* scheduling (all of A, +then all of B) makes wall-clock time a proxy for benchmark identity — so any +time-correlated GPU noise (thermal warm-up ramp, DVFS throttle, a neighbor +container on a shared GPU) becomes a systematic **bias** between benchmarks +rather than noise. The Monte-Carlo study in `repro/transient_noise_sim.py` +quantifies it: under a 5% thermal ramp a sequential Brunner-Munzel comparison +fires a false positive 86% of the time (α=0.05), and a 20% ramp can flip a real +5% speedup into a reported regression. Round-robin sampling spreads every +benchmark across the same window, so a transient lands on one sample of each +instead of corrupting one benchmark's whole block. + +The per-round visit order is also **randomly permuted** each round (a balanced +randomized design, not a global shuffle). Fixed round-robin would still pin each +benchmark to a constant phase within the round — so a monotonic ramp leaves a +small constant per-benchmark offset, and each benchmark always sees the same +predecessor's cache/clock state. Re-permuting each round makes both uniform in +expectation, turning that residual bias into variance. The shuffle is seeded +(`--seed`, default `0`) so runs stay reproducible. + +| Flag | Effect | +|---|---| +| `--interleave-group N` (default `8`) | Number of benchmarks sampled round-robin together. Each keeps a live GPU instance for the duration of the chunk, so **lower this if a group runs out of memory**; raise it to share the time window across more benchmarks. | +| `--sequential` | Collect each benchmark's samples contiguously (≡ `--interleave-group 1`). Lowest memory, but biased under thermal drift — use only for quick local runs. | +| `--seed S` (default `0`) | Seed for the per-round shuffle, fixed so runs are reproducible. | +| `--no-shuffle` | Use a fixed round-robin order instead of permuting each round. Leaves a small residual ordering/predecessor bias; mainly for debugging. | + +Caveat: interleaving removes *within-run* time-position bias. It does **not** +remove a whole-run thermal offset between two **separately produced** result +files (e.g. a cold baseline run vs. a warm candidate run). For the statistical +comparison below, produce the baseline and candidate result files back-to-back +under similar conditions. + +### Helper script + +`run_benchmarks.sh` wraps common tasks and can be run from anywhere. + +```bash +bash benchmarks/asv/run_benchmarks.sh [options] +``` + +| Command | Description | +|---|---| +| `run [suite] [method]` | Run benchmarks in-process (saves ASV-compatible results) | +| `view` | Build the ASV HTML dashboard from saved results and serve it on `localhost:8080` | +| `list` | List available benchmark suites | +| `compare BASE CAND` | Statistically compare two result JSONs (exits 1 on a significant regression) | + +## How results are stored + +ASV-format JSON files under `benchmarks/.asv/results/`: + +``` +benchmarks/.asv/results/ + my-machine-name/ + machine.json # Hardware/OS metadata (auto-generated by driver) + .json # Timing results for that commit + .json + ... +``` + +Each commit JSON contains the wall-clock timings for every benchmark + parameter combination +run on that machine, including the raw per-call samples (the ASV `samples` +column) used by `compare_results.py`. The `benchmarks/.asv/` directory is in +`.gitignore`. + +## Viewing results + +To browse historical results in a dashboard, point `asv` at the saved JSON: + +```bash +bash benchmarks/asv/run_benchmarks.sh view +# or, manually: +asv publish --config benchmarks/asv/asv.conf.json +asv preview --config benchmarks/asv/asv.conf.json +``` + +`asv.conf.json` exists only to support `publish` / `preview`; benchmarks +themselves are not invoked through `asv`. + +## Comparing two checkouts statistically + +The dashboard plots point estimates (medians), which cannot tell a real +regression from measurement noise. To test whether timing differences between +two checkouts are statistically significant, the driver records the raw per-call +samples in each result file (the ASV `samples` column), and `compare_results.py` +compares them with a Brunner-Munzel test via the +[benchstats](https://github.com/Arech/benchstats) package: + +```bash +pip install -r requirements.txt # benchstats (pulls rich, scipy, numpy) + +cd benchmarks/asv + +# baseline checkout — saves -.json +python driver.py --all -n 20 +# candidate checkout — saves -.json +python driver.py --all -n 20 + +python compare_results.py \ + ../.asv/results//-.json \ + ../.asv/results//-.json +``` + +It prints a table marking each `(benchmark, parameter combination)` as faster +(`<`), slower (`>`), or not significantly different (`~`), and exits `1` when a +significant difference is found, so it can gate CI. + +By default the result filename is derived from the commit hash, so two runs on +the **same** commit (e.g. prototyping against a dirty working tree, where `HEAD` +is unchanged) would overwrite each other. Pass `--label` to fold a tag into the +filename and keep them distinct: + +```bash +python driver.py --all -n 20 --label base # -> -base-.json +# ... edit code (HEAD stays the same) ... +python driver.py --all -n 20 --label cand # -> -cand-.json + +python compare_results.py \ + ../.asv/results//-base-.json \ + ../.asv/results//-cand-.json +``` + +| Flag | Effect | +|---|---| +| `--alpha A` | Significance level for the test (default `0.001`). | +| `--method M` | Statistical test to use (default `brunnermunzel`). | +| `--filter REGEX` | Only compare benchmarks whose name matches `REGEX`. | +| `--always-show-pvalues` | Show p-values for non-significant rows too. | +| `--export-to FILE` | Save the report to a `.txt`/`.svg`/`.html` file. | + +The test is rank-based and needs a reasonable number of samples per benchmark +(≥ ~10 recommended); the default `-n 20` timed iterations satisfies this. Only +timing is tested — throughput (`TFLOPS`/`GB/s`) is a constant-work transform of +time, so a rank test on it is identical; the driver already prints throughput +columns during a run. + +## Writing new benchmarks + +Create a new file in `benchmarks/asv/` following the naming convention `bench_.py`. + +```python +#!/usr/bin/env python3 +import torch +import transformer_engine.pytorch as te + +class BenchSomething: + params = [[1024, 4096], ["config_a", "config_b"]] + param_names = ["M", "config"] + timeout = 300 # seconds, per parameter combination + + # Driver overrides per (combo, method): _inner controls how many kernel + # invocations land in one CUDA event window; _scratch (when not None) is + # written to before each sample to evict the GPU cache. + _inner = 1 + _scratch = None + + def setup(self, M, config): + # Allocate tensors, create modules. + # This runs once per (combo, method); the same instance is reused for + # warmup and timed iterations. + self._evt = [torch.cuda.Event(enable_timing=True) for _ in range(2)] + ... + + def time_forward(self, M, config): + # Use CUDA events for accurate GPU timing. + # Return elapsed seconds per single invocation — the driver uses this + # instead of wall time. Looping inside the event window amortizes + # CUDA event resolution and kernel-launch overhead. + if self._scratch is not None: + self._scratch.fill_(1.0) # cold-cache mode + self._evt[0].record() + for _ in range(self._inner): + self.module(self.x) + self._evt[1].record() + torch.cuda.synchronize() + return self._evt[0].elapsed_time(self._evt[1]) / 1000 / self._inner + + # Optional: define work_ to get throughput columns (TFLOPS / GB/s). + def work_forward(self, M, config): + return {"flops": 2 * M * self.N * self.K} # compute-bound + # return {"bytes": M * self.hidden * 4} # memory-bound + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) +``` + +Key rules: +- Method names starting with `time_` are automatically timed. +- Use CUDA events and return elapsed seconds **per single invocation** — + divide the event delta by `self._inner` so the driver and the throughput + columns get per-call values regardless of inner-loop count. +- Honor `self._inner` (loop the kernel) and `self._scratch` (write before + recording the start event) so the driver's `--inner` and `--cold-cache` + flags work for your benchmark. +- Optionally define `work_` companions to get TFLOPS or GB/s columns. + These return the per-call work, not per-window work. +- Clear `.grad` attributes in backward benchmarks to prevent memory accumulation. +- The `params` list defines a cross-product; keep the matrix size reasonable. diff --git a/benchmarks/asv/__init__.py b/benchmarks/asv/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/benchmarks/asv/asv.conf.json b/benchmarks/asv/asv.conf.json new file mode 100644 index 000000000..3c1616aac --- /dev/null +++ b/benchmarks/asv/asv.conf.json @@ -0,0 +1,16 @@ +{ + "version": 1, + "project": "TransformerEngine", + "project_url": "https://github.com/ROCm/TransformerEngine", + "repo": "../..", + "branches": ["HEAD"], + "environment_type": "existing", + "install_command": [], + "build_command": [], + "benchmark_dir": ".", + "results_dir": "../.asv/results", + "html_dir": "../.asv/html", + "install_timeout": 600, + "benchmark_timeout": 1200, + "launch_method": "spawn" +} diff --git a/benchmarks/asv/bench_attention.py b/benchmarks/asv/bench_attention.py new file mode 100644 index 000000000..df6314c43 --- /dev/null +++ b/benchmarks/asv/bench_attention.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +Attention micro-benchmark using te.DotProductAttention. + +Benchmarks fused multi-head attention (with flash attention backend) for +model configurations with grouped-query attention (GQA). + +Models: + - Llama 3 8B (TP=1, TP=8), 70B (TP=8), 405B (TP=8) + - Qwen 2.5 7B (TP=1), 72B (TP=8) + +Forward FLOPs = 4 * batch * num_q_heads * seq_len^2 * head_dim + (two matmuls: Q@K^T and attn@V, each contributing 2*b*h*s^2*d) +Backward FLOPs = 2 * Forward FLOPs (approximately) + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json +""" + +import torch +import transformer_engine.pytorch as te + +BATCH = 2 + +# (num_q_heads, num_kv_heads, head_dim, tp) +MODELS = { + "Llama3-8B_TP1": (32, 8, 128, 1), + "Llama3-8B_TP8": (32, 8, 128, 8), + "Llama3-70B_TP8": (64, 8, 128, 8), + "Llama3-405B_TP8": (128, 8, 128, 8), + "Qwen2.5-7B_TP1": (28, 4, 128, 1), + "Qwen2.5-72B_TP8": (64, 8, 128, 8), +} + + +class BenchAttention: + params = [[1024, 2048, 4096, 8192], list(MODELS)] + param_names = ["seq_len", "model"] + timeout = 300 + _inner = 1 + _scratch = None + + def setup(self, seq_len, model): + n_q, n_kv, hd, tp = MODELS[model] + qh, kvh = n_q // tp, n_kv // tp + dtype = torch.bfloat16 + + self.attn = te.DotProductAttention( + num_attention_heads=qh, kv_channels=hd, + num_gqa_groups=kvh, attn_mask_type="causal", + ).to(device="cuda", dtype=dtype) + + self.q = torch.randn(seq_len, BATCH, qh, hd, dtype=dtype, device="cuda", requires_grad=True) + self.k = torch.randn(seq_len, BATCH, kvh, hd, dtype=dtype, device="cuda", requires_grad=True) + self.v = torch.randn(seq_len, BATCH, kvh, hd, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn_like(self.attn(self.q, self.k, self.v)) + self._evt = [torch.cuda.Event(enable_timing=True) for _ in range(2)] + + def work_forward(self, seq_len, model): + n_q, n_kv, hd, tp = MODELS[model] + qh = n_q // tp + return {"flops": 4 * BATCH * qh * seq_len * seq_len * hd} + + def work_forward_backward(self, seq_len, model): + n_q, n_kv, hd, tp = MODELS[model] + qh = n_q // tp + return {"flops": 3 * 4 * BATCH * qh * seq_len * seq_len * hd} + + def time_forward(self, seq_len, model): + if self._scratch is not None: + self._scratch.fill_(1.0) + self._evt[0].record() + for _ in range(self._inner): + self.attn(self.q, self.k, self.v) + self._evt[1].record() + torch.cuda.synchronize() + return self._evt[0].elapsed_time(self._evt[1]) / 1000 / self._inner + + def time_forward_backward(self, seq_len, model): + if self._scratch is not None: + self._scratch.fill_(1.0) + self._evt[0].record() + for _ in range(self._inner): + out = self.attn(self.q, self.k, self.v) + out.backward(self.grad_out) + self._evt[1].record() + torch.cuda.synchronize() + self.q.grad = self.k.grad = self.v.grad = None + return self._evt[0].elapsed_time(self._evt[1]) / 1000 / self._inner + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) diff --git a/benchmarks/asv/bench_casting.py b/benchmarks/asv/bench_casting.py new file mode 100644 index 000000000..713aa498e --- /dev/null +++ b/benchmarks/asv/bench_casting.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +Benchmarks quantization (BF16 -> FP8) and dequantization (FP8 -> BF16) for +both E4M3 (activations/weights) and E5M2 (gradients) formats. + +Shapes are (M, hidden_size) matching the activation tensors from models: + - Llama 3.1 8B, 70B, 405B + - Qwen 2.5 7B, 72B + +These casts are memory-bound; we report GB/s (input + output bytes). + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json +""" + +import torch +from transformer_engine.pytorch import Float8CurrentScalingQuantizer +from transformer_engine_torch import DType as TE_DType + +HIDDEN_SIZES = { + "Llama3-8B": 4096, + "Llama3-70B": 8192, + "Llama3-405B": 16384, + "Qwen2.5-7B": 3584, + "Qwen2.5-72B": 8192, +} + +CAST_CONFIGS = { + "BF16_to_E4M3": ("quantize", TE_DType.kFloat8E4M3), + "E4M3_to_BF16": ("dequantize", TE_DType.kFloat8E4M3), + "BF16_to_E5M2": ("quantize", TE_DType.kFloat8E5M2), + "E5M2_to_BF16": ("dequantize", TE_DType.kFloat8E5M2), +} + + +class BenchCasting: + params = [[1024, 2048, 4096, 8192], list(HIDDEN_SIZES), list(CAST_CONFIGS)] + param_names = ["M", "model", "cast"] + timeout = 120 + # Driver overrides these per (combo, method): _inner is the number of + # kernel invocations per CUDA event window (amortizes launch overhead); + # _scratch, when not None, is fill_()ed before each sample to evict the + # GPU cache. + _inner = 1 + _scratch = None + + def setup(self, M, model, cast): + hidden = HIDDEN_SIZES[model] + direction, fp8_dtype = CAST_CONFIGS[cast] + self.direction = direction + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=fp8_dtype, + device=torch.device("cuda"), + rowwise=True, + columnwise=False, + ) + if direction == "dequantize": + bf16_tensor = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") + self.x = quantizer.quantize(bf16_tensor) + else: + self.x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") + self.quantizer = quantizer + self._evt = [torch.cuda.Event(enable_timing=True) for _ in range(2)] + + def work_cast(self, M, model, cast): + hidden = HIDDEN_SIZES[model] + direction = CAST_CONFIGS[cast][0] + if direction == "quantize": + # Read BF16 (2B) + write FP8 (1B) + write scale + return {"bytes": M * hidden * 3} + else: + # Read FP8 (1B) + read scale + write BF16 (2B) + return {"bytes": M * hidden * 3} + + def time_cast(self, M, model, cast): + if self._scratch is not None: + self._scratch.fill_(1.0) + self._evt[0].record() + if self.direction == "quantize": + for _ in range(self._inner): + self.quantizer.quantize(self.x) + else: + for _ in range(self._inner): + self.x.dequantize(dtype=torch.bfloat16) + self._evt[1].record() + torch.cuda.synchronize() + return self._evt[0].elapsed_time(self._evt[1]) / 1000 / self._inner + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) diff --git a/benchmarks/asv/bench_gemm.py b/benchmarks/asv/bench_gemm.py new file mode 100644 index 000000000..b1ad40f99 --- /dev/null +++ b/benchmarks/asv/bench_gemm.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""BF16 GEMM benchmarks via te.Linear. + +GEMM shapes derived from transformer layer projections: + QKV, AttnOut, GateUp (SwiGLU), Down. + +Model configuration sources: +- Llama 3 8B (hidden=4096, intermediate=14336, heads=32, kv_heads=8, head_dim=128) + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + +- Llama 3 70B (hidden=8192, intermediate=28672, heads=64, kv_heads=8, head_dim=128) + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + +- Llama 3 405B (hidden=16384, intermediate=53248, heads=128, kv_heads=8, head_dim=128) + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + +- Qwen 2.5 7B (hidden=3584, intermediate=18944, heads=28, kv_heads=4, head_dim=128) + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + +- Qwen 2.5 72B (hidden=8192, intermediate=29568, heads=64, kv_heads=8, head_dim=128) + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json + """ + +import torch +import transformer_engine.pytorch as te + +# (hidden, intermediate, num_q_heads, num_kv_heads, head_dim, tp) +MODELS = { + "Llama3-8B_TP1": (4096, 14336, 32, 8, 128, 1), + "Llama3-8B_TP8": (4096, 14336, 32, 8, 128, 8), + "Llama3-70B_TP8": (8192, 28672, 64, 8, 128, 8), + "Llama3-405B_TP8": (16384, 53248, 128, 8, 128, 8), + "Qwen2.5-7B_TP1": (3584, 18944, 28, 4, 128, 1), + "Qwen2.5-72B_TP8": (8192, 29568, 64, 8, 128, 8), +} + +# Pre-compute (N, K) for each GEMM shape +SHAPES = {} +for _name, (h, inter, nq, nkv, hd, tp) in MODELS.items(): + SHAPES[f"{_name}-QKV"] = ((nq * hd + 2 * nkv * hd) // tp, h) + SHAPES[f"{_name}-AttnOut"] = (h, (nq * hd) // tp) + SHAPES[f"{_name}-GateUp"] = ((2 * inter) // tp, h) + SHAPES[f"{_name}-Down"] = (h, inter // tp) + + +class BenchGemm: + params = [[1024, 2048, 4096, 8192], list(SHAPES)] + param_names = ["M", "shape"] + timeout = 300 + _inner = 1 + _scratch = None + + def setup(self, M, shape): + N, K = SHAPES[shape] + dtype = torch.bfloat16 + self.linear = te.Linear(K, N, bias=False).to(device="cuda", dtype=dtype) + self.x = torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn_like(self.linear(self.x)) + self._evt = [torch.cuda.Event(enable_timing=True) for _ in range(2)] + + def work_forward(self, M, shape): + N, K = SHAPES[shape] + return {"flops": 2 * M * N * K} + + def work_forward_backward(self, M, shape): + N, K = SHAPES[shape] + return {"flops": 3 * 2 * M * N * K} + + def time_forward(self, M, shape): + if self._scratch is not None: + self._scratch.fill_(1.0) + self._evt[0].record() + for _ in range(self._inner): + self.linear(self.x) + self._evt[1].record() + torch.cuda.synchronize() + return self._evt[0].elapsed_time(self._evt[1]) / 1000 / self._inner + + def time_forward_backward(self, M, shape): + if self._scratch is not None: + self._scratch.fill_(1.0) + self._evt[0].record() + for _ in range(self._inner): + out = self.linear(self.x) + out.backward(self.grad_out) + self._evt[1].record() + torch.cuda.synchronize() + self.x.grad = None + self.linear.weight.grad = None + return self._evt[0].elapsed_time(self._evt[1]) / 1000 / self._inner + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) diff --git a/benchmarks/asv/bench_gemm_fp8.py b/benchmarks/asv/bench_gemm_fp8.py new file mode 100644 index 000000000..8728695e4 --- /dev/null +++ b/benchmarks/asv/bench_gemm_fp8.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +FP8 GEMM benchmarks via te.Linear under fp8_autocast. + +Same shapes as bench_gemm.py but with FP8 quantized compute: + - Llama 3 8B (TP=1, TP=8), 70B (TP=8), 405B (TP=8) + - Qwen 2.5 7B (TP=1), 72B (TP=8) + +Each model contributes four GEMM shapes: + QKV projection (column-parallel) N = (Qheads + 2*KVheads)*head_dim / TP, K = hidden + Attention output (row-parallel) N = hidden, K = Qheads*head_dim / TP + MLP Gate+Up (column-parallel) N = 2*intermediate / TP, K = hidden (SwiGLU) + MLP Down (row-parallel) N = hidden, K = intermediate / TP + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json +""" + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import DelayedScaling, Format + +# (hidden, intermediate, num_q_heads, num_kv_heads, head_dim, tp) +MODELS = { + "Llama3-8B_TP1": (4096, 14336, 32, 8, 128, 1), + "Llama3-8B_TP8": (4096, 14336, 32, 8, 128, 8), + "Llama3-70B_TP8": (8192, 28672, 64, 8, 128, 8), + "Llama3-405B_TP8": (16384, 53248, 128, 8, 128, 8), + "Qwen2.5-7B_TP1": (3584, 18944, 28, 4, 128, 1), + "Qwen2.5-72B_TP8": (8192, 29568, 64, 8, 128, 8), +} + +SHAPES = {} +for _name, (h, inter, nq, nkv, hd, tp) in MODELS.items(): + SHAPES[f"{_name}-QKV"] = ((nq * hd + 2 * nkv * hd) // tp, h) + SHAPES[f"{_name}-AttnOut"] = (h, (nq * hd) // tp) + SHAPES[f"{_name}-GateUp"] = ((2 * inter) // tp, h) + SHAPES[f"{_name}-Down"] = (h, inter // tp) + +FP8_RECIPE = DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max", +) + + +class BenchGemmFP8: + params = [[1024, 2048, 4096, 8192], list(SHAPES)] + param_names = ["M", "shape"] + timeout = 300 + _inner = 1 + _scratch = None + + def setup(self, M, shape): + N, K = SHAPES[shape] + dtype = torch.bfloat16 + self.linear = te.Linear(K, N, bias=False).to(device="cuda", dtype=dtype) + self.x = torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn(M, N, dtype=dtype, device="cuda") + self._evt = [torch.cuda.Event(enable_timing=True) for _ in range(2)] + + def work_forward(self, M, shape): + N, K = SHAPES[shape] + return {"flops": 2 * M * N * K} + + def work_forward_backward(self, M, shape): + N, K = SHAPES[shape] + return {"flops": 3 * 2 * M * N * K} + + def time_forward(self, M, shape): + if self._scratch is not None: + self._scratch.fill_(1.0) + self._evt[0].record() + with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + for _ in range(self._inner): + self.linear(self.x) + self._evt[1].record() + torch.cuda.synchronize() + return self._evt[0].elapsed_time(self._evt[1]) / 1000 / self._inner + + def time_forward_backward(self, M, shape): + if self._scratch is not None: + self._scratch.fill_(1.0) + self._evt[0].record() + for _ in range(self._inner): + with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + out = self.linear(self.x) + out.backward(self.grad_out) + self._evt[1].record() + torch.cuda.synchronize() + self.x.grad = None + self.linear.weight.grad = None + return self._evt[0].elapsed_time(self._evt[1]) / 1000 / self._inner + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) diff --git a/benchmarks/asv/bench_grouped_gemm.py b/benchmarks/asv/bench_grouped_gemm.py new file mode 100644 index 000000000..199f651c6 --- /dev/null +++ b/benchmarks/asv/bench_grouped_gemm.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""Grouped GEMM benchmarks via te.GroupedLinear. + +MoE model configurations with GateUp and Down projections. +Configurations are based on: +https://github.com/AMD-AGI/Primus-Turbo/blob/main/benchmark/ops/config.py +""" + +import torch +import transformer_engine.pytorch as te + +# (n_routed_experts, moe_intermediate_size, hidden_size) +MOE_MODELS = { + "DSV2-Lite": (64, 1408, 2048), + "DSV2": (160, 1536, 5120), + "DSV3": (256, 2048, 7168), + "Grok-V2": (8, 16384, 8192), +} + +# Build (config_key -> (num_gemms, N, K)) mapping +CONFIGS = {} +for model, (n_experts, inter, hidden) in MOE_MODELS.items(): + for ep in [32, 16, 8]: + if n_experts % ep != 0: + continue + B = n_experts // ep + CONFIGS[f"{model}_EP{ep}-GateUp"] = (B, 2 * inter, hidden) + CONFIGS[f"{model}_EP{ep}-Down"] = (B, hidden, inter) + + +class BenchGroupedGemm: + params = [[512, 1024, 2048, 4096], list(CONFIGS)] + param_names = ["M", "config"] + timeout = 300 + _inner = 1 + _scratch = None + + def setup(self, M, config): + B, N, K = CONFIGS[config] + dtype = torch.bfloat16 + + self.module = te.GroupedLinear( + num_gemms=B, in_features=K, out_features=N, bias=False, + ).to(device="cuda", dtype=dtype) + + self.xs = [ + torch.randn(M, K, dtype=dtype, device="cuda", requires_grad=True) + for _ in range(B) + ] + outs = self.module(self.xs) + self.grad_outs = [torch.randn_like(o) for o in outs] + self._evt = [torch.cuda.Event(enable_timing=True) for _ in range(2)] + + def work_forward(self, M, config): + B, N, K = CONFIGS[config] + return {"flops": B * 2 * M * N * K} + + def work_forward_backward(self, M, config): + B, N, K = CONFIGS[config] + return {"flops": B * 3 * 2 * M * N * K} + + def time_forward(self, M, config): + if self._scratch is not None: + self._scratch.fill_(1.0) + self._evt[0].record() + for _ in range(self._inner): + self.module(self.xs) + self._evt[1].record() + torch.cuda.synchronize() + return self._evt[0].elapsed_time(self._evt[1]) / 1000 / self._inner + + def time_forward_backward(self, M, config): + if self._scratch is not None: + self._scratch.fill_(1.0) + self._evt[0].record() + for _ in range(self._inner): + outs = self.module(self.xs) + torch.autograd.backward(outs, self.grad_outs) + self._evt[1].record() + torch.cuda.synchronize() + for x in self.xs: + x.grad = None + for p in self.module.parameters(): + p.grad = None + return self._evt[0].elapsed_time(self._evt[1]) / 1000 / self._inner + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) diff --git a/benchmarks/asv/bench_normalization.py b/benchmarks/asv/bench_normalization.py new file mode 100644 index 000000000..2b3608bac --- /dev/null +++ b/benchmarks/asv/bench_normalization.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +""" +RMSNorm and LayerNorm benchmarks on activation-sized tensors. + +Shapes are derived from training workloads: + - Llama 3 8B, 70B, 405B (all use RMSNorm) + - Qwen 2.5 7B, 72B (all use RMSNorm) + +Modern models predominantly use RMSNorm, but we benchmark both +LayerNorm and RMSNorm since TE supports both and they share the +same kernel infrastructure. + +The M dimension (batch * seq_len) is swept across typical training sizes. + +Sources for model configs: + https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json + https://huggingface.co/meta-llama/Llama-3.1-405B/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json + https://huggingface.co/Qwen/Qwen2.5-72B-Instruct/blob/main/config.json +""" + +import torch +import transformer_engine.pytorch as te + +NORMS = {"RMSNorm": te.RMSNorm, "LayerNorm": te.LayerNorm} +HIDDEN_SIZES = [3584, 4096, 8192, 16384] + + +class BenchNormalization: + params = [[1024, 2048, 4096, 8192], HIDDEN_SIZES, list(NORMS)] + param_names = ["M", "hidden", "norm_type"] + timeout = 120 + _inner = 1 + _scratch = None + + def setup(self, M, hidden, norm_type): + dtype = torch.bfloat16 + self.norm = NORMS[norm_type](hidden).to(device="cuda", dtype=dtype) + self.x = torch.randn(M, hidden, dtype=dtype, device="cuda", requires_grad=True) + self.grad_out = torch.randn_like(self.norm(self.x)) + self._evt = [torch.cuda.Event(enable_timing=True) for _ in range(2)] + + def work_forward(self, M, hidden, norm_type): + # Read input (2B) + write output (2B) = 4 bytes per element + return {"bytes": M * hidden * 4} + + def work_forward_backward(self, M, hidden, norm_type): + # Fwd: read+write (4B), Bwd: read input+grad_out+write grad_in (6B) = 10B + return {"bytes": M * hidden * 10} + + def time_forward(self, M, hidden, norm_type): + if self._scratch is not None: + self._scratch.fill_(1.0) + self._evt[0].record() + for _ in range(self._inner): + self.norm(self.x) + self._evt[1].record() + torch.cuda.synchronize() + return self._evt[0].elapsed_time(self._evt[1]) / 1000 / self._inner + + def time_forward_backward(self, M, hidden, norm_type): + if self._scratch is not None: + self._scratch.fill_(1.0) + self._evt[0].record() + for _ in range(self._inner): + out = self.norm(self.x) + out.backward(self.grad_out) + self._evt[1].record() + torch.cuda.synchronize() + self.x.grad = None + for p in self.norm.parameters(): + p.grad = None + return self._evt[0].elapsed_time(self._evt[1]) / 1000 / self._inner + +if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) diff --git a/benchmarks/asv/compare_results.py b/benchmarks/asv/compare_results.py new file mode 100644 index 000000000..c1313e1a2 --- /dev/null +++ b/benchmarks/asv/compare_results.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""Statistically compare two ASV result JSON files written by ``driver.py``. + +The point-estimate timings in the ASV dashboard cannot tell a real regression +from measurement noise. This tool compares the raw per-call samples stored in +two result files (one per checkout) using a statistical test (Brunner-Munzel by +default) via the benchstats package. It marks each (benchmark, parameter +combination) as faster (``<``), slower (``>``), or not significantly different +(``~``) and exits ``1`` when a significant timing difference is found, so it can +gate CI. A summary line reports how many benchmarks were significantly faster, +significantly slower, or unchanged. Requires ``pip install -r requirements.txt``. + +Usage: + # run the suite on the baseline checkout, then on the candidate checkout, + # pointing each at its own results file, then: + python compare_results.py baseline.json candidate.json + python compare_results.py baseline.json candidate.json --alpha 0.01 + python compare_results.py baseline.json candidate.json --export-to report.svg +""" + +import argparse +import os +import sys + + +def run_stats(args): + """Compare two ASV result JSONs with a statistical test via benchstats. + + Returns a process exit code: 1 if a significant difference is found in the + timing metric, else 0. + """ + import rich.table # noqa: F401 benchstats 3.4.0 render uses rich.table.Table without importing it + from parser_TEasv import parser_TEasv + from benchstats.compare import compareStats + from benchstats.render import renderComparisonResults + from benchstats.common import LoggingConsole, detectExportFormat + + main_metrics = ["time_s"] + + export_fmt = detectExportFormat(args.export_to, None) if args.export_to else None + if export_fmt is not None and os.path.isfile(args.export_to): + os.remove(args.export_to) + + console = LoggingConsole( + record=export_fmt is not None, + log_level=LoggingConsole.LogLevel.Warning, + ) + + s1 = parser_TEasv(args.baseline_json, args.filter, None, debug_log=console).getStats() + s2 = parser_TEasv(args.candidate_json, args.filter, None, debug_log=console).getStats() + + cr = compareStats( + s1, s2, + method=args.method, + alpha=args.alpha, + main_metrics=main_metrics, + debug_log=console, + ) + + renderComparisonResults( + cr, console, + main_metrics=main_metrics, + always_show_pvalues=args.always_show_pvalues, + ) + + # Tally significant results per direction for the timing metric. benchstats + # encodes the outcome of each comparison as set0-vs-set1: "<" means baseline + # < candidate (candidate's time is higher -> slower / a regression), ">" + # means baseline > candidate (candidate faster / a speedup), "~" means not + # significant at alpha. Printed via the console so it is captured by export. + for metric in main_metrics: + counts = {"<": 0, ">": 0, "~": 0} + for bm_res in cr.results.values(): + res = bm_res.get(metric) + if res is not None: + counts[res.result] = counts.get(res.result, 0) + 1 + total = counts["<"] + counts[">"] + counts["~"] + console.print( + f"\nSummary for '{metric}' ({cr.method}, alpha={cr.alpha:g}, " + f"{total} benchmarks):" + ) + console.print(f" candidate faster (significant, '>'): {counts['>']}") + console.print(f" candidate slower (significant, '<'): {counts['<']}") + console.print(f" no significant difference ('~'): {counts['~']}") + + if export_fmt is not None: + if export_fmt == "txt": + console.save_text(args.export_to) + elif export_fmt == "svg": + console.save_svg(args.export_to, title="") + elif export_fmt == "html": + console.save_html(args.export_to) + + if cr.at_least_one_differs: + console.warning( + "At least one significant timing difference was detected (exit 1)." + ) + return 1 + return 0 + + +def main(): + parser = argparse.ArgumentParser( + description="Statistically compare two ASV result JSON files via benchstats.") + parser.add_argument("baseline_json", help="Baseline ASV result JSON") + parser.add_argument("candidate_json", help="Candidate ASV result JSON") + parser.add_argument( + "--filter", default=None, + help="Only compare benchmarks whose name matches this regex.", + ) + parser.add_argument( + "--alpha", type=float, default=0.001, + help="Significance level for the test (default: 0.001).", + ) + parser.add_argument( + "--method", default="brunnermunzel", + help="Statistical test to use (default: brunnermunzel).", + ) + parser.add_argument( + "--always-show-pvalues", action="store_true", + help="Always show p-values, including for non-significant results.", + ) + parser.add_argument( + "--export-to", default=None, metavar="FILE", + help="Export the report to a .txt/.svg/.html file (format from extension).", + ) + args = parser.parse_args() + + # The benchstats parser is imported lazily from the script directory. + script_dir = os.path.dirname(os.path.abspath(__file__)) + if script_dir not in sys.path: + sys.path.insert(0, script_dir) + + return run_stats(args) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/benchmarks/asv/driver.py b/benchmarks/asv/driver.py new file mode 100644 index 000000000..52abcda64 --- /dev/null +++ b/benchmarks/asv/driver.py @@ -0,0 +1,613 @@ +#!/usr/bin/env python3 +############################################################################### +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### +"""ASV benchmark driver — runs bench classes in-process and saves ASV-compatible results. + +Usage: + python driver.py [method_filter] [-w W] [-n N] [--no-save] + python driver.py --all [-w W] [-n N] [--no-save] + python bench_gemm.py [method_filter] [-w W] [-n N] [--no-save] +""" + +import argparse +import glob +import hashlib +import importlib +import inspect +import itertools +import json +import os +import platform +import random +import re +import subprocess +import sys +import textwrap +import time +import numpy as np + + +# --------------------------------------------------------------------------- +# ASV result generation +# --------------------------------------------------------------------------- + +def _get_benchmark_code_and_version(cls, method_name): + """Build the code string and version hash the same way ASV does. + + ASV hashes a code string built from the time_* and setup methods. + The string is class header + indented time method + indented setup, + with no trailing newline. + + Returns (code, version_hash). + """ + time_src = textwrap.dedent(inspect.getsource(getattr(cls, method_name))) + setup_src = textwrap.dedent(inspect.getsource(cls.setup)) + code = ( + f"class {cls.__name__}:\n" + + textwrap.indent(time_src, " ") + "\n" + + textwrap.indent(setup_src, " ") + ).rstrip("\n") + return code, hashlib.sha256(code.encode()).hexdigest() + + +def _format_param_value(v): + """Format a parameter value the way ASV stores it in JSON.""" + if isinstance(v, str): + return f"'{v}'" + return repr(v) + + +def _get_machine_info(): + """Build the params/machine dict ASV expects.""" + machine = platform.node() + info = { + "arch": platform.machine(), + "cpu": "", + "machine": machine, + "num_cpu": str(os.cpu_count()), + "os": f"{platform.system()} {platform.release()}", + "ram": "", + } + try: + with open("/proc/cpuinfo") as f: + for line in f: + if line.startswith("model name"): + info["cpu"] = line.split(":", 1)[1].strip() + break + with open("/proc/meminfo") as f: + for line in f: + if line.startswith("MemTotal"): + info["ram"] = line.split()[1] # kB + break + except OSError: + pass + return machine, info + + +def _get_commit_hash(): + """Get the current git HEAD hash.""" + try: + return subprocess.check_output( + ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL + ).decode().strip() + except Exception: + return "unknown" + + +def _compute_stats(samples): + """Return (median, mean, stdev, ci_lo, ci_hi, q25, q75) for *samples*. + + Quartiles use linear interpolation (numpy default) — more meaningful at + small n than the index-floor approach. stdev is population stdev to + match the prior wire format; CI is a normal-approximation 99% half-width. + """ + s = np.asarray(samples, dtype=np.float64) + mean = float(s.mean()) + stdev = float(s.std(ddof=0)) + median, q25, q75 = (float(x) for x in np.quantile(s, [0.5, 0.25, 0.75])) + ci = 2.576 * stdev / np.sqrt(s.size) # 99% normal-approx half-width + return median, mean, stdev, max(0.0, mean - ci), mean + ci, q25, q75 + + +def _get_results_dir(): + """Read results_dir from asv.conf.json, resolved to an absolute path.""" + conf_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "asv.conf.json") + with open(conf_path) as f: + conf = json.load(f) + conf_dir = os.path.dirname(conf_path) + return os.path.normpath(os.path.join(conf_dir, conf["results_dir"])) + + +def save_asv_results(all_results, bench_meta, label=None): + """Write results and benchmark index to ASV's results directory. + + *label*, when given, is folded into the result filename so multiple runs on + the same commit (e.g. prototyping with a dirty working tree, where the HEAD + hash is unchanged) land in distinct files that ``compare_results.py`` can + compare instead of overwriting each other. + """ + commit_hash = _get_commit_hash() + machine_name, machine_info = _get_machine_info() + env_name = "existing-" + sys.executable.replace("/", "_").strip("_") + results_dir = _get_results_dir() + machine_dir = os.path.join(results_dir, machine_name) + os.makedirs(machine_dir, exist_ok=True) + + # Write machine.json if missing + machine_json = os.path.join(machine_dir, "machine.json") + if not os.path.exists(machine_json): + with open(machine_json, "w") as f: + json.dump({**machine_info, "version": 1}, f, indent=4) + + # Load existing result file or start fresh. A label is sanitized to keep the + # filename safe (no path separators / whitespace) and inserted after the hash. + if label: + safe_label = re.sub(r"[^A-Za-z0-9._-]+", "_", label).strip("_") + filename = f"{commit_hash[:8]}-{safe_label}-{env_name}.json" + else: + filename = f"{commit_hash[:8]}-{env_name}.json" + result_path = os.path.join(machine_dir, filename) + if os.path.exists(result_path): + with open(result_path) as f: + data = json.load(f) + else: + data = { + "commit_hash": commit_hash, + "env_name": env_name, + "date": int(time.time() * 1000), + "params": {**machine_info, "python": sys.executable}, + "python": sys.executable, + "requirements": {}, + "env_vars": {}, + "result_columns": [ + "result", "params", "version", + "started_at", "duration", + "stats_ci_99_a", "stats_ci_99_b", + "stats_q_25", "stats_q_75", + "stats_number", "stats_repeat", + "samples", + ], + "results": {}, + "durations": {}, + "version": 2, + } + + # Merge new results + for bench_key, bench_data in all_results.items(): + data["results"][bench_key] = bench_data + + with open(result_path, "w") as f: + json.dump(data, f, indent=2) + + print(f"\nResults saved to {result_path}") + + # Update benchmarks.json index so ASV dashboard stays in sync + benchmarks_path = os.path.join(results_dir, "benchmarks.json") + if os.path.exists(benchmarks_path): + with open(benchmarks_path) as f: + benchmarks_data = json.load(f) + else: + benchmarks_data = {"version": 2} + + benchmarks_data.update(bench_meta) + + with open(benchmarks_path, "w") as f: + json.dump(benchmarks_data, f, indent=4) + + print(f"Updated {benchmarks_path}") + + +# --------------------------------------------------------------------------- +# Benchmark runner +# --------------------------------------------------------------------------- + +_ASV_META_DEFAULTS = { + "min_run_count": 2, "number": 0, "repeat": 0, "rounds": 2, + "sample_time": 0.01, "type": "time", "unit": "seconds", "warmup_time": -1, +} + + +def _make_scratch(mb): + """Allocate a scratch buffer used to evict the GPU cache between samples. + + Sized by default to exceed the MI300 Infinity Cache (256 MB) and the L2 + (16 MB), so a single fill writes through every level of cache. + """ + import torch # noqa: deferred import — only needed when cold-cache is on + n = max(1, (mb * 1024 * 1024) // 4) # float32 = 4 bytes + return torch.empty(n, dtype=torch.float32, device="cuda") + + +def _autotune_inner(instance, method_name, combo, target_s, max_inner=10000): + """Pick an inner-loop count so one timed window lasts >= target_s. + + The bench class is expected to honor instance._inner inside its time_* + method (loop the kernel that many times in one CUDA event window and + divide). This probe runs two single invocations: one to settle algorithm + selection / cache state, and one to estimate the per-call cost. + """ + method = getattr(instance, method_name) + saved_inner = instance._inner + instance._inner = 1 + try: + method(*combo) # discard: cold cache + autotuner warmup + t_per = method(*combo) # seconds per single invocation + finally: + instance._inner = saved_inner + if t_per is None or t_per <= 0: + return 1 + return max(1, min(max_inner, int(target_s / t_per) + 1)) + + +def _free_gpu_cache(): + """Release cached GPU memory between interleave chunks. + + No-op when torch was never imported (e.g. CPU-only test harnesses), so the + driver stays importable and runnable without torch present. + """ + torch = sys.modules.get("torch") + if torch is not None: + try: + torch.cuda.empty_cache() + except Exception: + pass + + +def run_class( + suite_name, cls, class_name, method_filter=None, + warmup=3, iters=7, + inner="auto", target_window_ms=1.0, + cold_cache=False, cache_flush_mb=256, + interleave_group=8, rng=None, shuffle=True, +): + """Run all benchmarks in a class, returning (results, metadata) dicts. + + Samples are collected in round-robin chunks of ``interleave_group`` + ``(method, combo)`` benchmarks: one sample is taken from each benchmark in + the chunk per round, for ``iters`` rounds. This spreads every benchmark's + samples across the same wall-clock window so time-correlated GPU noise + (thermal ramp, DVFS throttle) becomes shared variance rather than a bias on + whichever benchmark happened to own a contiguous block of time. See + ``repro/transient_noise_sim.py``. ``interleave_group=1`` reproduces the + original contiguous (sequential) behavior; larger groups interleave more + benchmarks but keep that many GPU instances live at once. + + When ``shuffle`` is true the per-round visit order is randomly permuted + (seeded by *rng*, a ``random.Random``; one is created with seed 0 if not + given). Fixed round-robin still pins each benchmark to a constant phase + within the round, so a monotonic ramp leaves a small constant per-benchmark + offset and each benchmark always sees the same predecessor's cache/clock + state. Permuting each round makes both uniform in expectation, turning that + residual bias into variance. The per-round structure is kept (each benchmark + still gets exactly ``iters`` evenly-spread samples) -- a balanced randomized + design, not a global shuffle that could re-cluster a benchmark's samples. + """ + methods = sorted(m for m in dir(cls) if m.startswith("time_")) + if method_filter: + methods = [m for m in methods if method_filter in m] + if not methods: + return {}, {} + + params = getattr(cls, "params", [[]]) + param_names = getattr(cls, "param_names", []) + combos = list(itertools.product(*params)) + asv_params = [[_format_param_value(v) for v in dim] for dim in params] + + # Discover throughput columns from work_* companions + # Each entry: (dict_key, column_header, unit_divisor) + probe_keys = set() + for m in methods: + wfn = getattr(cls, "work_" + m[5:], None) + if wfn: + try: + probe_keys.update(wfn(cls(), *combos[0])) + except Exception: + pass + throughput_cols = [] + if "flops" in probe_keys: + throughput_cols.append(("flops", "TFLOPS", 1e12)) + if "bytes" in probe_keys: + throughput_cols.append(("bytes", "GB/s", 1e9)) + + # Print table header + target_window_s = target_window_ms / 1000.0 + group = max(1, int(interleave_group)) + if rng is None: + rng = random.Random(0) + inner_desc = ( + "cold-cache (inner=1)" if cold_cache + else f"inner={inner}" if inner != "auto" + else f"inner=auto (>={target_window_ms:g}ms window)" + ) + if group == 1: + sched_desc = "sequential" + else: + sched_desc = f"interleaved group={group}, " + ("shuffled" if shuffle else "fixed-order") + print(f"\n{class_name} ({len(combos)} combos x {len(methods)} methods, " + f"{warmup} warmup, {iters} timed, {inner_desc}, {sched_desc})") + extra_hdr = "".join(f" {label:>10}" for _, label, _ in throughput_cols) + HDR = (f" {'median':>10} {'mean':>10} {'stdev':>10}" + f" {'q25':>10} {'q75':>10} {'min':>10} {'max':>10}" + + extra_hdr + f" {'inner':>5} {'method':<30} params") + print("-" * len(HDR)) + print(HDR) + print("-" * len(HDR)) + + all_results = {} + all_meta = {} + + # Per-method result columns, indexed by combo position. Filling by index + # decouples the wire format from the order samples are actually collected in, + # so interleaved scheduling leaves the saved JSON identical to sequential. + n_combos = len(combos) + cols = { + m: {k: [None] * n_combos for k in + ("median", "ci_lo", "ci_hi", "q25", "q75", "number", "repeat", "samples")} + for m in methods + } + versions = {} + for method_name in methods: + bench_key = f"{suite_name}.{class_name}.{method_name}" + code, version = _get_benchmark_code_and_version(cls, method_name) + versions[method_name] = version + all_meta[bench_key] = { + **_ASV_META_DEFAULTS, + "code": code, "name": bench_key, "version": version, + "param_names": list(param_names), "params": asv_params, + "timeout": getattr(cls, "timeout", 300), + } + + def _label(combo): + return ", ".join(f"{nm}={v}" for nm, v in zip(param_names, combo)) + + # Flatten to (method, combo) tasks, method-major so printed rows keep the + # same grouping as before, then sample them in round-robin chunks. + tasks = [(mi, ci) for mi in range(len(methods)) for ci in range(n_combos)] + started_at = int(time.time() * 1000) + t_start = time.perf_counter() + + for chunk_start in range(0, len(tasks), group): + chunk = tasks[chunk_start:chunk_start + group] + + # Setup phase: prepare every benchmark in the chunk (allocate tensors, + # pick _inner, warm up) and keep its instance live for round-robin timing. + live = [] # (instance, method_obj, method_name, combo, combo_idx) + for mi, ci in chunk: + method_name = methods[mi] + combo = combos[ci] + instance = cls() + try: + instance.setup(*combo) + except Exception as e: + print(f" SKIP {_label(combo)} setup failed: {e}") + continue # leaves None in this (method, combo) slot + + # Inner-loop and cache configuration. Cold-cache mode forces + # inner=1 so only the first invocation in the window sees a + # cold cache; otherwise the 2nd..Nth invocations would refill + # it and we'd be back to a warm-cache measurement. + if cold_cache: + instance._scratch = _make_scratch(cache_flush_mb) + instance._inner = 1 + elif inner == "auto": + instance._inner = _autotune_inner( + instance, method_name, combo, target_window_s) + else: + instance._inner = max(1, int(inner)) + + method = getattr(instance, method_name) + for _ in range(warmup): + method(*combo) + live.append((instance, method, method_name, combo, ci)) + + # Timed phase: one sample from each live benchmark per round, so a + # transient spike lands on one sample of each rather than corrupting a + # whole benchmark's contiguous block. The visit order is re-permuted + # each round (when shuffle is on) so no benchmark is pinned to a fixed + # phase / predecessor; chunk_samples stays keyed by the stable index i. + chunk_samples = [[] for _ in live] + order = list(range(len(live))) + for _ in range(iters): + if shuffle and len(order) > 1: + rng.shuffle(order) + for i in order: + instance, method, method_name, combo, ci = live[i] + t0 = time.perf_counter() + result = method(*combo) + wall = time.perf_counter() - t0 + chunk_samples[i].append(wall if result is None else result) + + # Finalize phase: stats, throughput, print, store into the combo slot. + for i, (instance, method, method_name, combo, ci) in enumerate(live): + samples = chunk_samples[i] + median, mean, stdev, ci_lo, ci_hi, q25, q75 = _compute_stats(samples) + s_min, s_max = min(samples), max(samples) + + c = cols[method_name] + c["median"][ci] = median + c["ci_lo"][ci] = ci_lo + c["ci_hi"][ci] = ci_hi + c["q25"][ci] = q25 + c["q75"][ci] = q75 + c["number"][ci] = instance._inner + c["repeat"][ci] = iters + # Keep the raw samples (seconds) for statistical comparison + # (compare_results.py). Rounded to 1 ns to keep the JSON compact + # without losing meaningful timing resolution. + c["samples"][ci] = [round(x, 9) for x in samples] + + # Derive throughput from work_* companion + work = {} + wfn = getattr(instance, "work_" + method_name[5:], None) + if wfn and median > 0: + try: + work = wfn(*combo) + except Exception: + pass + extra_cols = "" + for key, _, divisor in throughput_cols: + if key in work and median > 0: + extra_cols += f" {work[key] / median / divisor:>10.1f}" + else: + extra_cols += f" {'':>10}" + + print(f" {median*1000:>8.3f}ms {mean*1000:>8.3f}ms " + f"{stdev*1000:>8.3f}ms {q25*1000:>8.3f}ms {q75*1000:>8.3f}ms " + f"{s_min*1000:>8.3f}ms {s_max*1000:>8.3f}ms" + f"{extra_cols} " + f"{instance._inner:>5} {method_name:<30} {_label(combo)}") + + # Release this chunk's GPU instances before setting up the next chunk. + live.clear() + _free_gpu_cache() + + duration = time.perf_counter() - t_start + for method_name in methods: + bench_key = f"{suite_name}.{class_name}.{method_name}" + c = cols[method_name] + all_results[bench_key] = [ + c["median"], asv_params, versions[method_name], started_at, + round(duration, 2), + c["ci_lo"], c["ci_hi"], c["q25"], c["q75"], c["number"], c["repeat"], + c["samples"], + ] + + return all_results, all_meta + + +def run_as_main(caller_file=None): + """Run benchmarks from a bench file or from the command line. + + When called with a file path (from a bench file's ``__main__`` block), + the suite is derived from the filename. When called without arguments + (i.e. ``python driver.py bench_gemm``), the suite is taken from argv. + + Usage from a bench file:: + + if __name__ == "__main__": + from driver import run_as_main + run_as_main(__file__) + """ + parser = argparse.ArgumentParser( + description="Run ASV benchmarks directly in-process (no subprocess overhead).") + if caller_file is None: + parser.add_argument("suite", nargs="?", default=None, + help="Benchmark module name (e.g. bench_casting)") + parser.add_argument("--all", action="store_true", + help="Run all bench_*.py suites in the directory") + parser.add_argument("method_filter", nargs="?", default=None, + help="Only run time_* methods containing this string") + parser.add_argument("-w", "--warmup", type=int, default=10, + help="Number of warmup iterations (default: 3)") + parser.add_argument("-n", "--iters", type=int, default=20, + help="Number of timed iterations (default: 7)") + parser.add_argument("--inner", default="auto", + help="Inner kernel invocations per timed window: " + "'auto' (tune to --target-window-ms) or an integer " + "(default: auto). Larger values amortize CUDA event " + "and kernel-launch overhead.") + parser.add_argument("--target-window-ms", type=float, default=1.0, + help="Target duration of one timed window when " + "--inner=auto (default: 1.0 ms).") + parser.add_argument("--cold-cache", action="store_true", + help="Flush the GPU cache (write a >LLC scratch buffer) " + "before each sample. Forces --inner=1 because " + "subsequent inner calls would refill the cache.") + parser.add_argument("--cache-flush-mb", type=int, default=256, + help="Size in MB of the cache-flush buffer for " + "--cold-cache (default: 256, sized for the MI300 " + "Infinity Cache).") + parser.add_argument("--interleave-group", type=int, default=8, + help="Number of (method, combo) benchmarks sampled " + "round-robin together so time-correlated GPU noise " + "(thermal ramp / DVFS throttle) is shared across " + "them instead of biasing whichever benchmark owns a " + "contiguous block of wall-clock time (default: 8). " + "Each benchmark in a group keeps a live GPU " + "instance, so lower this on out-of-memory. 1 = " + "sequential. See repro/transient_noise_sim.py.") + parser.add_argument("--sequential", action="store_true", + help="Collect each benchmark's samples in one contiguous " + "block (equivalent to --interleave-group 1). Lowest " + "memory, but biased under thermal drift.") + parser.add_argument("--seed", type=int, default=0, + help="Seed for the per-round shuffle of the interleave " + "order (default: 0), kept fixed so runs are " + "reproducible.") + parser.add_argument("--no-shuffle", action="store_true", + help="Disable the per-round random permutation and use a " + "fixed round-robin order. Each benchmark then keeps " + "a constant within-round phase and predecessor, " + "leaving a small residual ordering bias.") + parser.add_argument("--no-save", action="store_true", + help="Skip saving results to ASV format") + parser.add_argument("--label", default=None, + help="Tag folded into the result filename " + "(-