Skip to content

Support top-K distillation (SDFT/OPSD): teacher top-K sampling + soft-target CE training#1777

Open
atemaguer wants to merge 2 commits into
NovaSky-AI:mainfrom
atemaguer:feat/topk-distillation
Open

Support top-K distillation (SDFT/OPSD): teacher top-K sampling + soft-target CE training#1777
atemaguer wants to merge 2 commits into
NovaSky-AI:mainfrom
atemaguer:feat/topk-distillation

Conversation

@atemaguer

Copy link
Copy Markdown
Contributor

Summary

Adds support for top-K distillation (forward-KL, e.g. SDFT / OPSD) on the JAX/tinker backend. Two pieces are needed end-to-end and this PR adds both:

  1. Teacher side — topk_prompt_logprobs (sampling): return the top-K token distribution at each prompt position. The Tinker API already accepts the topk_prompt_logprobs request field, but the engine left it unimplemented (the response came back null).
  2. Student side — soft top-K cross-entropy (training): train against a per-position distribution over K teacher tokens (a weighted cross-entropy), instead of a single hard target.

Together these let cookbook recipes like SDFT (tinker_cookbook.recipes.sdft, forward-KL topk>0) run against a self-hosted skyrl-tx server.

Motivation

Top-K distillation needs, at each completion position, the teacher's top-K distribution as the training target. Previously:

  • topk_prompt_logprobs returned null, so the teacher distribution couldn't be obtained; and
  • the trainer only supported hard 1-D targets (compute_logprobs + -logprob*mask), so the (num_tokens, K) target/weight tensors were flattened to length N*K and overflowed the N-based max_length, crashing forward_backward.

What changed

Teacher top-K sampling

  • generator.py: compute top-K over the prompt logits already materialized for prompt_logprobs (top-K of the logits minus full-vocab logsumexp = true logprobs), returned per position as [(token_id, logprob), ...] with index 0 = None (Tinker convention).
  • Plumbed topk_prompt_logprobs through SampleInput → engine → backend → SampleOutput.

Soft top-K cross-entropy training (opt-in; standard hard-target path is byte-for-byte unchanged)

  • TensorData gains shape (api + internal types) so 2-D (N, K) loss-fn inputs survive the request boundary; the engine recovers K and sets target_topk.
  • logits_processor.py: compute_topk_logprobs / logits_to_topk_logprobs gather K logprobs per position (≡ compute_logprobs for K=1).
  • JAX backend: separate jitted _forward[_backward]_topk compute the teacher-weighted per-position logprob sum_k w[t,k]·logp(target[t,k]) and use the completion indicator (weights.sum(-1) > 0) as the loss mask — so the existing loss switch, per-sequence normalization and grad accumulation are reused unchanged. Gated on 2-D targets via target_topk > 1.
  • pad_batch_topk pads row-major-flattened (N, K) data to [B, T, K].

Backward compatibility

The standard hard-target (1-D) path is unchanged and selected whenever target_topk <= 1 (i.e. all existing RL/SFT). compute_topk_logprobs with K=1 equals compute_logprobs.

Testing

  • tests/tx/utils/test_topk_logprobs.py (added): top-K gather vs numpy reference, K=1 equivalence with compute_logprobs, weighted-sum == forward cross-entropy. Existing test_logits_processor.py still passes.
  • End-to-end against a deployed skyrl-tx (Qwen3.5-2B): a soft forward_backward with N*Kmax_length (would crash pre-fix) succeeds, and the soft-CE loss decreases monotonically over optim steps (≈155 → 111), confirming gradients flow through the weighted top-K CE. topk_prompt_logprobs returns populated top-K (token_id, logprob) per position.

🤖 Generated with Claude Code

atemaguer and others added 2 commits June 10, 2026 18:03
The Tinker API accepts a `topk_prompt_logprobs` request field but skyrl-tx left
it unimplemented (the engine only returned scalar `prompt_logprobs`, so the
response field came back null). This blocks faithful top-K on-policy
self-distillation (SDFT / OPSD), which needs the teacher's top-K distribution at
each prompt position.

Compute top-K over the same full-prompt logits already materialized for
prompt_logprobs (memory-efficient: top_k of the logits minus full-vocab
logsumexp = true logprobs), and plumb it through the engine -> backend ->
SampleOutput so it serializes to the SDK's expected
`list[Optional[list[tuple[int, float]]]]` shape (index 0 = None, index i = top-K
(token_id, logprob) predicting token i).

- skyrl/tx/utils/generator.py: compute + return top-K in _prefill_and_decode/generate
- skyrl/backends/jax.py: pass topk through sample(), attach to SampleOutput
- skyrl/tinker/{types,engine,api}.py: carry topk_prompt_logprobs end-to-end

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The JAX/tinker training path only supported hard 1D targets (one target token
per position via `compute_logprobs` + `-logprob * mask`). Soft top-K
distillation (SDFT / OPSD forward-KL) needs, at each position, a distribution
over K teacher tokens — a weighted cross-entropy over the teacher's top-K. With
the old path the (num_tokens, K) target/weight tensors were flattened to length
N*K and overflowed the N-based `max_length`, crashing forward_backward.

