From 38d375dc4c812ca50ba2a13baab24e945203bcfc Mon Sep 17 00:00:00 2001 From: Xinyu Jiang Date: Thu, 11 Jun 2026 02:12:27 -0400 Subject: [PATCH 1/5] [ROCm] Skip fused bias-grad in wgrad GEMM on fp32-accumulate path On ROCm, hipBLASLt has no algorithm for a bf16 -> fp32-accumulate wgrad GEMM that also fuses the bias-gradient (BGRADB) epilogue: the heuristic returns zero algorithms and the GEMM raises "Unable to find any suitable algorithms". This hits any LayerNormLinear with bias (e.g. Qwen2.5 QKV with add-qkv-bias) when training with fp32 gradient accumulation (--accumulate-allreduce-grads-in-fp32). When wgrad is accumulated into an fp32 main_grad on ROCm, skip the fused dbias and reduce grad_bias separately (grad_output.sum over tokens in fp32, cast to bias dtype) -- mathematically identical to the BGRADB epilogue. CUDA and all other paths are unchanged. Co-Authored-By: Jessica Jiang Signed-off-by: Xinyu Jiang --- .../pytorch/module/layernorm_linear.py | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ed3ef10fe..f5c69233f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -892,6 +892,19 @@ def backward( dgrad_shape, dtype=ctx.activation_dtype, device=grad_outputs[0].device ) + # ROCm hipBLASLt has no algorithm for a bf16 -> fp32-accumulate wgrad + # GEMM that also fuses the bias-gradient (BGRADB) epilogue, so the + # heuristic returns no algorithms and the GEMM raises. When wgrad is + # accumulated into an fp32 main_grad on ROCm, skip the fusion and + # reduce grad_bias separately below. + rocm_unfuse_dbias = ( + IS_HIP_EXTENSION + and accumulate_wgrad_into_param_main_grad + and grad_bias is None + and not ctx.fp8 + and bias is not None + ) + # Arguments to include in wgrad GEMM closure wgrad_gemm_kwargs = { "out_dtype": ( @@ -905,7 +918,11 @@ def backward( ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not ctx.fp8) else None), + "bias": ( + bias + if (grad_bias is None and not ctx.fp8 and not rocm_unfuse_dbias) + else None + ), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, @@ -951,7 +968,16 @@ def wgrad_gemm( # Update grad bias if needed if grad_bias is None: - grad_bias = grad_bias_ + if rocm_unfuse_dbias: + # Fused dbias was suppressed for the ROCm fp32-accumulate + # path; reduce it here (fp32 accumulate, cast to bias dtype). + grad_bias = ( + grad_output.reshape(-1, grad_output.shape[-1]) + .sum(dim=0, dtype=torch.float32) + .to(bias.dtype) + ) + else: + grad_bias = grad_bias_ del grad_bias_ # Deallocate input tensors if permitted From c9904e3910e251e6f5a207a8ec3c2fd1ebb93337 Mon Sep 17 00:00:00 2001 From: Xinyu Jiang Date: Mon, 15 Jun 2026 00:07:17 -0500 Subject: [PATCH 2/5] Hoist ROCm wgrad bias-grad unfuse to general_gemm Move the BGRADB-unfuse workaround from the per-module LayerNormLinear backward up to general_gemm, the single chokepoint every wgrad path funnels through. This covers Linear, LayerNormLinear, LayerNormMLP and the delayed-wgrad store in one place, and fixes the delayed-wgrad path that the per-module version dropped the bias gradient on. CUDA, the forward bias-add path and fp8/fp4 are untouched. Co-Authored-By: Jessica Jiang Signed-off-by: Xinyu Jiang --- .../pytorch/cpp_extensions/gemm.py | 29 +++++++++++++++++- .../pytorch/module/layernorm_linear.py | 30 ++----------------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 6ddabd6cc..b8e2caecb 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -431,6 +431,27 @@ def general_gemm( ) out_dtype = torch.bfloat16 if use_bf16_tn_output_workaround else out_dtype + # ROCm: hipBLASLt has no algorithm for a bf16 -> fp32-accumulate wgrad GEMM + # that also fuses the bias-gradient (BGRADB) epilogue; the heuristic returns + # no algorithms and the GEMM raises "Unable to find any suitable algorithms". + # Run the GEMM with the default epilogue (no fused bias grad) and reduce the + # bias gradient separately afterwards. Every wgrad path (Linear, + # LayerNormLinear, LayerNormMLP, and the delayed-wgrad store) reads the bias + # gradient from this function's return value, so handling it here covers them + # all. Gated on ROCm + bias-grad (grad + bias) + fp32 main_grad accumulate + + # non-quantized, so CUDA, the forward bias-add path (grad=False) and fp8/fp4 + # are untouched. + rocm_split_dbias = ( + IS_HIP_EXTENSION + and grad + and bias is not None + and accumulate + and out is not None + and out.dtype == torch.float32 + and quantization_params is None + ) + gemm_bias = None if rocm_split_dbias else bias + args = ( A, transa, # transa @@ -439,7 +460,7 @@ def general_gemm( out, quantization_params, TE_DType[out_dtype] if out_dtype is not None else None, - bias, + gemm_bias, bias_dtype, gelu, gelu_in, @@ -460,6 +481,12 @@ def general_gemm( out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) + if rocm_split_dbias: + # dbias = column-sum of grad_output (operand B) over tokens, accumulated + # in fp32 and cast to the bias dtype to match the fused BGRADB epilogue + # this replaces. + bias_grad = B.reshape(-1, B.shape[-1]).sum(dim=0, dtype=torch.float32).to(bias.dtype) + if IS_HIP_EXTENSION and use_bf16_tn_output_workaround: out = cast_if_needed(out, torch.float32) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index f5c69233f..ed3ef10fe 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -892,19 +892,6 @@ def backward( dgrad_shape, dtype=ctx.activation_dtype, device=grad_outputs[0].device ) - # ROCm hipBLASLt has no algorithm for a bf16 -> fp32-accumulate wgrad - # GEMM that also fuses the bias-gradient (BGRADB) epilogue, so the - # heuristic returns no algorithms and the GEMM raises. When wgrad is - # accumulated into an fp32 main_grad on ROCm, skip the fusion and - # reduce grad_bias separately below. - rocm_unfuse_dbias = ( - IS_HIP_EXTENSION - and accumulate_wgrad_into_param_main_grad - and grad_bias is None - and not ctx.fp8 - and bias is not None - ) - # Arguments to include in wgrad GEMM closure wgrad_gemm_kwargs = { "out_dtype": ( @@ -918,11 +905,7 @@ def backward( ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": ( - bias - if (grad_bias is None and not ctx.fp8 and not rocm_unfuse_dbias) - else None - ), + "bias": (bias if (grad_bias is None and not ctx.fp8) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, @@ -968,16 +951,7 @@ def wgrad_gemm( # Update grad bias if needed if grad_bias is None: - if rocm_unfuse_dbias: - # Fused dbias was suppressed for the ROCm fp32-accumulate - # path; reduce it here (fp32 accumulate, cast to bias dtype). - grad_bias = ( - grad_output.reshape(-1, grad_output.shape[-1]) - .sum(dim=0, dtype=torch.float32) - .to(bias.dtype) - ) - else: - grad_bias = grad_bias_ + grad_bias = grad_bias_ del grad_bias_ # Deallocate input tensors if permitted From 1715351269f122f042986f6cfbba32a44165043b Mon Sep 17 00:00:00 2001 From: Xinyu Jiang Date: Wed, 17 Jun 2026 07:33:19 -0500 Subject: [PATCH 3/5] Trim verbose comments in wgrad bias-grad split Signed-off-by: Xinyu Jiang --- .../pytorch/cpp_extensions/gemm.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index b8e2caecb..92e267680 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -431,16 +431,9 @@ def general_gemm( ) out_dtype = torch.bfloat16 if use_bf16_tn_output_workaround else out_dtype - # ROCm: hipBLASLt has no algorithm for a bf16 -> fp32-accumulate wgrad GEMM - # that also fuses the bias-gradient (BGRADB) epilogue; the heuristic returns - # no algorithms and the GEMM raises "Unable to find any suitable algorithms". - # Run the GEMM with the default epilogue (no fused bias grad) and reduce the - # bias gradient separately afterwards. Every wgrad path (Linear, - # LayerNormLinear, LayerNormMLP, and the delayed-wgrad store) reads the bias - # gradient from this function's return value, so handling it here covers them - # all. Gated on ROCm + bias-grad (grad + bias) + fp32 main_grad accumulate + - # non-quantized, so CUDA, the forward bias-add path (grad=False) and fp8/fp4 - # are untouched. + # hipBLASLt has no algorithm for the fused bias-grad (BGRADB) epilogue on a + # bf16 -> fp32-accumulate wgrad GEMM, so split it: run the GEMM without the + # fused dbias and reduce the bias gradient separately below. rocm_split_dbias = ( IS_HIP_EXTENSION and grad @@ -482,9 +475,7 @@ def general_gemm( out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) if rocm_split_dbias: - # dbias = column-sum of grad_output (operand B) over tokens, accumulated - # in fp32 and cast to the bias dtype to match the fused BGRADB epilogue - # this replaces. + # dbias = column-sum of grad_output over tokens, in fp32, cast to bias dtype. bias_grad = B.reshape(-1, B.shape[-1]).sum(dim=0, dtype=torch.float32).to(bias.dtype) if IS_HIP_EXTENSION and use_bf16_tn_output_workaround: From 627e0ef131f8a9063a473591e9f7ceb584b97f01 Mon Sep 17 00:00:00 2001 From: Xinyu Jiang Date: Thu, 18 Jun 2026 03:37:08 +0000 Subject: [PATCH 4/5] [ROCm] Trigger wgrad bias-grad split on fp32 output, not accumulate The hipBLASLt "no suitable algorithm" failure for the fused bias-grad (BGRADB) epilogue is driven by the fp32 output dtype, independent of accumulate, so the split must also cover the non-accumulating (e.g. first-microbatch) wgrad. Also exclude gelu, whose bias-grad is not a plain grad_output sum. Re-enable the ROCm numerics test that was skipped for this case. Co-authored-by: Zhiyao Jiang Signed-off-by: Xinyu Jiang --- tests/pytorch/test_numerics.py | 4 ---- transformer_engine/pytorch/cpp_extensions/gemm.py | 7 ++----- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index b634a668c..f81436096 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1410,10 +1410,6 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ config = model_configs[model] - if IS_HIP_EXTENSION: - if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias: - pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") - te_linear_ref = Linear( config.hidden_size, 4 * config.hidden_size, diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 92e267680..bc6893620 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -431,14 +431,12 @@ def general_gemm( ) out_dtype = torch.bfloat16 if use_bf16_tn_output_workaround else out_dtype - # hipBLASLt has no algorithm for the fused bias-grad (BGRADB) epilogue on a - # bf16 -> fp32-accumulate wgrad GEMM, so split it: run the GEMM without the - # fused dbias and reduce the bias gradient separately below. + # hipBLASLt has no fused bias-grad (BGRADB) algorithm for an fp32-output wgrad GEMM, so skip the fused dbias and reduce grad_output below. rocm_split_dbias = ( IS_HIP_EXTENSION and grad and bias is not None - and accumulate + and not gelu and out is not None and out.dtype == torch.float32 and quantization_params is None @@ -475,7 +473,6 @@ def general_gemm( out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) if rocm_split_dbias: - # dbias = column-sum of grad_output over tokens, in fp32, cast to bias dtype. bias_grad = B.reshape(-1, B.shape[-1]).sum(dim=0, dtype=torch.float32).to(bias.dtype) if IS_HIP_EXTENSION and use_bf16_tn_output_workaround: From abe36979464d59477e1000ac9098f6daf0d7ee81 Mon Sep 17 00:00:00 2001 From: Xinyu Jiang Date: Thu, 18 Jun 2026 03:50:57 +0000 Subject: [PATCH 5/5] [ROCm] Trigger wgrad bias-grad split on fp32 output, not accumulate The hipBLASLt "no suitable algorithm" failure for the fused bias-grad (BGRADB) epilogue is driven by the fp32 output dtype, independent of accumulate, so the split must also cover non-accumulating (e.g. first-microbatch) wgrad. Also exclude gelu, whose bias-grad is not a plain grad_output sum. Re-enable the Linear / LayerNormLinear / LayerNormMLP wgrad numerics tests skipped for this case; GroupedLinear routes through general_grouped_gemm and stays skipped. Co-authored-by: Zhiyao Jiang Signed-off-by: Xinyu Jiang --- tests/pytorch/test_numerics.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index f81436096..99d50dcba 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1704,9 +1704,6 @@ def test_layernorm_linear_accuracy( def test_layernorm_linear_accuracy_delay_wgrad_compute( dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation ): - if IS_HIP_EXTENSION: - if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias: - pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") if NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("Delayed wgrad compute is not supported in debug mode.") @@ -1928,10 +1925,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( config = model_configs[model] - if IS_HIP_EXTENSION: - if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias: - pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") - ln_mlp = LayerNormMLP( hidden_size=config.hidden_size, ffn_hidden_size=4 * config.hidden_size,