diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index b634a668c..99d50dcba 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1410,10 +1410,6 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ config = model_configs[model] - if IS_HIP_EXTENSION: - if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias: - pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") - te_linear_ref = Linear( config.hidden_size, 4 * config.hidden_size, @@ -1708,9 +1704,6 @@ def test_layernorm_linear_accuracy( def test_layernorm_linear_accuracy_delay_wgrad_compute( dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation ): - if IS_HIP_EXTENSION: - if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias: - pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") if NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("Delayed wgrad compute is not supported in debug mode.") @@ -1932,10 +1925,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( config = model_configs[model] - if IS_HIP_EXTENSION: - if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias: - pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") - ln_mlp = LayerNormMLP( hidden_size=config.hidden_size, ffn_hidden_size=4 * config.hidden_size, diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 6ddabd6cc..bc6893620 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -431,6 +431,18 @@ def general_gemm( ) out_dtype = torch.bfloat16 if use_bf16_tn_output_workaround else out_dtype + # hipBLASLt has no fused bias-grad (BGRADB) algorithm for an fp32-output wgrad GEMM, so skip the fused dbias and reduce grad_output below. + rocm_split_dbias = ( + IS_HIP_EXTENSION + and grad + and bias is not None + and not gelu + and out is not None + and out.dtype == torch.float32 + and quantization_params is None + ) + gemm_bias = None if rocm_split_dbias else bias + args = ( A, transa, # transa @@ -439,7 +451,7 @@ def general_gemm( out, quantization_params, TE_DType[out_dtype] if out_dtype is not None else None, - bias, + gemm_bias, bias_dtype, gelu, gelu_in, @@ -460,6 +472,9 @@ def general_gemm( out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) + if rocm_split_dbias: + bias_grad = B.reshape(-1, B.shape[-1]).sum(dim=0, dtype=torch.float32).to(bias.dtype) + if IS_HIP_EXTENSION and use_bf16_tn_output_workaround: out = cast_if_needed(out, torch.float32)