Skip to content

voxtral_tts: enable CUDA backend with 4w quantization (Ampere + Blackwell pre-exported artifacts)#19093

Open
seyeong-han wants to merge 10 commits intopytorch:mainfrom
seyeong-han:voxtral-tts
Open

voxtral_tts: enable CUDA backend with 4w quantization (Ampere + Blackwell pre-exported artifacts)#19093
seyeong-han wants to merge 10 commits intopytorch:mainfrom
seyeong-han:voxtral-tts

Conversation

@seyeong-han
Copy link
Copy Markdown
Contributor

@seyeong-han seyeong-han commented Apr 23, 2026

Description

This PR brings the ExecuTorch CUDA AOTI backend to examples/models/voxtral_tts. The full pipeline (LM + flow head + codec) now runs on GPU. With --qlinear 4w and --streaming, end-to-end synthesis for a 24-token prompt on an RTX 5080 runs at RTF 0.31x — over 3× real-time with 2.6 s time-to-first-audio.

Headlines

Backend model.ptd LM time Total E2E RTF
XNNPACK fp32 (CPU) 3.2 s 15.3 s 4.8x
CUDA fp32 + portable codec 15.8 GB 11.5 s 178 s 51x
CUDA 4w + CUDA codec (offline) 3.4 GB 2.1 s 3.7 s 0.88x ⚡ (A100)
CUDA 4w + streaming 3.4 GB 3.85 s 0.31x ⚡⚡ (RTX 5080)

Streaming (--streaming --speaker) decouples generation latency from audio length: the first chunk arrives in ~2.6 s, then 2 s chunks emit continuously as the model decodes. On RTX 5080 (sm_120), a 24-token prompt synthesizes 10.3 s of audio while only blocking the speaker for 3.85 s wall — 3.2× faster than real-time.

Numerical parity vs XNNPACK FP32 on seed=42:

  • Last-position prefill hidden cosine: 0.999994
  • First-frame semantic argmax: identical (3040)
  • First-frame top-5 logits: identical
  • Frame count before END_AUDIO: 39 vs CPU baseline 40 (within bf16 sampling noise)

What changed and why

The CUDA enablement broke into five challenges:

  1. Causal mask was missing on CUDA. MistralDecoder.forward didn't build any attention mask. CUDA's triton.sdpa then attended over the full zero-initialized [1, H_kv, max_seq_len, D] cache. Fix: port _build_causal_mask_bool from voxtral_realtime/model.py and thread through every layer (CUDA only — XNNPACK custom_sdpa infers prefix length from start_pos internally).

  2. AOTI Triton SDPA only accepts bf16. Initial fix promoted the entire model to bf16, which degraded semantic_head and predict_velocity precision. Fix: isolate bf16 to StaticKVCache (BHSD bf16 buffers) + StandardSDPA (cast Q to bf16 before triton.sdpa, cast result back to input dtype). Model weights stay fp32.

  3. Triton conv autotune has no choices for the codec's ConvTranspose1d shapes, and AOTI's CUDA runtime ships no aoti_torch_cuda_convolution shim. Fix: rewrite Conv1d / ConvTranspose1d as unfold + matmul / matmul + Fold (_conv1d_as_matmul, _conv_transpose1d_as_matmul in model.py). Math is bit-exact (eager parity max abs diff 5.5e-10 in fp32). Triton's batched-matmul autotune found 20 valid kernel choices for the new path where the conv form had 0.

  4. 4w quantization plumbing. --qlinear 4w on CUDA auto-promotes --dtype to bf16 and auto-sets --qlinear-packing-format=tile_packed_to_4d (required by _weight_int4pack_mm). flow_head.input_projection (3072×36) is auto-skipped because K=36 isn't divisible by group_size=32 — caught via skip_incompatible_shapes=True.

  5. Runner bf16 staging. Model.pte exports an lm_input_is_bf16 metadata int. The runner reads it at load time and switches its from_blob(...) calls to bf16 staging buffers when the model is bf16 (quantized exports). Default fp32 path stays untouched.

