Support top-K distillation (SDFT/OPSD): teacher top-K sampling + soft-target CE training#1777
Support top-K distillation (SDFT/OPSD): teacher top-K sampling + soft-target CE training#1777atemaguer wants to merge 2 commits into
Conversation
The Tinker API accepts a `topk_prompt_logprobs` request field but skyrl-tx left
it unimplemented (the engine only returned scalar `prompt_logprobs`, so the
response field came back null). This blocks faithful top-K on-policy
self-distillation (SDFT / OPSD), which needs the teacher's top-K distribution at
each prompt position.
Compute top-K over the same full-prompt logits already materialized for
prompt_logprobs (memory-efficient: top_k of the logits minus full-vocab
logsumexp = true logprobs), and plumb it through the engine -> backend ->
SampleOutput so it serializes to the SDK's expected
`list[Optional[list[tuple[int, float]]]]` shape (index 0 = None, index i = top-K
(token_id, logprob) predicting token i).
- skyrl/tx/utils/generator.py: compute + return top-K in _prefill_and_decode/generate
- skyrl/backends/jax.py: pass topk through sample(), attach to SampleOutput
- skyrl/tinker/{types,engine,api}.py: carry topk_prompt_logprobs end-to-end
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The JAX/tinker training path only supported hard 1D targets (one target token per position via `compute_logprobs` + `-logprob * mask`). Soft top-K distillation (SDFT / OPSD forward-KL) needs, at each position, a distribution over K teacher tokens — a weighted cross-entropy over the teacher's top-K. With the old path the (num_tokens, K) target/weight tensors were flattened to length N*K and overflowed the N-based `max_length`, crashing forward_backward. This adds an opt-in soft path, gated on 2D (num_tokens, K) targets, that leaves the standard hard-target path byte-for-byte unchanged: - TensorData gains `shape`; the engine recovers (N, K) and sets `target_topk`. - logits_processor: `compute_topk_logprobs` / `logits_to_topk_logprobs` gather K logprobs per position (== compute_logprobs for K=1). - jax backend: separate jitted `_forward[_backward]_topk` fns compute the teacher-weighted per-position logprob `sum_k w[t,k] * logp(target[t,k])` and use the completion indicator (`weights.sum(-1) > 0`) as the loss mask, so the existing loss switch, per-sequence normalization and grad accumulation are reused unchanged. - pad_batch_topk pads row-major-flattened (N, K) data to [B, T, K]. Tests: tests/tx/utils/test_topk_logprobs.py (gather correctness, K=1 equivalence with compute_logprobs, weighted-sum == forward cross-entropy). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces support for soft top-K distillation variants (such as SDFT and OPSD forward-KL CE) in the JAX backend, allowing each position to carry K candidate target tokens and teacher weights. It adds new model forward and loss functions, padding utilities for 3D tensors, and support for returning top-K prompt logprobs during sampling. The review feedback highlights a critical validation gap in prepare_model_pass_batch where mixed batches of 1D and 2D targets can bypass consistency checks and crash the backend. To resolve this, the reviewer suggests initializing target_topk to None to properly track uninitialized states, ensuring strict consistency across requests, and validating that the weights shape matches the target tokens shape.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| all_loss_fns = [] | ||
| all_loss_fn_configs = [] | ||
| request_batch_slices = [] | ||
| target_topk = 0 # K for soft top-K distillation targets (0 => standard 1D) |
There was a problem hiding this comment.
Initialize target_topk to None instead of 0 to correctly distinguish between an uninitialized state and a standard 1D target (which has current_topk = 0). This is necessary to prevent mixed batches of 1D and 2D targets from silently passing the validation and crashing the backend.
| target_topk = 0 # K for soft top-K distillation targets (0 => standard 1D) | |
| target_topk = None # K for soft top-K distillation targets (None => not set yet) |
| # Detect 2D (num_tokens, K) targets => soft top-K distillation. The wire | ||
| # data is row-major flattened; the backend reshapes using target_topk. | ||
| tt_shape = loss_fn_inputs.target_tokens.shape | ||
| if tt_shape is not None and len(tt_shape) == 2: | ||
| if target_topk not in (0, tt_shape[1]): | ||
| raise ValueError(f"Inconsistent top-K across batch: {target_topk} vs {tt_shape[1]}") | ||
| target_topk = tt_shape[1] |
There was a problem hiding this comment.
The current validation allows mixed batches of 1D and 2D targets to bypass the consistency check. For example, if the first request is 1D (target_topk remains 0) and the second is 2D (tt_shape[1] = 5), target_topk is set to 5 without raising an error, leading to a crash in the backend.
To prevent this, we should explicitly track the current_topk for every request (defaulting to 0 for 1D targets) and ensure it is consistent across all requests in the batch. Additionally, we should validate that the weights shape matches the target_tokens shape when soft top-K distillation is used.
| # Detect 2D (num_tokens, K) targets => soft top-K distillation. The wire | |
| # data is row-major flattened; the backend reshapes using target_topk. | |
| tt_shape = loss_fn_inputs.target_tokens.shape | |
| if tt_shape is not None and len(tt_shape) == 2: | |
| if target_topk not in (0, tt_shape[1]): | |
| raise ValueError(f"Inconsistent top-K across batch: {target_topk} vs {tt_shape[1]}") | |
| target_topk = tt_shape[1] | |
| # Detect 2D (num_tokens, K) targets => soft top-K distillation. The wire | |
| # data is row-major flattened; the backend reshapes using target_topk. | |
| tt_shape = loss_fn_inputs.target_tokens.shape | |
| current_topk = tt_shape[1] if (tt_shape is not None and len(tt_shape) == 2) else 0 | |
| if target_topk is None: | |
| target_topk = current_topk | |
| elif target_topk != current_topk: | |
| raise ValueError(f"Inconsistent top-K across batch: {target_topk} vs {current_topk}") | |
| if current_topk > 0: | |
| w_shape = loss_fn_inputs.weights.shape | |
| if w_shape is not None and w_shape != tt_shape: | |
| raise ValueError(f"Weights shape {w_shape} must match target_tokens shape {tt_shape}") |
| all_model_ids=all_model_ids, | ||
| all_loss_fns=all_loss_fns, | ||
| all_loss_fn_configs=all_loss_fn_configs, | ||
| target_topk=target_topk, |
Summary
Adds support for top-K distillation (forward-KL, e.g. SDFT / OPSD) on the JAX/tinker backend. Two pieces are needed end-to-end and this PR adds both:
topk_prompt_logprobs(sampling): return the top-K token distribution at each prompt position. The Tinker API already accepts thetopk_prompt_logprobsrequest field, but the engine left it unimplemented (the response came backnull).Together these let cookbook recipes like SDFT (
tinker_cookbook.recipes.sdft, forward-KLtopk>0) run against a self-hosted skyrl-tx server.Motivation
Top-K distillation needs, at each completion position, the teacher's top-K distribution as the training target. Previously:
topk_prompt_logprobsreturnednull, so the teacher distribution couldn't be obtained; andcompute_logprobs+-logprob*mask), so the(num_tokens, K)target/weight tensors were flattened to lengthN*Kand overflowed theN-basedmax_length, crashingforward_backward.What changed
Teacher top-K sampling
generator.py: compute top-K over the prompt logits already materialized forprompt_logprobs(top-K of the logits minus full-vocab logsumexp = true logprobs), returned per position as[(token_id, logprob), ...]with index 0 =None(Tinker convention).topk_prompt_logprobsthroughSampleInput→ engine → backend →SampleOutput.Soft top-K cross-entropy training (opt-in; standard hard-target path is byte-for-byte unchanged)
TensorDatagainsshape(api + internal types) so 2-D(N, K)loss-fn inputs survive the request boundary; the engine recoversKand setstarget_topk.logits_processor.py:compute_topk_logprobs/logits_to_topk_logprobsgather K logprobs per position (≡compute_logprobsfor K=1)._forward[_backward]_topkcompute the teacher-weighted per-position logprobsum_k w[t,k]·logp(target[t,k])and use the completion indicator (weights.sum(-1) > 0) as the loss mask — so the existing loss switch, per-sequence normalization and grad accumulation are reused unchanged. Gated on 2-D targets viatarget_topk > 1.pad_batch_topkpads row-major-flattened(N, K)data to[B, T, K].Backward compatibility
The standard hard-target (1-D) path is unchanged and selected whenever
target_topk <= 1(i.e. all existing RL/SFT).compute_topk_logprobswith K=1 equalscompute_logprobs.Testing
tests/tx/utils/test_topk_logprobs.py(added): top-K gather vs numpy reference, K=1 equivalence withcompute_logprobs, weighted-sum == forward cross-entropy. Existingtest_logits_processor.pystill passes.forward_backwardwithN*K≫max_length(would crash pre-fix) succeeds, and the soft-CE loss decreases monotonically over optim steps (≈155 → 111), confirming gradients flow through the weighted top-K CE.topk_prompt_logprobsreturns populated top-K(token_id, logprob)per position.🤖 Generated with Claude Code