Skip to content
264 changes: 247 additions & 17 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from __future__ import annotations

from collections.abc import Iterable
from collections.abc import Iterable, Sequence
import functools
import io
import math
Expand Down Expand Up @@ -200,6 +200,18 @@ def make_reference_and_test_tensors(
return ref, test


def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
"""Convert to an FP64 CPU tensor"""
if tensor is None:
return None
out = tensor.detach()
if isinstance(out, QuantizedTensor):
out = out.dequantize()
out = out.to(dtype=torch.float64, device="cpu")
out = out.requires_grad_(requires_grad=tensor.requires_grad)
return out


class MegatronTrainingHelper:
"""Test-side stand-in for the Megatron-Core DDP / MegatronFSDP wrapper.
Megatron's DDP wrapper (and MegatronFSDP) owns the per-parameter
Expand Down Expand Up @@ -3368,25 +3380,17 @@ def test_layernorm_mlp(
y_test = forward(x_test)
y_test.backward(dy_test)

def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
"""Convert to FP64 CPU tensor"""
if tensor is None:
return None
out = tensor.detach().to(dtype=torch.float64, device="cpu")
out = out.requires_grad_(requires_grad=tensor.requires_grad)
return out

# Check values
tols = {"rtol": 0.25, "atol": 0.5} # Loose tols for sanity checking
torch.testing.assert_close(to_cpu(y_test), y_ref, **tols)
torch.testing.assert_close(to_cpu(x_test.grad), x_ref.grad, **tols)
torch.testing.assert_close(to_cpu(norm.weight.grad), norm_w_ref.grad, **tols)
torch.testing.assert_close(to_cpu(norm.bias.grad), norm_b_ref.grad, **tols)
torch.testing.assert_close(to_cpu(ffn2.weight.grad), w2_ref.grad, **tols)
torch.testing.assert_close(to_cpu(ffn1.weight.grad), w1_ref.grad, **tols)
assert_close(y_test, y_ref, **tols)
assert_close(x_test.grad, x_ref.grad, **tols)
assert_close_grads(norm.weight, norm_w_ref, **tols)
assert_close_grads(norm.bias, norm_b_ref, **tols)
assert_close_grads(ffn2.weight, w2_ref, **tols)
assert_close_grads(ffn1.weight, w1_ref, **tols)
if bias:
torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols)
torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols)
assert_close_grads(ffn1.bias, b1_ref, **tols)
assert_close_grads(ffn2.bias, b2_ref, **tols)

@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("dtype", _dtypes)
Expand Down Expand Up @@ -4740,6 +4744,232 @@ def fuse_ops(
torch.testing.assert_close(dw_test, w_ref.grad, **tols)


class TestTrainingLoops:

def _linear_train_stage(
self,
module: te.ops.Linear,
*,
steps: int = 3,
in_shape: Sequence[int],
out_shape: Sequence[int],
dtype: torch.type,
device: torch.device,
quantization: Optional[str],
recipe: Optional[transformer_engine.common.recipe.Recipe],
) -> None:
"""Perform training steps with linear op"""

# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantization is not None:
tols = quantization_tols(quantization)

for _ in range(steps):
# Update parameters with random values to simulate
# optimizer step or FSDP param all-gather
with torch.no_grad():
module.weight.copy_(torch.empty_like(module.weight).uniform_())
module.bias.copy_(torch.empty_like(module.bias).uniform_())
for param in module.parameters():
param.grad = None

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)
w_ref = to_cpu(module.weight)
b_ref = to_cpu(module.bias)

# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref)
y_ref.backward(dy_ref)

# Implementation with linear op
with te.autocast(enabled=quantization is not None, recipe=recipe):
y_test = module(x_test)
y_test.backward(dy_test)

# Check results
assert_close(y_test, y_ref, **tols)
assert_close_grads(x_test, x_ref, **tols)
assert_close_grads(module.weight, w_ref, **tols)
assert_close_grads(module.bias, b_ref, **tols)

@torch.inference_mode
def _linear_infer_stage(
self,
module: te.ops.Linear,
*,
steps: int = 3,
in_shape: Sequence[int],
dtype: torch.type,
device: torch.device,
quantization: Optional[str],
recipe: Optional[transformer_engine.common.recipe.Recipe],
) -> None:
"""Perform inference steps with linear op"""

# Parameter reference values
w_ref = to_cpu(module.weight)
b_ref = to_cpu(module.bias)

# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantization is not None:
tols = quantization_tols(quantization)

for _ in range(steps):
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
)

# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref)

# Implementation with linear op
with te.autocast(enabled=quantization is not None, recipe=recipe):
y_test = module(x_test)

# Check results
assert_close(y_test, y_ref, **tols)

@pytest.mark.parametrize("stages", (["train", "infer"] * 2, ["infer", "train"] * 2))
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_weight", (False, True))
def test_linear_training_loop(
self,
*,
stages: Sequence[str],
weight_shape: tuple[int, int] = (32, 32),
in_shape: Sequence[int] = (32, -1),
dtype: Optional[torch.dtype] = None,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_weight: bool,
) -> None:
"""Training loops with linear op"""
if dtype is None:
dtype = torch.bfloat16 if is_bf16_available() else torch.float32

# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]

# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
if quantization is None and quantized_weight:
pytest.skip("Quantization scheme is not specified")

# Construct module with random weights
recipe = make_recipe(quantization)
with te.quantized_model_init(enabled=quantized_weight, recipe=recipe):
module = te.ops.Linear(
in_features,
out_features,
device=device,
dtype=dtype,
)
with torch.no_grad():
for param in module.parameters():
param.copy_(torch.empty_like(param).uniform_())

# Training loop stages
for stage in stages:
if stage == "train":
self._linear_train_stage(
module,
in_shape=in_shape,
out_shape=out_shape,
dtype=dtype,
device=device,
quantization=quantization,
recipe=recipe,
)
elif stage == "infer":
self._linear_infer_stage(
module,
in_shape=in_shape,
dtype=dtype,
device=device,
quantization=quantization,
recipe=recipe,
)
else:
raise ValueError(f"Unrecognized stage ({stage})")

@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_weight", (False, True))
def test_linear_inference_loop(
self,
*,
weight_shape: tuple[int, int] = (32, 32),
in_shape: Sequence[int] = (32, -1),
dtype: Optional[torch.dtype] = None,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_weight: bool,
) -> None:
"""Inference loop with linear op"""
if dtype is None:
dtype = torch.bfloat16 if is_bf16_available() else torch.float32

# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]

# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
if quantization is None and quantized_weight:
pytest.skip("Quantization scheme is not specified")

# Construct module with random weights
recipe = make_recipe(quantization)
with (
torch.inference_mode(),
te.quantized_model_init(enabled=quantized_weight, recipe=recipe),
):
module = te.ops.Linear(
in_features,
out_features,
device=device,
dtype=dtype,
)
for param in module.parameters():
param.copy_(torch.empty_like(param).uniform_())

# Inference loop
self._linear_infer_stage(
module,
in_shape=in_shape,
dtype=dtype,
device=device,
quantization=quantization,
recipe=recipe,
)


def test_grouped_gemm_quant_cute_matches_mxfp8_quantized() -> None:
if not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
Expand Down
40 changes: 22 additions & 18 deletions transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,6 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None:
super().pre_fuser_forward(requires_grad=requires_grad)
if FP8GlobalStateManager.is_fp8_enabled():
# Configure quantizer usages
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
Comment on lines -332 to -333
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment was made incorrect in #1817.

weight_requires_grad = requires_grad and self.weight.requires_grad
columnwise_usage = weight_requires_grad
if FP8GlobalStateManager.get_fp8_recipe().backward_override is not None:
Expand All @@ -339,13 +337,13 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None:
weight_quantizer = self.get_quantizer("forward", 1)
grad_output_quantizer = self.get_quantizer("backward", 0)
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
weight_quantizer.set_usage(rowwise=True, columnwise=False)
weight_quantizer.set_usage(rowwise=True, columnwise=requires_grad)
grad_output_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)

def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
super().reset_recipe_state(recipe=recipe)

# Configure input/grad output tensor
# Configure input/grad output quantizers
# Note: These tensors are only used internally. If there is no
# tensor-parallel communication, they are only used for GEMM.
input_quantizer = self.get_quantizer("forward", 0)
Expand All @@ -370,21 +368,15 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:

# Configure weight quantizer
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
# before basic linear attrs have been set.
weight_quantizer = self.get_quantizer("forward", 1)
if weight_quantizer is None:
pass
elif is_quantized_tensor(getattr(self, "weight", None)):
# Make sure weight param has correct quantizer
weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
weight_quantizer.internal = False
self.weight.update_quantizer(weight_quantizer.copy())
else:
# Use internal tensors if quantized weights will not be
# exposed externally
weight_quantizer.internal = (
not FP8GlobalStateManager.with_fp8_parameters()
and not getattr(self, "_with_quantized_weight", False)
weight = getattr(self, "weight", None)
if weight_quantizer is not None:
# Determine if quantized weight is exposed as parameter
weight_quantizer.internal = not (
FP8GlobalStateManager.with_fp8_parameters()
or getattr(self, "_with_quantized_weight", False)
or is_quantized_tensor(weight)
)

# Recipe-specific configuration
Expand Down Expand Up @@ -416,6 +408,18 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group

# Update quantizer in quantized weight tensor
if weight_quantizer is not None and is_quantized_tensor(weight):
if weight._quantizer is not None:
# Preserve existing usages in weight tensor. Even if a
# usage is currently unnecessary, the weight tensor
# may be used elsewhere.
weight_quantizer.set_usage(
rowwise=weight._quantizer.rowwise_usage,
columnwise=weight._quantizer.columnwise_usage,
)
weight.update_quantizer(weight_quantizer.copy())
Comment on lines +411 to +421
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you move this to its own section after the other quantizers? Logically, it should follow the rest of the setup for the weight quantizer.
Also, this logic is a little strange to me. Ultimately what it does is setting the columnwise usage if the grad is enabled and then keeping it forever. So it should be something like this:

Suggested change
# Update quantizer in quantized weight tensor
if weight_quantizer is not None and is_quantized_tensor(weight):
# Set quantizer usages
# Note: Avoid disabling usages that are already set. The
# weight tensor may be reused across steps, so future
# steps may need usages that are currently unnecessary.
weight_quantizer.set_usage(rowwise=True)
columnwise_usage = torch.is_grad_enabled()
if weight._quantizer is not None and weight._quantizer.columnwise_usage:
columnwise_usage = True
if columnwise_usage:
weight_quantizer.set_usage(columnwise=True)
# Update weight tensor
weight.update_quantizer(weight_quantizer.copy())
# Update quantizer in quantized weight tensor
if weight_quantizer is not None and is_quantized_tensor(weight):
# Set quantizer usages
# Note: Avoid disabling usages that are already set. The
# weight tensor may be reused across steps, so future
# steps may need usages that are currently unnecessary.
if weight._quantizer is None or (not weight._quantizer.columnwise_usage and torch.is_grad_enabled()):
weight_quantizers.set_usage(rowwise=True, columnwise=True)
weight.update_quantizer(weight_quantizer.copy())

(the assumption here is that the rowwise would always be there for the weights).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, no. I see that the context of this function is the resetting of the recipe, so sure, we need to create the new quantizer and pass it to the tensor. This makes sense. I don't see the requantization happening in this case though - if we have weight as already quantized tensor then where is the code to actually apply this change in the _quantizer to the data held by it?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The big problem with the existing code is that we update the weight param quantizer, and then we do more weight quantizer configuration afterwards. We should make sure the weight quantizer is fully configured, and only then update the quantized param. The logic for updating the quantized param was also a little convoluted, so I attempted to make it a bit more clear.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think requantization really makes sense: #2929 (comment)

I can see the argument that we should blindly preserve usages in the weight param quantizer. This logic may be trying too hard to clean up after a user doing something weird.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't agree with that (if I change the recipe to a different type then trying to run without this requantization would just fail since e.g. we would try to multiply the MXFP8 tensor with FP8 CS tensor) but I think this should be addressed in its own PR rather than here.


@staticmethod
def _functional_forward(
input: torch.Tensor, # pylint: disable=redefined-builtin
Expand Down
Loading
Loading