[megatron] Fused LM-head log-prob + entropy (avoid full [*, seq, vocab] logit materialization)#1765
Conversation
…rializing [*, seq, vocab] logits
Adds an opt-in fused LM-head log-prob path for the Megatron backend that folds the
vocab projection into the chunked log-prob so the full [*, seq, vocab // TP] logits
tensor is never materialized — the dominant activation transient on the SFT/logprob
path at large vocab / long context. Numerically equivalent to the stock logits path.
Design ported from verl's fused linear-cross-entropy. Two backends behind one flag:
* "torch" (default): a pure-PyTorch vocab-parallel FusedLinearLogprob autograd
Function. Chunks over the sequence dim, recomputes the per-chunk logits in
backward, fp32 softmax math, masks out-of-shard targets. Runs anywhere (CPU/GPU),
no extra deps. Verified bit-exact (fwd + grad-hidden) and fp32-rounding-exact
(grad-weight) vs the stock logits path on a real gloo TP=1/TP=2 group.
* "triton": vendored flash-style Triton kernel (from volcengine/verl, Apache-2.0)
that tiles over the vocab dim so per-chunk logits never materialize (lower memory,
faster). GPU + triton required; falls back to "torch" with a warning otherwise.
Integration (model-source-free; no megatron-core edits):
* MegatronModelWrapper wraps output_layer.forward to return the pre-projection
hidden state (skipping the vocab GEMM) and capture the LM-head weight shard. The
weight is resolved from the weight= kwarg the model passes — i.e.
shared_embedding_or_output_weight() — so tied embeddings work with no None-weight
crash. Sequence-parallel hidden is gathered with tensor_parallel_output_grad=True,
whose backward reduce-scatters grad_hidden across TP exactly as ColumnParallelLinear
does, so the fused function's per-shard grad_hidden is reduced correctly.
* Both the training loss_func and the forward-only collection_func (eval / reference
logprobs) route through the fused path; the vocab range is derived from the captured
weight shard, never from hidden.shape[-1].
* Falls back to the stock logits path automatically when the model uses MuP output
scaling (the fused path bypasses _scale_logits) and is restricted to the
cross_entropy loss (the RL entropy/KL terms still need full logits).
Config: trainer.fused_linear_logprob (bool, default False -> byte-identical to today)
and trainer.fused_linear_logprob_backend ("torch" | "triton"). Megatron-only.
Tests: tests/.../cpu/megatron/test_fused_linear_logprob.py (CPU gloo TP=1/TP=2,
with/without out-of-shard targets, chunked + single-chunk) asserts the fused path
matches the stock logits path for forward, grad-hidden, and grad-weight. A GPU test
covers the Triton backend.
The vendored Triton kernel preserves verl's out-of-bounds-vocab -inf masking fix
(verl-project/verl#2656) and reproduces the NVIDIA + Bytedance Apache-2.0 attribution
(see file header + new top-level NOTICE).
Signed-off-by: dyurk-lila <dyurk@lila.ai>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ision policy
Validated on 2xH100 (buzz): all 16 GPU cases pass (TP1/TP2 × chunk {8,1000} ×
{fp32,bf16} × with/without out-of-shard targets).
- Fix an out-of-vocab gradient bug in the Triton adapter backward: it was zeroing
grad_output at fully-OOV positions, but the reference (stock DistributedLogprob /
the pure-torch FusedLinearLogprob) produces d_logits = -softmax * grad_output
there (no owned chosen column => the onehot term drops, the -softmax term remains).
Zeroing diverged the gradient (~0.04 abs). Now grad_output passes through unchanged;
the forward still masks the OOV logprob value to 0. Matches the torch backend exactly.
- Precision policy: production uses the fast default matmul precision (TF32 on Hopper) —
torch/triton's default, and the model trains in bf16 where it's empirically fine
(bf16 cases match the stock path within ~3e-3). A module-level FORCE_FP32_IEEE_PRECISION
flag (default False) lets the test pin IEEE for fp32 cases as a reproducible
"kernel math is exact" sanity check. Measured at TP=1 fp32: default precision fwd
error 1.27e-3 (TF32 rounding), IEEE 9.5e-7 (exact) — confirming the only default-path
gap is tensor-core rounding, not a math error. apply() signature unchanged.
Signed-off-by: dyurk-lila <dyurk@lila.ai>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…gprob adapter The Triton kernel's tl.dot(hidden, weight.T) requires both operands to share a dtype. On a real hybrid model the pre-head hidden state is fp32 (final norm) while the LM-head weight is bf16, so the adapter crashed with "AssertionError: Both operands must be same dtype. Got fp32 and bf16" on the first real training step (the fp32/fp32 and bf16/bf16 test cases never exposed it). Mirror the pure-torch FusedLinearLogprob (hidden.to(weight.dtype) @ weight.t()) and ColumnParallelLinear (which casts its input to the weight dtype before the bf16 GEMM): cast hidden to weight.dtype before the kernel, and cast the returned d_hidden back to the original hidden dtype so autograd accumulates it into the fp32 hidden leaf correctly. No-op when dtypes already match. Signed-off-by: dyurk-lila <dyurk@lila.ai> Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…through the fused path The fused linear log-prob previously served only the SFT cross_entropy loss and hard-raised for every other (RL) loss, because the RL loss_func also needs ENTROPY, which it computed from the full logits the fused path deliberately never materializes. This wires entropy through the fused kernel — verl's established (log_probs, entropy) dual-output — so SFT and ALL RL losses (ppo/grpo/etc.) can use fusion. No new kernel math: verl's per-token entropy = logsumexp(logits) - Σ(softmax·logits) is algebraically identical to stock _VocabParallelEntropy (max + log(sum_exp) - Σpx) and to the Triton epilogue's entropy; the Triton kernel already computes+returns entropy and already accepts a dentropy buffer in backward (the adapter was discarding both). - model_utils.py: new _distributed_log_softmax_and_entropy (reuses the logprob path's MAX+SUM all-reduces + one extra SUM of Σ(softmax·logit)). FusedLinearLogprob gains return_entropy (default False → byte-identical); when True returns (logprob, entropy) and backward adds the entropy term dlogits += softmax·(Σpx − logits)·grad_entropy (== stock _VocabParallelEntropy backward; == verl's term). Both public entry points + the dispatcher thread return_entropy; entropy is returned raw per-position (not rolled/trimmed). vocab_parallel_entropy_packed_sequences gains precomputed_entropy_tokens so the packed action-weighting + CP all-reduce stay identical, only the entropy source swaps. - fused_linear_logprob_triton.py: stop discarding verl's entropy; wire the incoming entropy grad into the existing dentropy buffer. Entropy is NOT OOV-masked (target-independent). No kernel edits. - megatron_model_wrapper.py: remove the cross_entropy-only raise; route RL entropy through the fused dual-output (non-packed slices the action window [-num_actions-1:-1]; packed feeds precomputed_entropy_tokens). KL keeps using logprobs. SFT stays logprob-only (skips the entropy all-reduce). MuP fallback + fused-off byte-identical preserved. - tests: CPU entropy fwd+backward parity vs stock vocab_parallel_entropy (TP1/TP2, chunk, OOV), target-independence, packed-substitution bit-identity, slice-alignment guard; GPU Triton entropy parity test. All CPU runners pass (logprob regression bit-identical; entropy ~5e-7). Signed-off-by: dyurk-lila <dyurk@lila.ai> Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ropy GPU test The new entropy parity test crashed in its REFERENCE (vocab_parallel_entropy), which all-reduces over mpu.get_tensor_model_parallel_group() — uninitialized because the worker only called dist.init_process_group, never mpu.initialize_model_parallel. (The logprob test sidesteps this by referencing dist.group.WORLD directly.) Initialize model-parallel state with tensor_model_parallel_size=world_size (TP group == WORLD membership, so both the stock reference and the fused path reduce over identical ranks) and destroy it in the finally. Harness-only fix; the Triton kernel/adapter were never reached by the failing assertion. Signed-off-by: dyurk-lila <dyurk@lila.ai> Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Polish the fused linear log-prob feature so it reads and behaves as a native
SkyRL feature, with no internal/development-history references and CI-correct
tests. No functional change to the fused math (CPU gloo parity re-verified at
TP=1/2, fwd bit-exact, grads at fp32 round-off).
CI correctness:
* GPU Triton test: add pytest.mark.megatron (the megatron GPU job selects by
`-m megatron`, so the test was collected but never run); combine with the
existing skipif in a module-level pytestmark list.
* CPU test: stub megatron.core.parallel_state at module import scope when
megatron-core is absent, so collection no longer crashes the CPU CI job
(which installs no --extra megatron) and the pure-torch path actually runs.
The stub is import-scope so it also reaches the mp.spawn workers. Mark the
module pytest.mark.megatron for selection where megatron is installed.
Cleanup:
* Scrub development-history / out-of-band comments ("original downstream
monkey-patch", "accompanying review message", "the original bug", "BUG 2").
* Remove dead code: unused GPU-test helpers (_stock_logprobs/_fused_logprobs)
and their now-unused import; unused partition_vocab_size in
FusedLinearLogprob.forward; drop fully_oov from the Triton save_for_backward
(it was saved but never read in backward).
* DRY: extract the shared chosen-token scatter-add used by both
ChunkedDistributedLogprob.backward and FusedLinearLogprob.backward into
_add_chosen_token_grad so the OOV/index convention stays in lockstep.
* Cache exp(shifted) once in _distributed_log_softmax_and_entropy.
* Fix stale/elided test-path doc pointers and the MuP-fallback pointer; add the
missing lm_head_weight/fused_backend/return_entropy Args to the packed-seq
log-prob docstring.
Formatting: ruff 0.11.9 + black 24.10.0 clean.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a fused LM-head log-prob optimization for the Megatron backend, featuring both a pure-PyTorch chunked implementation and a high-performance Triton-based backend vendored from verl. By fusing the LM-head projection into the chunked log-prob and entropy computation, the PR avoids materializing the massive full-vocab logits tensor, significantly reducing peak memory usage. The feedback recommends enhancing the robustness of the model wrapper by using *args and **kwargs to guard against future Megatron-core signature changes, broadening the exception handling when importing the Triton backend to catch driver or library mismatches, and checking CUDA availability alongside Triton to prevent cryptic errors in CPU-only environments.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def fused_output_layer_forward(input_, weight=None, runtime_gather_output=None): | ||
| w = weight if weight is not None else getattr(layer, "weight", None) | ||
| if w is None: | ||
| # Tied-embedding stage with no allocated head weight AND none passed in: | ||
| # we cannot fuse safely — defer to the original projection. | ||
| return original(input_, weight=weight, runtime_gather_output=runtime_gather_output) | ||
| hidden = input_ | ||
| if getattr(layer, "sequence_parallel", False): | ||
| # Gather the sequence-parallel-scattered hidden across TP. Its backward is a | ||
| # reduce-scatter (tensor_parallel_output_grad=True), which reduces the fused | ||
| # function's per-shard grad_hidden exactly as ColumnParallelLinear would. | ||
| hidden = gather_from_sequence_parallel_region( | ||
| hidden, tensor_parallel_output_grad=True, group=tpg | ||
| ) | ||
| holder["weight"] = w | ||
| # Return hidden in the logits position (+ no bias). The model's _scale_logits is | ||
| # identity here (MuP excluded above) and its transpose yields [b, s, h]. | ||
| return hidden, None |
There was a problem hiding this comment.
Using explicit keyword arguments like weight=None and runtime_gather_output=None can lead to TypeError if Megatron-core updates its ParallelLinear.forward signature or if other wrappers pass unexpected arguments. Using *args and **kwargs makes the wrapper future-proof and robust against signature changes.
def fused_output_layer_forward(input_, *args, **kwargs):
w = kwargs.get("weight", None)
if w is None and len(args) > 0:
w = args[0]
w = w if w is not None else getattr(layer, "weight", None)
if w is None:
# Tied-embedding stage with no allocated head weight AND none passed in:
# we cannot fuse safely — defer to the original projection.
return original(input_, *args, **kwargs)
hidden = input_
if getattr(layer, "sequence_parallel", False):
# Gather the sequence-parallel-scattered hidden across TP. Its backward is a
# reduce-scatter (tensor_parallel_output_grad=True), which reduces the fused
# function's per-shard grad_hidden exactly as ColumnParallelLinear would.
hidden = gather_from_sequence_parallel_region(
hidden, tensor_parallel_output_grad=True, group=tpg
)
holder["weight"] = w
# Return hidden in the logits position (+ no bias). The model's _scale_logits is
# identity here (MuP excluded above) and its transpose yields [b, s, h].
return hidden, NoneThere was a problem hiding this comment.
Claude suggested that this is a proper copyright notice but this should definitely be checked by someone more familiar with the repo's conventions around this
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Summary
Adds an opt-in fused LM-head path for the Megatron backend that folds the vocab projection into the chunked log-prob (and entropy), so the full
[*, seq, vocab // TP]logits tensor is never materialized. That tensor is the dominant activation transient on the log-prob path at large vocab + long context (e.g. ~4 GiB fp32 at micro-bs 2 / 32k ctx / 131k vocab / TP4); eliminating it unlocks larger micro-batches / no-recompute without OOM.Serves SFT and all RL losses (cross_entropy, ppo, grpo, …): the fused kernel returns
(log_probs, entropy)— verl's established dual-output — so RL's entropy term is satisfied from the fused output instead of needing the full logits. Numerically equivalent to the stock logits path; default OFF → byte-identical to today.Design ported from verl's fused linear-cross-entropy (
volcengine/verl, Apache-2.0). Two backends behind one flag:torch(default): pure-PyTorch, vocab-parallelFusedLinearLogprob. Chunks over the seq dim, recomputes per-chunk logits in backward, fp32 softmax/entropy math, masks out-of-shard targets. Runs anywhere (CPU/GPU), no extra deps.triton: vendored flash-style Triton kernel — tiles over the vocab dim so per-chunk logits never materialize (lower memory floor + faster). GPU +tritonrequired; falls back totorchwith a warning if unavailable.Config
Megatron-only. Uses the existing
logprobs_chunk_sizefor the chunk width.How it integrates (no megatron-core edits)
MegatronModelWrapperwrapsoutput_layer.forwardat model-build to return the pre-projection hidden state (skipping the vocab GEMM) and capture the LM-head weight shard. Correctness points:cross_entropy) requests logprob-only (skips the entropy all-reduce); RL requestsreturn_entropy=True. The fused per-token entropy =logsumexp(logits) − Σ(softmax·logits), algebraically identical to stock_VocabParallelEntropyand verl; backward adds the matchingdentropyterm. RL entropy is sliced to the action window (non-packed) or fed throughvocab_parallel_entropy_packed_sequencesvia a newprecomputed_entropy_tokensarg (packed action-weighting + CP all-reduce unchanged — verified a bit-identical substitution). KL keeps using logprobs.hidden→ LM-headweight.dtypebefore the GEMM (the common hybrid case is fp32 pre-head hidden + bf16 head), mirroringColumnParallelLinear; grad cast back on the way out.weight=kwarg the model passes (shared_embedding_or_output_weight()) → nooutput_layer.weight is Nonecrash (cf. verl#3730).tensor_parallel_output_grad=True; its reduce-scatter backward performs the TP reduction ofgrad_hiddenexactly as stockColumnParallelLinear. Vocab range derived from the captured weight shard, nothidden.shape[-1].loss_funcand forward-onlycollection_funcroute through the fused path. Auto-fallback to stock under MuP output scaling; default fused-off is byte-identical.Verification
tests/backends/skyrl_train/cpu/megatron/test_fused_linear_logprob.py): gloo TP1/TP2, chunked + single, with/without OOV. Logprob fwd + grad-hidden bit-exact, grad-weight to fp32 rounding vs the stock logits path; entropy fwd + backward parity vs stockvocab_parallel_entropy(~5e-7), entropy target-independence, packedprecomputed_entropy_tokensbit-identical substitution, slice-alignment guard. Runs on CPU CI (no GPU / no megatron-core needed —megatron.core.parallel_stateis stubbed when absent).tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_fused_linear_logprob_triton.py), 2×H100: 24/24 pass. Logprob 16/16 (incl. fp32+bf16 dtype-mix) + entropy 8/8 (fp32 ~5e-7 / ~1e-8 grads under IEEE; bf16 within tolerance; OOV-invariant). Marked@pytest.mark.megatronso the megatron GPU CI job (-m megatron) selects it. The fp32 cases pin IEEE (FORCE_FP32_IEEE_PRECISION) as a math-is-exact sanity check (default-precision fwd error 1.27e-3 TF32 vs 9.5e-7 IEEE); production uses the fast default (TF32) precision — the model trains in bf16 where it matches stock within ~3e-3.Licensing / attribution
verl is Apache-2.0 (same as SkyRL); SkyRL already vendors verl in ~8 files with the
# This code is adapted from VERL … original copyright reproduced below:convention. The vendored files (verl/utils/kernel/{kernels.py,linear_cross_entropy.py}) carry a dual copyright —Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATESandCopyright 2024 Bytedance Ltd. and/or its affiliates(verl is a Bytedance/volcengine project; the NVIDIA line is an NVIDIA-contributed portion). Both are reproduced verbatim in the vendored file header with aChanges:note (§4(b)). Adds a top-levelNOTICE(SkyRL had none) reproducing verl'sNotice.txt. Preserves verl's out-of-bounds-vocab-infmasking fix (verl#2656).Test plan
uv run --isolated --extra dev -- pytest -s tests/backends/skyrl_train/cpu/megatron/test_fused_linear_logprob.py(CPU, no megatron-core)pytest -s tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_fused_linear_logprob_triton.py -m megatronon 2×H100 — 24/24ruff+blackclean (pre-commit)fused_linear_logprob: falsepath byte-identical to current behaviour🤖 Generated with Claude Code