An adversarial review pass caught issues (1) and (2) before this PR went out.

How to validate

unset CPATH && export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH

# Full pipeline (one-shot script; ~10 min for first-time export + build, ~4 s/run after)
bash examples/models/voxtral_tts/run_cuda_e2e.sh ~/models/Voxtral-4B-TTS-2603

# Listen
ffplay $PWD/voxtral_tts_exports_cuda_4w/sample.wav

# Streaming live playback (pipe raw PCM — first audio in ~2.6 s)
cmake-out/examples/models/voxtral_tts/voxtral_tts_runner \
    --model voxtral_tts_exports_cuda_4w/model.pte \
    --data_path voxtral_tts_exports_cuda_4w/aoti_cuda_blob.ptd \
    --codec voxtral_tts_exports_cuda_4w/codec_decoder.pte \
    --codec_data_path voxtral_tts_exports_cuda_4w/codec_aoti_cuda_blob.ptd \
    --tokenizer ~/models/Voxtral-4B-TTS-2603/tekken.json \
    --voice ~/models/Voxtral-4B-TTS-2603/voice_embedding/neutral_female.pt \
    --text "Hello, how are you today?" \
    --streaming --speaker \
  | ffplay -f f32le -ar 24000 -ac 1 -nodisp -autoexit -

Pre-exported artifacts on HF Hub

Sub-real-time CUDA artifacts are distributed at
younghan-meta/Voxtral-4B-TTS-2603-ExecuTorch-CUDA
in per-arch subfolders so users can skip the export step:

Folder Arch GPU Offline RTF Streaming RTF Time-to-first
sm80/ Ampere A100 80 GB 0.88x
sm120/ Blackwell RTX 5080 16 GB 1.29x 0.31x ⚡⚡ ~2.6 s

Streaming numbers: 24 text tokens → 10.3 s audio, measured on warm Triton autotune cache.
Offline numbers: 7 text tokens ("Hello, how are you today?"), same warm-cache condition.

AOTI bakes pre-compiled cubins for the export-time arch into the *.ptd, so cubins aren't compatible across architectures — running an sm_80 blob on a Blackwell card fails with CUDA driver error: invalid argument on the first kernel launch. The README's "Pre-exported artifacts" section documents the per-arch download pattern and the WSL2 LIBRARY_PATH linker gotcha hit during the Blackwell re-export.

Suggested review order

  1. README.md — user-visible surface (CUDA section + streaming perf + gotchas table + multi-arch HF artifact section)
  2. model.pyStaticKVCache (BHSD bf16) + StandardSDPA (bf16 cast in/out) + _build_causal_mask_bool + _conv1d_as_matmul / _conv_transpose1d_as_matmul
  3. export_voxtral_tts.py--backend cuda + --qlinear 4w plumbing, helper extraction (_apply_cuda_arg_defaults, _export_lm_pte, _export_codec_pte)
  4. voxtral_tts_runner.{h,cpp}, main.cpp--data_path / --codec_data_path flags, bf16 staging path gated on lm_input_is_bf16 metadata
  5. CMakePresets.jsonvoxtral-tts-cuda preset
  6. run_cuda_e2e.sh — one-shot pipeline script

seyeong-han and others added 5 commits April 16, 2026 10:42
Add the in-progress Voxtral TTS export, runner, parity, and acceptance tooling so the work can be resumed on another machine without losing the current investigation state.

Made-with: Cursor
…NNPACK

Three bugs fixed: codec reshape order (P*T to T*P), flow-matching RNG
(mt19937 to xorshift64+BoxMuller matching C ref), ALiBi slopes off-by-one.
Adds --speaker for live PCM output, parakeet STT gate, quantization
docs and benchmarks.

Authored with Claude.
…eal-time on A100)

Adds full CUDA AOTI support to voxtral_tts. Headlines on A100 80GB for
"Hello, how are you today?" with seed=42:
  XNNPACK fp32 baseline:    15.3s wall clock, RTF 4.8x
  CUDA fp32 + portable codec:  178s, RTF 51x  (codec dominated on CPU)
  CUDA 4w + CUDA codec:       3.7s, RTF 0.88x (sub-real-time)

The 4w-quant + full-CUDA pipeline matches the XNNPACK baseline on prefill
hidden state (cosine 0.999994), first-frame semantic argmax, and top-5 logits.

Suggested review order:
  1. README.md, BENCHMARK.md, PROGRESS.md  -- user-visible surface
  2. model.py                              -- StaticKVCache + StandardSDPA + causal mask + conv-as-matmul codec
  3. export_voxtral_tts.py                 -- --backend cuda + --qlinear 4w plumbing
  4. voxtral_tts_runner.{h,cpp}, main.cpp  -- bf16 staging via lm_input_is_bf16 metadata
  5. CMakePresets.json                     -- voxtral-tts-cuda preset
  6. test_cuda_parity.py                   -- 11 eager-parity gates (CUDA-required, skip otherwise)
  7. run_cuda_e2e.sh                       -- one-shot pipeline script

Authored with Claude (Anthropic) assistance.
…3_5_moe layout)

Internal docs, parity tooling, and developer-only test scripts move to the
voxtral-tts-dev branch. The PR now ships the same kind of files qwen3_5_moe
exposes publicly: model.py, export script, runner, CMake, README.

Removed (kept on voxtral-tts-dev):
  BENCHMARK.md, PROGRESS.md
  voxtral_tts_vs_voxtral_realtime_manager_note.md
  mermaid_architecture_voxtral_tts_parity_gap.md
  parity.py, compare_parity_traces.py
  test_cuda_parity.py, test_eager_e2e.py, test_export_cli.py,
  test_parity.py, test_validation_contract.py,
  test_verify_codec_export.py, test_verify_export_parity.py
  transcribe_apple_speech.swift, transcribe_parakeet.py
  verify_codec_export.py, verify_export_parity.py, verify_xnnpack_transcript.py

Updated README and run_cuda_e2e.sh to drop links to the moved files.

Authored with Claude (Anthropic) assistance.
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 23, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19093

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 Cancelled Job, 3 Unrelated Failures

As of commit a46f783 with merge base 2d53535 (image):

CANCELLED JOB - The following job was cancelled. Please retry:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 23, 2026
Young Han and others added 3 commits April 23, 2026 14:20
… layout

- Revert CLAUDE.md edit that slipped into the prior commit (out of scope).
- Add `voxtral_tts-cpu` and `voxtral_tts-cuda` Makefile targets following
  the same pattern voxtral_realtime / qwen3_5_moe use, including .PHONY +
  help-text entries. `make voxtral_tts-cuda` now builds parent ExecuTorch
  with CUDA + the runner in one step.
- Rewrite README.md to mirror qwen3_5_moe's layout: Overview, Prerequisites,
  Export (with options table), Build (one-line `make` command), Run (with
  options table), Troubleshooting. Drops the previous mixed
  Architecture/Quick-Start/Build/Run shape.

Authored with Claude (Anthropic) assistance.
ufmt + clang-format whitespace and import-ordering only. No semantic changes.

Authored with Claude (Anthropic) assistance.
@seyeong-han seyeong-han changed the title add voxtral_tts: enable CUDA backend with 4w quantization voxtral_tts: enable CUDA backend with 4w quantization (Ampere + Blackwell pre-exported artifacts) Apr 24, 2026
Comment thread examples/models/voxtral_tts/README.md
Comment thread examples/models/voxtral_tts/CMakePresets.json Outdated
@seyeong-han
Copy link
Copy Markdown
Contributor Author

@pytorchbot label "release notes: examples"

@pytorch-bot pytorch-bot Bot added the release notes: examples Changes to any of our example LLMs integrations, such as Llama3 and Llava label Apr 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: examples Changes to any of our example LLMs integrations, such as Llama3 and Llava

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants