Skip to content

[ROCm] Fix biased wgrad with fp32 gradient accumulation#634

Open
XinyuJiangCMU wants to merge 5 commits into
ROCm:devfrom
XinyuJiangCMU:rocm-wgrad-bgrad-dbias-fix-v2
Open

[ROCm] Fix biased wgrad with fp32 gradient accumulation#634
XinyuJiangCMU wants to merge 5 commits into
ROCm:devfrom
XinyuJiangCMU:rocm-wgrad-bgrad-dbias-fix-v2

Conversation

@XinyuJiangCMU

Copy link
Copy Markdown

Problem

On ROCm, hipBLASLt cannot find a suitable algorithm for an fp32 weight gradient GEMM with fused bias gradient computation.

This causes training with --add-qkv-bias and --accumulate-allreduce-grads-in-fp32 to fail with:

RuntimeError: Unable to find any suitable algorithms

Fix

Run the weight gradient GEMM without the fused bias gradient and compute the bias gradient separately by summing grad_output.

The fix is implemented in general_gemm, covering delayed weight gradient execution and other callers using the same path. CUDA behavior is unchanged.

Testing

Verified on MI350X:

  • The isolated reproduction passes.

  • Qwen2.5-0.5B GSM8K training passes the previously failing backward step.

  • Re-enabled the ROCm wgrad numerics tests previously skipped by grouped GEMM change 434:

    • test_linear_accuracy_delay_wgrad_compute
    • test_layernorm_linear_accuracy_delay_wgrad_compute
    • test_layernorm_mlp_accuracy_delay_wgrad_compute

    All three use general_gemm.

Result:

132 passed, 0 skipped, 0 failed

XinyuJiangCMU and others added 5 commits June 18, 2026 04:41
On ROCm, hipBLASLt has no algorithm for a bf16 -> fp32-accumulate wgrad
GEMM that also fuses the bias-gradient (BGRADB) epilogue: the heuristic
returns zero algorithms and the GEMM raises "Unable to find any suitable
algorithms". This hits any LayerNormLinear with bias (e.g. Qwen2.5 QKV
with add-qkv-bias) when training with fp32 gradient accumulation
(--accumulate-allreduce-grads-in-fp32).

When wgrad is accumulated into an fp32 main_grad on ROCm, skip the fused
dbias and reduce grad_bias separately (grad_output.sum over tokens in
fp32, cast to bias dtype) -- mathematically identical to the BGRADB
epilogue. CUDA and all other paths are unchanged.

Co-Authored-By: Jessica Jiang <jessicajiang324@gmail.com>
Signed-off-by: Xinyu Jiang <xinyuj2@andrew.cmu.edu>
Move the BGRADB-unfuse workaround from the per-module LayerNormLinear backward
up to general_gemm, the single chokepoint every wgrad path funnels through.
This covers Linear, LayerNormLinear, LayerNormMLP and the delayed-wgrad store
in one place, and fixes the delayed-wgrad path that the per-module version
dropped the bias gradient on. CUDA, the forward bias-add path and fp8/fp4 are
untouched.

Co-Authored-By: Jessica Jiang <jessicajiang324@gmail.com>
Signed-off-by: Xinyu Jiang <xinyuj2@andrew.cmu.edu>
Signed-off-by: Xinyu Jiang <xinyuj2@andrew.cmu.edu>
The hipBLASLt "no suitable algorithm" failure for the fused bias-grad (BGRADB) epilogue is driven by the fp32 output dtype, independent of accumulate, so the split must also cover the non-accumulating (e.g. first-microbatch) wgrad. Also exclude gelu, whose bias-grad is not a plain grad_output sum. Re-enable the ROCm numerics test that was skipped for this case.

Co-authored-by: Zhiyao Jiang <jessicajiang324@gmail.com>
Signed-off-by: Xinyu Jiang <xinyuj2@andrew.cmu.edu>
The hipBLASLt "no suitable algorithm" failure for the fused bias-grad (BGRADB) epilogue is driven by the fp32 output dtype, independent of accumulate, so the split must also cover non-accumulating (e.g. first-microbatch) wgrad. Also exclude gelu, whose bias-grad is not a plain grad_output sum. Re-enable the Linear / LayerNormLinear / LayerNormMLP wgrad numerics tests skipped for this case; GroupedLinear routes through general_grouped_gemm and stays skipped.

Co-authored-by: Zhiyao Jiang <jessicajiang324@gmail.com>
Signed-off-by: Xinyu Jiang <xinyuj2@andrew.cmu.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant