Skip to content

[megatron] Fused LM-head log-prob + entropy (avoid full [*, seq, vocab] logit materialization)#1765

Open
dyurk-lila wants to merge 7 commits into
NovaSky-AI:mainfrom
dyurk-lila:feat/fused-linear-cross-entropy
Open

[megatron] Fused LM-head log-prob + entropy (avoid full [*, seq, vocab] logit materialization)#1765
dyurk-lila wants to merge 7 commits into
NovaSky-AI:mainfrom
dyurk-lila:feat/fused-linear-cross-entropy

Conversation

@dyurk-lila

Copy link
Copy Markdown

Summary

Adds an opt-in fused LM-head path for the Megatron backend that folds the vocab projection into the chunked log-prob (and entropy), so the full [*, seq, vocab // TP] logits tensor is never materialized. That tensor is the dominant activation transient on the log-prob path at large vocab + long context (e.g. ~4 GiB fp32 at micro-bs 2 / 32k ctx / 131k vocab / TP4); eliminating it unlocks larger micro-batches / no-recompute without OOM.

Serves SFT and all RL losses (cross_entropy, ppo, grpo, …): the fused kernel returns (log_probs, entropy) — verl's established dual-output — so RL's entropy term is satisfied from the fused output instead of needing the full logits. Numerically equivalent to the stock logits path; default OFF → byte-identical to today.

Design ported from verl's fused linear-cross-entropy (volcengine/verl, Apache-2.0). Two backends behind one flag:

  • torch (default): pure-PyTorch, vocab-parallel FusedLinearLogprob. Chunks over the seq dim, recomputes per-chunk logits in backward, fp32 softmax/entropy math, masks out-of-shard targets. Runs anywhere (CPU/GPU), no extra deps.
  • triton: vendored flash-style Triton kernel — tiles over the vocab dim so per-chunk logits never materialize (lower memory floor + faster). GPU + triton required; falls back to torch with a warning if unavailable.

Config

trainer.fused_linear_logprob: false          # default off → stock behaviour, byte-identical
trainer.fused_linear_logprob_backend: torch  # "torch" | "triton"

Megatron-only. Uses the existing logprobs_chunk_size for the chunk width.

How it integrates (no megatron-core edits)

MegatronModelWrapper wraps output_layer.forward at model-build to return the pre-projection hidden state (skipping the vocab GEMM) and capture the LM-head weight shard. Correctness points:

  • Logprob and entropy come from one fused call. SFT (cross_entropy) requests logprob-only (skips the entropy all-reduce); RL requests return_entropy=True. The fused per-token entropy = logsumexp(logits) − Σ(softmax·logits), algebraically identical to stock _VocabParallelEntropy and verl; backward adds the matching dentropy term. RL entropy is sliced to the action window (non-packed) or fed through vocab_parallel_entropy_packed_sequences via a new precomputed_entropy_tokens arg (packed action-weighting + CP all-reduce unchanged — verified a bit-identical substitution). KL keeps using logprobs.
  • dtype: the fused path casts hidden → LM-head weight.dtype before the GEMM (the common hybrid case is fp32 pre-head hidden + bf16 head), mirroring ColumnParallelLinear; grad cast back on the way out.
  • Tied embeddings: weight taken from the weight= kwarg the model passes (shared_embedding_or_output_weight()) → no output_layer.weight is None crash (cf. verl#3730).
  • Sequence parallelism: SP hidden gathered with tensor_parallel_output_grad=True; its reduce-scatter backward performs the TP reduction of grad_hidden exactly as stock ColumnParallelLinear. Vocab range derived from the captured weight shard, not hidden.shape[-1].
  • Both training loss_func and forward-only collection_func route through the fused path. Auto-fallback to stock under MuP output scaling; default fused-off is byte-identical.

Verification

  • CPU (tests/backends/skyrl_train/cpu/megatron/test_fused_linear_logprob.py): gloo TP1/TP2, chunked + single, with/without OOV. Logprob fwd + grad-hidden bit-exact, grad-weight to fp32 rounding vs the stock logits path; entropy fwd + backward parity vs stock vocab_parallel_entropy (~5e-7), entropy target-independence, packed precomputed_entropy_tokens bit-identical substitution, slice-alignment guard. Runs on CPU CI (no GPU / no megatron-core needed — megatron.core.parallel_state is stubbed when absent).
  • GPU (tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_fused_linear_logprob_triton.py), 2×H100: 24/24 pass. Logprob 16/16 (incl. fp32+bf16 dtype-mix) + entropy 8/8 (fp32 ~5e-7 / ~1e-8 grads under IEEE; bf16 within tolerance; OOV-invariant). Marked @pytest.mark.megatron so the megatron GPU CI job (-m megatron) selects it. The fp32 cases pin IEEE (FORCE_FP32_IEEE_PRECISION) as a math-is-exact sanity check (default-precision fwd error 1.27e-3 TF32 vs 9.5e-7 IEEE); production uses the fast default (TF32) precision — the model trains in bf16 where it matches stock within ~3e-3.

Licensing / attribution

verl is Apache-2.0 (same as SkyRL); SkyRL already vendors verl in ~8 files with the # This code is adapted from VERL … original copyright reproduced below: convention. The vendored files (verl/utils/kernel/{kernels.py,linear_cross_entropy.py}) carry a dual copyrightCopyright (c) 2025 NVIDIA CORPORATION & AFFILIATES and Copyright 2024 Bytedance Ltd. and/or its affiliates (verl is a Bytedance/volcengine project; the NVIDIA line is an NVIDIA-contributed portion). Both are reproduced verbatim in the vendored file header with a Changes: note (§4(b)). Adds a top-level NOTICE (SkyRL had none) reproducing verl's Notice.txt. Preserves verl's out-of-bounds-vocab -inf masking fix (verl#2656).

Test plan

  • uv run --isolated --extra dev -- pytest -s tests/backends/skyrl_train/cpu/megatron/test_fused_linear_logprob.py (CPU, no megatron-core)
  • pytest -s tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_fused_linear_logprob_triton.py -m megatron on 2×H100 — 24/24
  • ruff + black clean (pre-commit)
  • Default fused_linear_logprob: false path byte-identical to current behaviour

🤖 Generated with Claude Code

dyurk-lila and others added 6 commits June 8, 2026 15:09
…rializing [*, seq, vocab] logits

Adds an opt-in fused LM-head log-prob path for the Megatron backend that folds the
vocab projection into the chunked log-prob so the full [*, seq, vocab // TP] logits
tensor is never materialized — the dominant activation transient on the SFT/logprob
path at large vocab / long context. Numerically equivalent to the stock logits path.

Design ported from verl's fused linear-cross-entropy. Two backends behind one flag:
  * "torch" (default): a pure-PyTorch vocab-parallel FusedLinearLogprob autograd
    Function. Chunks over the sequence dim, recomputes the per-chunk logits in
    backward, fp32 softmax math, masks out-of-shard targets. Runs anywhere (CPU/GPU),
    no extra deps. Verified bit-exact (fwd + grad-hidden) and fp32-rounding-exact
    (grad-weight) vs the stock logits path on a real gloo TP=1/TP=2 group.
  * "triton": vendored flash-style Triton kernel (from volcengine/verl, Apache-2.0)
    that tiles over the vocab dim so per-chunk logits never materialize (lower memory,
    faster). GPU + triton required; falls back to "torch" with a warning otherwise.

Integration (model-source-free; no megatron-core edits):
  * MegatronModelWrapper wraps output_layer.forward to return the pre-projection
    hidden state (skipping the vocab GEMM) and capture the LM-head weight shard. The
    weight is resolved from the weight= kwarg the model passes — i.e.
    shared_embedding_or_output_weight() — so tied embeddings work with no None-weight
    crash. Sequence-parallel hidden is gathered with tensor_parallel_output_grad=True,
    whose backward reduce-scatters grad_hidden across TP exactly as ColumnParallelLinear
    does, so the fused function's per-shard grad_hidden is reduced correctly.
  * Both the training loss_func and the forward-only collection_func (eval / reference
    logprobs) route through the fused path; the vocab range is derived from the captured
    weight shard, never from hidden.shape[-1].
  * Falls back to the stock logits path automatically when the model uses MuP output
    scaling (the fused path bypasses _scale_logits) and is restricted to the
    cross_entropy loss (the RL entropy/KL terms still need full logits).

Config: trainer.fused_linear_logprob (bool, default False -> byte-identical to today)
and trainer.fused_linear_logprob_backend ("torch" | "triton"). Megatron-only.

Tests: tests/.../cpu/megatron/test_fused_linear_logprob.py (CPU gloo TP=1/TP=2,
with/without out-of-shard targets, chunked + single-chunk) asserts the fused path
matches the stock logits path for forward, grad-hidden, and grad-weight. A GPU test
covers the Triton backend.

The vendored Triton kernel preserves verl's out-of-bounds-vocab -inf masking fix
(verl-project/verl#2656) and reproduces the NVIDIA + Bytedance Apache-2.0 attribution
(see file header + new top-level NOTICE).

Signed-off-by: dyurk-lila <dyurk@lila.ai>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ision policy

Validated on 2xH100 (buzz): all 16 GPU cases pass (TP1/TP2 × chunk {8,1000} ×
{fp32,bf16} × with/without out-of-shard targets).

- Fix an out-of-vocab gradient bug in the Triton adapter backward: it was zeroing
  grad_output at fully-OOV positions, but the reference (stock DistributedLogprob /
  the pure-torch FusedLinearLogprob) produces d_logits = -softmax * grad_output
  there (no owned chosen column => the onehot term drops, the -softmax term remains).
  Zeroing diverged the gradient (~0.04 abs). Now grad_output passes through unchanged;
  the forward still masks the OOV logprob value to 0. Matches the torch backend exactly.

- Precision policy: production uses the fast default matmul precision (TF32 on Hopper) —
  torch/triton's default, and the model trains in bf16 where it's empirically fine
  (bf16 cases match the stock path within ~3e-3). A module-level FORCE_FP32_IEEE_PRECISION
  flag (default False) lets the test pin IEEE for fp32 cases as a reproducible
  "kernel math is exact" sanity check. Measured at TP=1 fp32: default precision fwd
  error 1.27e-3 (TF32 rounding), IEEE 9.5e-7 (exact) — confirming the only default-path
  gap is tensor-core rounding, not a math error. apply() signature unchanged.

Signed-off-by: dyurk-lila <dyurk@lila.ai>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…gprob adapter

The Triton kernel's tl.dot(hidden, weight.T) requires both operands to share a
dtype. On a real hybrid model the pre-head hidden state is fp32 (final norm) while
the LM-head weight is bf16, so the adapter crashed with
"AssertionError: Both operands must be same dtype. Got fp32 and bf16" on the first
real training step (the fp32/fp32 and bf16/bf16 test cases never exposed it).

Mirror the pure-torch FusedLinearLogprob (hidden.to(weight.dtype) @ weight.t()) and
ColumnParallelLinear (which casts its input to the weight dtype before the bf16
GEMM): cast hidden to weight.dtype before the kernel, and cast the returned d_hidden
back to the original hidden dtype so autograd accumulates it into the fp32 hidden
leaf correctly. No-op when dtypes already match.

Signed-off-by: dyurk-lila <dyurk@lila.ai>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…through the fused path

The fused linear log-prob previously served only the SFT cross_entropy loss and hard-raised
for every other (RL) loss, because the RL loss_func also needs ENTROPY, which it computed
from the full logits the fused path deliberately never materializes. This wires entropy
through the fused kernel — verl's established (log_probs, entropy) dual-output — so SFT and
ALL RL losses (ppo/grpo/etc.) can use fusion.

No new kernel math: verl's per-token entropy = logsumexp(logits) - Σ(softmax·logits) is
algebraically identical to stock _VocabParallelEntropy (max + log(sum_exp) - Σpx) and to the
Triton epilogue's entropy; the Triton kernel already computes+returns entropy and already
accepts a dentropy buffer in backward (the adapter was discarding both).

- model_utils.py: new _distributed_log_softmax_and_entropy (reuses the logprob path's MAX+SUM
  all-reduces + one extra SUM of Σ(softmax·logit)). FusedLinearLogprob gains return_entropy
  (default False → byte-identical); when True returns (logprob, entropy) and backward adds the
  entropy term dlogits += softmax·(Σpx − logits)·grad_entropy (== stock _VocabParallelEntropy
  backward; == verl's term). Both public entry points + the dispatcher thread return_entropy;
  entropy is returned raw per-position (not rolled/trimmed). vocab_parallel_entropy_packed_sequences
  gains precomputed_entropy_tokens so the packed action-weighting + CP all-reduce stay identical,
  only the entropy source swaps.
- fused_linear_logprob_triton.py: stop discarding verl's entropy; wire the incoming entropy grad
  into the existing dentropy buffer. Entropy is NOT OOV-masked (target-independent). No kernel edits.
- megatron_model_wrapper.py: remove the cross_entropy-only raise; route RL entropy through the
  fused dual-output (non-packed slices the action window [-num_actions-1:-1]; packed feeds
  precomputed_entropy_tokens). KL keeps using logprobs. SFT stays logprob-only (skips the entropy
  all-reduce). MuP fallback + fused-off byte-identical preserved.
- tests: CPU entropy fwd+backward parity vs stock vocab_parallel_entropy (TP1/TP2, chunk, OOV),
  target-independence, packed-substitution bit-identity, slice-alignment guard; GPU Triton entropy
  parity test. All CPU runners pass (logprob regression bit-identical; entropy ~5e-7).

Signed-off-by: dyurk-lila <dyurk@lila.ai>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ropy GPU test

The new entropy parity test crashed in its REFERENCE (vocab_parallel_entropy), which
all-reduces over mpu.get_tensor_model_parallel_group() — uninitialized because the worker
only called dist.init_process_group, never mpu.initialize_model_parallel. (The logprob test
sidesteps this by referencing dist.group.WORLD directly.) Initialize model-parallel state
with tensor_model_parallel_size=world_size (TP group == WORLD membership, so both the stock
reference and the fused path reduce over identical ranks) and destroy it in the finally.
Harness-only fix; the Triton kernel/adapter were never reached by the failing assertion.

Signed-off-by: dyurk-lila <dyurk@lila.ai>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Polish the fused linear log-prob feature so it reads and behaves as a native
SkyRL feature, with no internal/development-history references and CI-correct
tests. No functional change to the fused math (CPU gloo parity re-verified at
TP=1/2, fwd bit-exact, grads at fp32 round-off).

CI correctness:
* GPU Triton test: add pytest.mark.megatron (the megatron GPU job selects by
  `-m megatron`, so the test was collected but never run); combine with the
  existing skipif in a module-level pytestmark list.
* CPU test: stub megatron.core.parallel_state at module import scope when
  megatron-core is absent, so collection no longer crashes the CPU CI job
  (which installs no --extra megatron) and the pure-torch path actually runs.
  The stub is import-scope so it also reaches the mp.spawn workers. Mark the
  module pytest.mark.megatron for selection where megatron is installed.

Cleanup:
* Scrub development-history / out-of-band comments ("original downstream
  monkey-patch", "accompanying review message", "the original bug", "BUG 2").
* Remove dead code: unused GPU-test helpers (_stock_logprobs/_fused_logprobs)
  and their now-unused import; unused partition_vocab_size in
  FusedLinearLogprob.forward; drop fully_oov from the Triton save_for_backward
  (it was saved but never read in backward).
* DRY: extract the shared chosen-token scatter-add used by both
  ChunkedDistributedLogprob.backward and FusedLinearLogprob.backward into
  _add_chosen_token_grad so the OOV/index convention stays in lockstep.
* Cache exp(shifted) once in _distributed_log_softmax_and_entropy.
* Fix stale/elided test-path doc pointers and the MuP-fallback pointer; add the
  missing lm_head_weight/fused_backend/return_entropy Args to the packed-seq
  log-prob docstring.

Formatting: ruff 0.11.9 + black 24.10.0 clean.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused LM-head log-prob optimization for the Megatron backend, featuring both a pure-PyTorch chunked implementation and a high-performance Triton-based backend vendored from verl. By fusing the LM-head projection into the chunked log-prob and entropy computation, the PR avoids materializing the massive full-vocab logits tensor, significantly reducing peak memory usage. The feedback recommends enhancing the robustness of the model wrapper by using *args and **kwargs to guard against future Megatron-core signature changes, broadening the exception handling when importing the Triton backend to catch driver or library mismatches, and checking CUDA availability alongside Triton to prevent cryptic errors in CPU-only environments.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +84 to +101
def fused_output_layer_forward(input_, weight=None, runtime_gather_output=None):
w = weight if weight is not None else getattr(layer, "weight", None)
if w is None:
# Tied-embedding stage with no allocated head weight AND none passed in:
# we cannot fuse safely — defer to the original projection.
return original(input_, weight=weight, runtime_gather_output=runtime_gather_output)
hidden = input_
if getattr(layer, "sequence_parallel", False):
# Gather the sequence-parallel-scattered hidden across TP. Its backward is a
# reduce-scatter (tensor_parallel_output_grad=True), which reduces the fused
# function's per-shard grad_hidden exactly as ColumnParallelLinear would.
hidden = gather_from_sequence_parallel_region(
hidden, tensor_parallel_output_grad=True, group=tpg
)
holder["weight"] = w
# Return hidden in the logits position (+ no bias). The model's _scale_logits is
# identity here (MuP excluded above) and its transpose yields [b, s, h].
return hidden, None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using explicit keyword arguments like weight=None and runtime_gather_output=None can lead to TypeError if Megatron-core updates its ParallelLinear.forward signature or if other wrappers pass unexpected arguments. Using *args and **kwargs makes the wrapper future-proof and robust against signature changes.

            def fused_output_layer_forward(input_, *args, **kwargs):
                w = kwargs.get("weight", None)
                if w is None and len(args) > 0:
                    w = args[0]
                w = w if w is not None else getattr(layer, "weight", None)
                if w is None:
                    # Tied-embedding stage with no allocated head weight AND none passed in:
                    # we cannot fuse safely — defer to the original projection.
                    return original(input_, *args, **kwargs)
                hidden = input_
                if getattr(layer, "sequence_parallel", False):
                    # Gather the sequence-parallel-scattered hidden across TP. Its backward is a
                    # reduce-scatter (tensor_parallel_output_grad=True), which reduces the fused
                    # function's per-shard grad_hidden exactly as ColumnParallelLinear would.
                    hidden = gather_from_sequence_parallel_region(
                        hidden, tensor_parallel_output_grad=True, group=tpg
                    )
                holder["weight"] = w
                # Return hidden in the logits position (+ no bias). The model's _scale_logits is
                # identity here (MuP excluded above) and its transpose yields [b, s, h].
                return hidden, None

Comment thread skyrl/backends/skyrl_train/distributed/megatron/model_utils.py Outdated
Comment thread skyrl/backends/skyrl_train/distributed/megatron/fused_linear_logprob_triton.py Outdated
Comment thread NOTICE

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude suggested that this is a proper copyright notice but this should definitely be checked by someone more familiar with the repo's conventions around this

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant