Add opt-in FP8 DiT TensorRT engine for SA3-medium#47
Draft
ryanontheinside wants to merge 3 commits into
Draft
Conversation
Producer recipe (build_dit_fp8.py) builds a ModelOpt FP8 GEMM-trunk DiT on top of the FP16-mixed graph: FP8 PTQ on MatMul/Gemm, initializer repair plus activation-scale recalibration, re-applied FP32 islands (RMSNorm/Softmax/RoPE plus the conditioning front-end), and per-channel weight scales. make_calib.py captures calibration inputs from the model's own pingpong generate(), pulling prompts from interface/reprompt.py. ~1.8x faster steps than FP16-mixed at B=1, amortizing further under batched dispatch. Consumer wiring adds sa3-m-fp8 as an opt-in target (excluded from all/all-both and 'build all missing'; built only by explicit name, gated on the published ONNX) and a --precision fp8 selection that pairs the FP8 DiT with the FP16-mixed decoder, guarded so non-medium DiTs cannot request it.
The compounded euler final-latent cosine was previously quoted as a single-prompt range (~0.96 to 0.976). Measured over all 47 reprompt calibration prompts (warmed engines, deterministic euler, L=646) the distribution is mean 0.953, median 0.957, p5 0.915, worst 0.873; the worst single-step latent cosine over all 376 samples is 0.9982. The rollout is chaotic at the early sigmas (an eps=1e-3 input perturbation alone compounds to ~0.967), so the docs now report the distribution with that caveat and point to the production-sampler listening gate (decoded audio tracks the FP16-mixed generation at ~0.90 RMS-curve correlation) as the real acceptance test. Also note that TRT tactic selection is nondeterministic per build: step time varied 10.6-11.0 ms across rebuilds of the same ONNX, so a noticeably slow engine should be rebuilt rather than blamed on the recipe.
Report FP8 DiT parity as the full 47-prompt distribution
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
FP8 DiT TensorRT engine (opt-in, ~1.8x per step, ~2.5x batched throughput)
Adds an optional FP8 GEMM-trunk TensorRT engine for the SA3-medium DiT, on top
of the existing FP16-mixed recipe. The DiT step is the inner loop of the
pingpong sampler, so this is the highest-leverage place to cut latency. The
engine keeps FP32 inputs/outputs, so it is a drop-in swap for the FP16-mixed
engine at inference (
sa3_trt --precision fp8), pairing an FP8 DiT with theexisting FP16-mixed decoder.
Why FP8 here (and why batching is the larger win)
The FP16-mixed DiT engine is compute-saturated: a single batched forward barely
amortizes (<=1.09x at B=4), so a ring-buffer pipeline at depth > 1 hits a flat
throughput ceiling. FP8 cuts per-row GEMM compute ~1.8x, which frees the SM
throughput the FP16 engine saturated, so batching amortizes.
The recipe (why it is more than
mtq.quantize)build_dit_fp8.pytakes the publisheddit_fp16mixed.onnxplus a calibration.npzand producesdit_fp8.onnx:the island surgery, and ModelOpt/ORT reject that).
disable_mha_qdq(attention BMMs stayon the FP16/FP32 path),
maxcalibration.and recalibrate activation scales on a Q/DQ-bypassed copy with real
conditioning.
Softmax / RoPE) plus the conditioning front-end, which must stay FP32 or the
t>=0.984 timestep features flush.
free at runtime). Activations stay per-tensor (TRT requirement for FP8
activation quant);
maxcalibration, because the activation outliers aresignal and percentile clipping regresses parity.
Calibration data is captured from the model's own
generate()bymake_calib.py, which records the six DiT engine inputs across the pingpongschedule. Prompts come from this repo's
interface/reprompt.pyMusic examples,the deployment-matched reprompt format. The
.npzis a reproducible producerartifact (gitignored, never committed).
What is in this PR
Producer (model maintainers):
build/make_calib.py(new): calibration capture from the checkpoint.build/build_dit_fp8.py(new): the FP8 build recipe above.Consumer:
build/build_from_onnx.py,build/build.py: addsa3-m-fp8as an opt-intarget. It is excluded from
all/all-both/ "build all missing" and isbuilt only by explicit name, gated on the published ONNX existing.
scripts/sa3_trt.py,scripts/sa3_trt_core.py:--precision fp8selection(FP8 DiT + FP16-mixed decoder).
build/README.md: producer and consumer documentation..gitignore: ignore*.calib.npz.Testing and validation
All results below were produced on a single RTX 5090 (sm_120), TensorRT
10.16, on SA3-medium at L=646 (the latent length of a ~54 s generation). Step
times are hardware-dependent; the speedup ratios are the portable claim.
1. Clean-room reproduction of the producer chain
The full producer path was rebuilt from a clean checkout, using only the inputs
a model maintainer has: the SA3-medium checkpoint and the published
dit_fp16mixed.onnx(both pulled from HuggingFace), nothing carried over fromdevelopment.
make_calib.pycaptured a fresh calibration set, thenbuild_dit_fp8.pyproduced the engine end to end.Calibration capture: 376 samples (47 reprompt Music prompts x 8 sigmas) at
L=646, schedule
[1.0, 0.9944, 0.9845, 0.9579, 0.8909, 0.7455, 0.5125, 0.2739],t5_hiddenrange[-52.33, 36.10],xrange[-5.87, 5.38],local_add_condall zero (text-to-music), matching the expected referenceprofile.
Build stages, all clean:
maxcalibration over the376 samples.
(
to_timestep_embed.2.bias6060 -> 0.12,layers.22.to_local_embed3286 -> 1.30), 417 Q/DQ pairs bypassed, 834 activation scales recalibrated,
0 mask-path pairs.
front-end.
verified.
The resulting engine deserializes and exposes the expected six FP32 inputs
(
x,t,t5_hidden,t5_mask,seconds_total,local_add_cond) and thevelocityFP32 output, with dynamic latent length.2. Numerical parity vs the FP16-mixed engine
Parity was measured by feeding the captured DiT inputs through both the FP8 and
FP16-mixed engines at batch 1 and comparing outputs.
Single-step latent agreement (
x + dt * v, the quantity that actually advancesthe sampler), over all 376 samples, by sigma:
Worst single-step latent cosine 0.99824, mean 0.99971.
Compounded agreement was measured with an 8-step deterministic euler rollout
per prompt: each engine is chained from the same sigma=1.0 latent with its own
velocities and the final latents are compared, once per calibration prompt
(both engines warmed by the 376 single-step evaluations beforehand). Over all
47 prompts the final-latent cosine distribution is mean 0.953, median 0.957,
p5 0.915, worst 0.873, best 0.990.
Two caveats on reading the compounded number: the rollout is chaotic at the
early sigmas (a 1e-3 relative input perturbation alone compounds to ~0.967
final-latent cosine, and single FP8 steps at sigma 0.994/0.984 dominate the
divergence), and the FP16-mixed engine itself scores only ~0.998 compounded vs
PT eager. So compounded cosine is a guide rather than a gate; the acceptance
test is decoded audio under the production pingpong sampler, judged by ear,
where this engine's generation tracks the FP16-mixed generation at ~0.90
RMS-curve correlation (same conditioning and seeds).
Two negative results worth recording. BF16 was tried and rejected: it
compounds error over the 8 steps (final-latent cos ~0.81) and is audibly
degraded. Dequantizing the conditioning front-end GEMMs (
to_cond_embed,project_in,project_out) was also tried and rejected: it improves everyeuler metric (worst single-step latent 0.9982 -> 0.9993, compounded mean
0.953 -> 0.967) yet reproducibly diverges under the production pingpong
sampler: decoded-audio RMS correlation vs the FP16-mixed generation drops
from ~0.90 to ~0.33, audibly a different song, across three independent
engine builds. Euler metrics under-weight the high-sigma steps (dt ~ 0)
exactly where pingpong's denoised prediction
x - t*vamplifies velocityerror most, so recipe changes here must be gated by the listening test, not
by euler cosines. Under the stochastic pingpong sampler the FP8 engine yields
a different but comparable sample.
3. Per-step latency (B=1, L=646)
Median of 200 timed steps on a real calibration sample:
The ranges are run-to-run and build-to-build variance. TensorRT tactic
selection is nondeterministic per build (we observed a ~10% slow outlier on
one build of a closely related graph). If a freshly built engine benches
noticeably slower than expected, rebuild it; the ONNX is deterministic, only
the engine compilation varies.
4. Batched throughput, depths 1..8
For each engine the input batch dimension was made dynamic, a STRONGLY_TYPED
B=1..8 engine was compiled, each batched row was validated against the serial
B=1 engine, then the median batched step time was benched against serial
dispatch.
gens/sis the steady-state generation rate of an 8-step pingpongring buffer running at that batch (depth).
FP16-mixed:
(serial B=1 reference engine: 20.9 ms)
FP8:
(serial B=1 reference engine: 10.6 ms)
Reading these together:
batch, batching buys <=1.09x, and the ring-buffer ceiling is ~6.5 gens/s.
1.41x at B=4..5, lifting the ceiling to ~16.5 gens/s.
at B=4, 2.65x at B=8), and the end-to-end pipeline throughput a depth > 1
ring buffer actually hits rises from ~6.5 to ~16.5 gens/s, about 2.5x. The
headline 1.8x is the B=1 step; under the batching the pipeline uses, FP8
compounds it.
5. Batching correctness
Each batched row was validated against the serial B=1 engine on identical
inputs. The FP16-mixed batched engine matched its serial engine at 0.99994
(worst row). The FP8 batched engine matched at a uniform 0.992 to 0.995 across
all rows: this is FP8 kernel-tactic variance between two separate engine builds
(the per-tensor activation scales are the same), not a batch-size
specialization bug, which the uniformity across rows confirms. Timing is
unaffected.
Status: ONNX not yet in the official model repo
dit_fp8.onnx+dit_fp8.onnx.dataare not instabilityai/stable-audio-3-optimizedyet, sobuild_from_onnx.py sa3-m-fp8and
sa3_trt --precision fp8will 404 until they are uploaded there underexactly those filenames. This is why
sa3-m-fp8is opt-in and kept out of thedefault
allbuild paths.The built artifacts from the clean-room run above are staged at
ryanontheinside/stable-audio-3-optimized-fp8in the official repo's layout (
onnx/sa3-m/dit_fp8.onnx+.data, plus aprebuilt
tensorRT/sm_120engine), so they can be verified directly or copiedinto the official repo. Equally, a maintainer run of the two producer commands
below reproduces them from nothing but the checkpoint and the published
dit_fp16mixed.onnx; the staging copy exists for convenience, the recipe isthe source of truth.
Usage
Producer:
Requires
nvidia-modelopt+onnxruntime-gpuon top of the consumer deps.Consumer (once the ONNX is published):