Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,10 +1410,6 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_

config = model_configs[model]

if IS_HIP_EXTENSION:
if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias:
pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.")

te_linear_ref = Linear(
config.hidden_size,
4 * config.hidden_size,
Expand Down Expand Up @@ -1708,9 +1704,6 @@ def test_layernorm_linear_accuracy(
def test_layernorm_linear_accuracy_delay_wgrad_compute(
dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation
):
if IS_HIP_EXTENSION:
if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias:
pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.")
if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")

Expand Down Expand Up @@ -1932,10 +1925,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute(

config = model_configs[model]

if IS_HIP_EXTENSION:
if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias:
pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.")

ln_mlp = LayerNormMLP(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
Expand Down
17 changes: 16 additions & 1 deletion transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,18 @@ def general_gemm(
)
out_dtype = torch.bfloat16 if use_bf16_tn_output_workaround else out_dtype

# hipBLASLt has no fused bias-grad (BGRADB) algorithm for an fp32-output wgrad GEMM, so skip the fused dbias and reduce grad_output below.
rocm_split_dbias = (
IS_HIP_EXTENSION
and grad
and bias is not None
and not gelu
and out is not None
and out.dtype == torch.float32
and quantization_params is None
)
gemm_bias = None if rocm_split_dbias else bias

args = (
A,
transa, # transa
Expand All @@ -439,7 +451,7 @@ def general_gemm(
out,
quantization_params,
TE_DType[out_dtype] if out_dtype is not None else None,
bias,
gemm_bias,
bias_dtype,
gelu,
gelu_in,
Expand All @@ -460,6 +472,9 @@ def general_gemm(

out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs)

if rocm_split_dbias:
bias_grad = B.reshape(-1, B.shape[-1]).sum(dim=0, dtype=torch.float32).to(bias.dtype)

if IS_HIP_EXTENSION and use_bf16_tn_output_workaround:
out = cast_if_needed(out, torch.float32)

Expand Down
Loading