This adds an opt-in soft path, gated on 2D (num_tokens, K) targets, that leaves
the standard hard-target path byte-for-byte unchanged:

- TensorData gains `shape`; the engine recovers (N, K) and sets `target_topk`.
- logits_processor: `compute_topk_logprobs` / `logits_to_topk_logprobs` gather
  K logprobs per position (== compute_logprobs for K=1).
- jax backend: separate jitted `_forward[_backward]_topk` fns compute the
  teacher-weighted per-position logprob
  `sum_k w[t,k] * logp(target[t,k])` and use the completion indicator
  (`weights.sum(-1) > 0`) as the loss mask, so the existing loss switch,
  per-sequence normalization and grad accumulation are reused unchanged.
- pad_batch_topk pads row-major-flattened (N, K) data to [B, T, K].

Tests: tests/tx/utils/test_topk_logprobs.py (gather correctness, K=1 equivalence
with compute_logprobs, weighted-sum == forward cross-entropy).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for soft top-K distillation variants (such as SDFT and OPSD forward-KL CE) in the JAX backend, allowing each position to carry K candidate target tokens and teacher weights. It adds new model forward and loss functions, padding utilities for 3D tensors, and support for returning top-K prompt logprobs during sampling. The review feedback highlights a critical validation gap in prepare_model_pass_batch where mixed batches of 1D and 2D targets can bypass consistency checks and crash the backend. To resolve this, the reviewer suggests initializing target_topk to None to properly track uninitialized states, ensuring strict consistency across requests, and validating that the weights shape matches the target tokens shape.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread skyrl/tinker/engine.py
all_loss_fns = []
all_loss_fn_configs = []
request_batch_slices = []
target_topk = 0 # K for soft top-K distillation targets (0 => standard 1D)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Initialize target_topk to None instead of 0 to correctly distinguish between an uninitialized state and a standard 1D target (which has current_topk = 0). This is necessary to prevent mixed batches of 1D and 2D targets from silently passing the validation and crashing the backend.

Suggested change
target_topk = 0 # K for soft top-K distillation targets (0 => standard 1D)
target_topk = None # K for soft top-K distillation targets (None => not set yet)

Comment thread skyrl/tinker/engine.py
Comment on lines +154 to +160
# Detect 2D (num_tokens, K) targets => soft top-K distillation. The wire
# data is row-major flattened; the backend reshapes using target_topk.
tt_shape = loss_fn_inputs.target_tokens.shape
if tt_shape is not None and len(tt_shape) == 2:
if target_topk not in (0, tt_shape[1]):
raise ValueError(f"Inconsistent top-K across batch: {target_topk} vs {tt_shape[1]}")
target_topk = tt_shape[1]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current validation allows mixed batches of 1D and 2D targets to bypass the consistency check. For example, if the first request is 1D (target_topk remains 0) and the second is 2D (tt_shape[1] = 5), target_topk is set to 5 without raising an error, leading to a crash in the backend.

To prevent this, we should explicitly track the current_topk for every request (defaulting to 0 for 1D targets) and ensure it is consistent across all requests in the batch. Additionally, we should validate that the weights shape matches the target_tokens shape when soft top-K distillation is used.

Suggested change
# Detect 2D (num_tokens, K) targets => soft top-K distillation. The wire
# data is row-major flattened; the backend reshapes using target_topk.
tt_shape = loss_fn_inputs.target_tokens.shape
if tt_shape is not None and len(tt_shape) == 2:
if target_topk not in (0, tt_shape[1]):
raise ValueError(f"Inconsistent top-K across batch: {target_topk} vs {tt_shape[1]}")
target_topk = tt_shape[1]
# Detect 2D (num_tokens, K) targets => soft top-K distillation. The wire
# data is row-major flattened; the backend reshapes using target_topk.
tt_shape = loss_fn_inputs.target_tokens.shape
current_topk = tt_shape[1] if (tt_shape is not None and len(tt_shape) == 2) else 0
if target_topk is None:
target_topk = current_topk
elif target_topk != current_topk:
raise ValueError(f"Inconsistent top-K across batch: {target_topk} vs {current_topk}")
if current_topk > 0:
w_shape = loss_fn_inputs.weights.shape
if w_shape is not None and w_shape != tt_shape:
raise ValueError(f"Weights shape {w_shape} must match target_tokens shape {tt_shape}")

Comment thread skyrl/tinker/engine.py
all_model_ids=all_model_ids,
all_loss_fns=all_loss_fns,
all_loss_fn_configs=all_loss_fn_configs,
target_topk=target_topk,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Pass target_topk or 0 to handle the case where target_topk remains None (e.g., if the batch is empty).

Suggested change
target_topk=target_topk,
target_topk=target_topk or 0,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant