Skip to content

Add opt-in FP8 DiT TensorRT engine for SA3-medium#47

Draft
ryanontheinside wants to merge 3 commits into
Stability-AI:mainfrom
ryanontheinside:feat/dit-fp8
Draft

Add opt-in FP8 DiT TensorRT engine for SA3-medium#47
ryanontheinside wants to merge 3 commits into
Stability-AI:mainfrom
ryanontheinside:feat/dit-fp8

Conversation

@ryanontheinside

@ryanontheinside ryanontheinside commented Jun 8, 2026

Copy link
Copy Markdown

FP8 DiT TensorRT engine (opt-in, ~1.8x per step, ~2.5x batched throughput)

Adds an optional FP8 GEMM-trunk TensorRT engine for the SA3-medium DiT, on top
of the existing FP16-mixed recipe. The DiT step is the inner loop of the
pingpong sampler, so this is the highest-leverage place to cut latency. The
engine keeps FP32 inputs/outputs, so it is a drop-in swap for the FP16-mixed
engine at inference (sa3_trt --precision fp8), pairing an FP8 DiT with the
existing FP16-mixed decoder.

Why FP8 here (and why batching is the larger win)

The FP16-mixed DiT engine is compute-saturated: a single batched forward barely
amortizes (<=1.09x at B=4), so a ring-buffer pipeline at depth > 1 hits a flat
throughput ceiling. FP8 cuts per-row GEMM compute ~1.8x, which frees the SM
throughput the FP16 engine saturated, so batching amortizes.

The recipe (why it is more than mtq.quantize)

build_dit_fp8.py takes the published dit_fp16mixed.onnx plus a calibration
.npz and produces dit_fp8.onnx:

  1. Kahn-sort + opset-19 convert (the FP16-mixed graph is left un-toposorted by
    the island surgery, and ModelOpt/ORT reject that).
  2. ModelOpt FP8 PTQ on MatMul/Gemm only, disable_mha_qdq (attention BMMs stay
    on the FP16/FP32 path), max calibration.
  3. Restore the handful of initializers ModelOpt corrupts during preprocessing,
    and recalibrate activation scales on a Q/DQ-bypassed copy with real
    conditioning.
  4. Re-apply the FP32 islands the FP16-mixed recipe established (RMSNorm /
    Softmax / RoPE) plus the conditioning front-end, which must stay FP32 or the
    t>=0.984 timestep features flush.
  5. Per-channel weight scales along the GEMM N axis (constant-folded at build,
    free at runtime). Activations stay per-tensor (TRT requirement for FP8
    activation quant); max calibration, because the activation outliers are
    signal and percentile clipping regresses parity.

Calibration data is captured from the model's own generate() by
make_calib.py, which records the six DiT engine inputs across the pingpong
schedule. Prompts come from this repo's interface/reprompt.py Music examples,
the deployment-matched reprompt format. The .npz is a reproducible producer
artifact (gitignored, never committed).

What is in this PR

Producer (model maintainers):

  • build/make_calib.py (new): calibration capture from the checkpoint.
  • build/build_dit_fp8.py (new): the FP8 build recipe above.

Consumer:

  • build/build_from_onnx.py, build/build.py: add sa3-m-fp8 as an opt-in
    target. It is excluded from all / all-both / "build all missing" and is
    built only by explicit name, gated on the published ONNX existing.
  • scripts/sa3_trt.py, scripts/sa3_trt_core.py: --precision fp8 selection
    (FP8 DiT + FP16-mixed decoder).
  • build/README.md: producer and consumer documentation.
  • .gitignore: ignore *.calib.npz.

Testing and validation

All results below were produced on a single RTX 5090 (sm_120), TensorRT
10.16, on SA3-medium at L=646 (the latent length of a ~54 s generation). Step
times are hardware-dependent; the speedup ratios are the portable claim.

1. Clean-room reproduction of the producer chain

The full producer path was rebuilt from a clean checkout, using only the inputs
a model maintainer has: the SA3-medium checkpoint and the published
dit_fp16mixed.onnx (both pulled from HuggingFace), nothing carried over from
development. make_calib.py captured a fresh calibration set, then
build_dit_fp8.py produced the engine end to end.

Calibration capture: 376 samples (47 reprompt Music prompts x 8 sigmas) at
L=646, schedule [1.0, 0.9944, 0.9845, 0.9579, 0.8909, 0.7455, 0.5125, 0.2739], t5_hidden range [-52.33, 36.10], x range [-5.87, 5.38],
local_add_cond all zero (text-to-music), matching the expected reference
profile.

Build stages, all clean:

  • toposort + opset 17 -> 19.
  • FP8 PTQ: 619 quantizable nodes, MHA Q/DQ disabled, max calibration over the
    376 samples.
  • repair: the two known ModelOpt-corrupted initializers restored
    (to_timestep_embed.2.bias 6060 -> 0.12, layers.22.to_local_embed
    3286 -> 1.30), 417 Q/DQ pairs bypassed, 834 activation scales recalibrated,
    0 mask-path pairs.
  • FP32 islands (hybrid): 2021 island nodes (96 Softmax) plus the conditioning
    front-end.
  • per-channel weights: 220 weight pairs upgraded, 440 per-channel scale vectors
    verified.
  • TensorRT STRONGLY_TYPED compile: 1494 MB engine.

The resulting engine deserializes and exposes the expected six FP32 inputs
(x, t, t5_hidden, t5_mask, seconds_total, local_add_cond) and the
velocity FP32 output, with dynamic latent length.

2. Numerical parity vs the FP16-mixed engine

Parity was measured by feeding the captured DiT inputs through both the FP8 and
FP16-mixed engines at batch 1 and comparing outputs.

Single-step latent agreement (x + dt * v, the quantity that actually advances
the sampler), over all 376 samples, by sigma:

sigma min cos mean cos
1.0000 1.00000 1.00000
0.9944 0.99999 1.00000
0.9845 0.99979 0.99998
0.9579 0.99985 0.99992
0.8909 0.99955 0.99973
0.7455 0.99891 0.99945
0.5125 0.99921 0.99948
0.2739 0.99824 0.99912

Worst single-step latent cosine 0.99824, mean 0.99971.

Compounded agreement was measured with an 8-step deterministic euler rollout
per prompt: each engine is chained from the same sigma=1.0 latent with its own
velocities and the final latents are compared, once per calibration prompt
(both engines warmed by the 376 single-step evaluations beforehand). Over all
47 prompts the final-latent cosine distribution is mean 0.953, median 0.957,
p5 0.915, worst 0.873, best 0.990.

Two caveats on reading the compounded number: the rollout is chaotic at the
early sigmas (a 1e-3 relative input perturbation alone compounds to ~0.967
final-latent cosine, and single FP8 steps at sigma 0.994/0.984 dominate the
divergence), and the FP16-mixed engine itself scores only ~0.998 compounded vs
PT eager. So compounded cosine is a guide rather than a gate; the acceptance
test is decoded audio under the production pingpong sampler, judged by ear,
where this engine's generation tracks the FP16-mixed generation at ~0.90
RMS-curve correlation (same conditioning and seeds).

Two negative results worth recording. BF16 was tried and rejected: it
compounds error over the 8 steps (final-latent cos ~0.81) and is audibly
degraded. Dequantizing the conditioning front-end GEMMs (to_cond_embed,
project_in, project_out) was also tried and rejected: it improves every
euler metric (worst single-step latent 0.9982 -> 0.9993, compounded mean
0.953 -> 0.967) yet reproducibly diverges under the production pingpong
sampler: decoded-audio RMS correlation vs the FP16-mixed generation drops
from ~0.90 to ~0.33, audibly a different song, across three independent
engine builds. Euler metrics under-weight the high-sigma steps (dt ~ 0)
exactly where pingpong's denoised prediction x - t*v amplifies velocity
error most, so recipe changes here must be gated by the listening test, not
by euler cosines. Under the stochastic pingpong sampler the FP8 engine yields
a different but comparable sample.

3. Per-step latency (B=1, L=646)

Median of 200 timed steps on a real calibration sample:

engine step time speedup
FP16-mixed 18.7-19.4 ms 1.0x
FP8 10.6-11.0 ms ~1.8x

The ranges are run-to-run and build-to-build variance. TensorRT tactic
selection is nondeterministic per build (we observed a ~10% slow outlier on
one build of a closely related graph). If a freshly built engine benches
noticeably slower than expected, rebuild it; the ONNX is deterministic, only
the engine compilation varies.

4. Batched throughput, depths 1..8

For each engine the input batch dimension was made dynamic, a STRONGLY_TYPED
B=1..8 engine was compiled, each batched row was validated against the serial
B=1 engine, then the median batched step time was benched against serial
dispatch. gens/s is the steady-state generation rate of an 8-step pingpong
ring buffer running at that batch (depth).

FP16-mixed:

B batched ms per-slot ms speedup gens/s
1 28.4 28.4 0.74x 4.4
2 42.7 21.4 0.98x 5.9
3 60.5 20.2 1.04x 6.2
4 76.9 19.2 1.09x 6.5
5 96.4 19.3 1.09x 6.5
6 125.5 20.9 1.00x 6.0
7 144.1 20.6 1.02x 6.1
8 164.1 20.5 1.02x 6.1

(serial B=1 reference engine: 20.9 ms)

FP8:

B batched ms per-slot ms speedup gens/s
1 11.0 11.0 0.96x 11.3
2 17.1 8.6 1.24x 14.6
3 24.0 8.0 1.33x 15.6
4 30.4 7.6 1.40x 16.5
5 37.6 7.5 1.41x 16.6
6 46.0 7.7 1.38x 16.3
7 53.2 7.6 1.40x 16.5
8 61.9 7.7 1.37x 16.2

(serial B=1 reference engine: 10.6 ms)

Reading these together:

  • FP16-mixed is compute-saturated: per-slot stays ~19 to 21 ms regardless of
    batch, batching buys <=1.09x, and the ring-buffer ceiling is ~6.5 gens/s.
  • FP8 exceeds the ceiling: per-slot drops to ~7.6 ms and batching amortizes up to
    1.41x at B=4..5, lifting the ceiling to ~16.5 gens/s.
  • The per-slot FP8 advantage holds ~2.5x across the batch (2.57x at B=1, 2.53x
    at B=4, 2.65x at B=8), and the end-to-end pipeline throughput a depth > 1
    ring buffer actually hits rises from ~6.5 to ~16.5 gens/s, about 2.5x. The
    headline 1.8x is the B=1 step; under the batching the pipeline uses, FP8
    compounds it.

5. Batching correctness

Each batched row was validated against the serial B=1 engine on identical
inputs. The FP16-mixed batched engine matched its serial engine at 0.99994
(worst row). The FP8 batched engine matched at a uniform 0.992 to 0.995 across
all rows: this is FP8 kernel-tactic variance between two separate engine builds
(the per-tensor activation scales are the same), not a batch-size
specialization bug, which the uniformity across rows confirms. Timing is
unaffected.


Status: ONNX not yet in the official model repo

dit_fp8.onnx + dit_fp8.onnx.data are not in
stabilityai/stable-audio-3-optimized yet, so build_from_onnx.py sa3-m-fp8
and sa3_trt --precision fp8 will 404 until they are uploaded there under
exactly those filenames. This is why sa3-m-fp8 is opt-in and kept out of the
default all build paths.

The built artifacts from the clean-room run above are staged at
ryanontheinside/stable-audio-3-optimized-fp8
in the official repo's layout (onnx/sa3-m/dit_fp8.onnx + .data, plus a
prebuilt tensorRT/sm_120 engine), so they can be verified directly or copied
into the official repo. Equally, a maintainer run of the two producer commands
below reproduces them from nothing but the checkpoint and the published
dit_fp16mixed.onnx; the staging copy exists for convenience, the recipe is
the source of truth.

Usage

Producer:

# 1. capture calibration data from the checkpoint
python build/make_calib.py \
  --model-config <MODELS_ROOT>/SA3-M-hf/model_config.json \
  --checkpoint   <MODELS_ROOT>/SA3-M-hf/model.safetensors \
  --out          sa3-m.calib.npz

# 2. build the FP8 ONNX + engine
python build/build_dit_fp8.py \
  --input  <HF_REPO>/onnx/sa3-m/dit_fp16mixed.onnx \
  --calib  sa3-m.calib.npz \
  --onnx   <HF_REPO>/onnx/sa3-m/dit_fp8.onnx \
  --engine ../models/<arch>/sa3-m/dit_fp8.trt

Requires nvidia-modelopt + onnxruntime-gpu on top of the consumer deps.

Consumer (once the ONNX is published):

python build/build_from_onnx.py sa3-m-fp8   # STRONGLY_TYPED compile, no ModelOpt
# then: sa3_trt --precision fp8

Producer recipe (build_dit_fp8.py) builds a ModelOpt FP8 GEMM-trunk DiT on
top of the FP16-mixed graph: FP8 PTQ on MatMul/Gemm, initializer repair plus
activation-scale recalibration, re-applied FP32 islands (RMSNorm/Softmax/RoPE
plus the conditioning front-end), and per-channel weight scales. make_calib.py
captures calibration inputs from the model's own pingpong generate(), pulling
prompts from interface/reprompt.py. ~1.8x faster steps than FP16-mixed at B=1,
amortizing further under batched dispatch.

Consumer wiring adds sa3-m-fp8 as an opt-in target (excluded from all/all-both
and 'build all missing'; built only by explicit name, gated on the published
ONNX) and a --precision fp8 selection that pairs the FP8 DiT with the
FP16-mixed decoder, guarded so non-medium DiTs cannot request it.
The compounded euler final-latent cosine was previously quoted as a
single-prompt range (~0.96 to 0.976). Measured over all 47 reprompt
calibration prompts (warmed engines, deterministic euler, L=646) the
distribution is mean 0.953, median 0.957, p5 0.915, worst 0.873; the
worst single-step latent cosine over all 376 samples is 0.9982. The
rollout is chaotic at the early sigmas (an eps=1e-3 input perturbation
alone compounds to ~0.967), so the docs now report the distribution
with that caveat and point to the production-sampler listening gate
(decoded audio tracks the FP16-mixed generation at ~0.90 RMS-curve
correlation) as the real acceptance test.

Also note that TRT tactic selection is nondeterministic per build:
step time varied 10.6-11.0 ms across rebuilds of the same ONNX, so a
noticeably slow engine should be rebuilt rather than blamed on the
recipe.
Report FP8 DiT parity as the full 47-prompt distribution
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant