From 525a7498f05080321a495c807461e49027d0fb74 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 25 Apr 2026 04:05:56 +0000 Subject: [PATCH 1/6] Avoid removing usages from quantized weight in linear op Quantized weight tensor may be used across steps, so removing a usage is not safe. Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 263 ++++++++++++++++-- .../pytorch/ops/basic/basic_linear.py | 57 ++-- 2 files changed, 271 insertions(+), 49 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 0f40e92183..592930523c 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -4,7 +4,7 @@ from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Iterable, Sequence import functools import io import math @@ -199,6 +199,18 @@ def make_reference_and_test_tensors( return ref, test +def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + """Convert to an FP64 CPU tensor""" + if tensor is None: + return None + out = tensor.detach() + if isinstance(out, QuantizedTensor): + out = out.dequantize() + out = out.to(dtype=torch.float64, device="cpu") + out = out.requires_grad_(requires_grad=tensor.requires_grad) + return out + + class TestSequentialContainer: """Tests for sequential container""" @@ -3297,25 +3309,17 @@ def test_layernorm_mlp( y_test = forward(x_test) y_test.backward(dy_test) - def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: - """Convert to FP64 CPU tensor""" - if tensor is None: - return None - out = tensor.detach().to(dtype=torch.float64, device="cpu") - out = out.requires_grad_(requires_grad=tensor.requires_grad) - return out - # Check values tols = {"rtol": 0.25, "atol": 0.5} # Loose tols for sanity checking - torch.testing.assert_close(to_cpu(y_test), y_ref, **tols) - torch.testing.assert_close(to_cpu(x_test.grad), x_ref.grad, **tols) - torch.testing.assert_close(to_cpu(norm.weight.grad), norm_w_ref.grad, **tols) - torch.testing.assert_close(to_cpu(norm.bias.grad), norm_b_ref.grad, **tols) - torch.testing.assert_close(to_cpu(ffn2.weight.grad), w2_ref.grad, **tols) - torch.testing.assert_close(to_cpu(ffn1.weight.grad), w1_ref.grad, **tols) + assert_close(y_test, y_ref, **tols) + assert_close(x_test.grad, x_ref.grad, **tols) + assert_close_grads(norm.weight, norm_w_ref, **tols) + assert_close_grads(norm.bias, norm_b_ref, **tols) + assert_close_grads(ffn2.weight, w2_ref, **tols) + assert_close_grads(ffn1.weight, w1_ref, **tols) if bias: - torch.testing.assert_close(to_cpu(ffn1.bias.grad), b1_ref.grad, **tols) - torch.testing.assert_close(to_cpu(ffn2.bias.grad), b2_ref.grad, **tols) + assert_close_grads(ffn1.bias, b1_ref, **tols) + assert_close_grads(ffn2.bias, b2_ref, **tols) @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) @@ -4543,6 +4547,231 @@ def fuse_ops( torch.testing.assert_close(dw_test, w_ref.grad, **tols) +class TestTrainingLoops: + + def _linear_train_stage( + self, + module: te.ops.Linear, + *, + steps: int = 4, + in_shape: Sequence[int], + out_shape: Sequence[int], + dtype: torch.type, + device: torch.device, + quantization: Optional[str], + recipe: Optional[transformer_engine.common.recipe.Recipe], + ) -> None: + """Perform training steps with linear op""" + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantization is not None: + tols = quantization_tols(quantization) + + for _ in range(steps): + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + w_ref = to_cpu(module.weight) + b_ref = to_cpu(module.bias) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref) + y_ref.backward(dy_ref) + + # Implementation with linear op + with te.autocast(enabled=quantization is not None, recipe=recipe): + y_test = module(x_test) + y_test.backward(dy_test) + + # Check results + assert_close(y_test, y_ref, **tols) + assert_close_grads(x_test, x_ref, **tols) + assert_close_grads(module.weight, w_ref, **tols) + assert_close_grads(module.bias, b_ref, **tols) + + # Update parameters with random values + with torch.no_grad(): + module.weight.copy_(torch.empty_like(module.weight).uniform_()) + module.bias.copy_(torch.empty_like(module.bias).uniform_()) + for param in module.parameters(): + param.grad = None + + @torch.inference_mode + def _linear_infer_stage( + self, + module: te.ops.Linear, + *, + steps: int = 4, + in_shape: Sequence[int], + dtype: torch.type, + device: torch.device, + quantization: Optional[str], + recipe: Optional[transformer_engine.common.recipe.Recipe], + ) -> None: + """Perform inference steps with linear op""" + + # Parameter reference values + w_ref = to_cpu(module.weight) + b_ref = to_cpu(module.bias) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantization is not None: + tols = quantization_tols(quantization) + + for _ in range(steps): + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w_ref, bias=b_ref) + + # Implementation with linear op + with te.autocast(enabled=quantization is not None, recipe=recipe): + y_test = module(x_test) + + # Check results + assert_close(y_test, y_ref, **tols) + + @pytest.mark.parametrize("stages", (["train", "infer"] * 2, ["infer", "train"] * 2)) + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantized_weight", (False, True)) + def test_linear_training_loop( + self, + *, + stages: Sequence[str], + weight_shape: tuple[int, int] = (32, 32), + in_shape: Sequence[int] = (32, -1), + dtype: Optional[torch.dtype] = None, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_weight: bool, + ) -> None: + """Training loops with linear op""" + if dtype is None: + dtype = torch.bfloat16 if is_bf16_available() else torch.float32 + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + maybe_skip_quantization(quantization, dims=out_shape) + if quantization is None and quantized_weight: + pytest.skip("Quantization scheme is not specified") + + # Construct module with random weights + recipe = make_recipe(quantization) + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): + module = te.ops.Linear( + in_features, + out_features, + device=device, + dtype=dtype, + ) + with torch.no_grad(): + for param in module.parameters(): + param.copy_(torch.empty_like(param).uniform_()) + + # Training loop stages + for stage in stages: + if stage == "train": + self._linear_train_stage( + module, + in_shape=in_shape, + out_shape=out_shape, + dtype=dtype, + device=device, + quantization=quantization, + recipe=recipe, + ) + elif stage == "infer": + self._linear_infer_stage( + module, + in_shape=in_shape, + dtype=dtype, + device=device, + quantization=quantization, + recipe=recipe, + ) + else: + raise ValueError(f"Unrecognized stage ({stage})") + + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantized_weight", (False, True)) + def test_linear_inference_loop( + self, + *, + weight_shape: tuple[int, int] = (32, 32), + in_shape: Sequence[int] = (32, -1), + dtype: Optional[torch.dtype] = None, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_weight: bool, + ) -> None: + """Inference loop with linear op""" + if dtype is None: + dtype = torch.bfloat16 if is_bf16_available() else torch.float32 + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + maybe_skip_quantization(quantization, dims=out_shape) + if quantization is None and quantized_weight: + pytest.skip("Quantization scheme is not specified") + + # Construct module with random weights + recipe = make_recipe(quantization) + with ( + torch.inference_mode(), + te.quantized_model_init(enabled=quantized_weight, recipe=recipe), + ): + module = te.ops.Linear( + in_features, + out_features, + device=device, + dtype=dtype, + ) + for param in module.parameters(): + param.copy_(torch.empty_like(param).uniform_()) + + # Inference loop + self._linear_infer_stage( + module, + in_shape=in_shape, + dtype=dtype, + device=device, + quantization=quantization, + recipe=recipe, + ) + + def test_grouped_gemm_quant_cute_matches_mxfp8_quantized() -> None: if not mxfp8_available: pytest.skip(reason_for_no_mxfp8) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 17594726cc..98010cb8fa 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -325,27 +325,10 @@ def pre_first_fuser_forward(self) -> None: if self.weight.device.type == "meta": self.reset_parameters() - def pre_fuser_forward(self, *, requires_grad: bool) -> None: - super().pre_fuser_forward(requires_grad=requires_grad) - if FP8GlobalStateManager.is_fp8_enabled(): - # Configure quantizer usages - # Note: We cache the quantized input for backward pass, - # but discard the quantized weights. - weight_requires_grad = requires_grad and self.weight.requires_grad - columnwise_usage = weight_requires_grad - if FP8GlobalStateManager.get_fp8_recipe().backward_override is not None: - columnwise_usage = False - input_quantizer = self.get_quantizer("forward", 0) - weight_quantizer = self.get_quantizer("forward", 1) - grad_output_quantizer = self.get_quantizer("backward", 0) - input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - weight_quantizer.set_usage(rowwise=True, columnwise=False) - grad_output_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) - # Configure input/grad output tensor + # Configure input/grad output quantizers # Note: These tensors are only used internally. If there is no # tensor-parallel communication, they are only used for GEMM. input_quantizer = self.get_quantizer("forward", 0) @@ -370,21 +353,15 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: # Configure weight quantizer # Note: This function may be called in base class constructor, - # before any basic linear attrs have been set. + # before basic linear attrs have been set. weight_quantizer = self.get_quantizer("forward", 1) - if weight_quantizer is None: - pass - elif is_quantized_tensor(getattr(self, "weight", None)): - # Make sure weight param has correct quantizer - weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) - weight_quantizer.internal = False - self.weight.update_quantizer(weight_quantizer.copy()) - else: - # Use internal tensors if quantized weights will not be - # exposed externally - weight_quantizer.internal = ( - not FP8GlobalStateManager.with_fp8_parameters() - and not getattr(self, "_with_quantized_weight", False) + weight = getattr(self, "weight", None) + if weight_quantizer is not None: + # Determine if quantized weight is exposed as parameter + weight_quantizer.internal = not ( + FP8GlobalStateManager.with_fp8_parameters() + or getattr(self, "_with_quantized_weight", False) + or is_quantized_tensor(weight) ) # Recipe-specific configuration @@ -416,6 +393,22 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: grad_output_quantizer.with_amax_reduction = True grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group + # Update quantizer in quantized weight tensor + if weight_quantizer is not None and is_quantized_tensor(weight): + # Set quantizer usages + # Note: Avoid disabling usages that are already set. The + # weight tensor may be reused across steps, so future + # steps may need usages that are currently unnecessary. + weight_quantizer.set_usage(rowwise=True) + columnwise_usage = torch.is_grad_enabled() + if weight._quantizer is not None and weight._quantizer.columnwise_usage: + columnwise_usage = True + if columnwise_usage: + weight_quantizer.set_usage(columnwise=True) + + # Update weight tensor + weight.update_quantizer(weight_quantizer.copy()) + @staticmethod def _functional_forward( input: torch.Tensor, # pylint: disable=redefined-builtin From 59d417d1281fb2c7d0344f09c47b5a8301d20f53 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 25 Apr 2026 04:17:49 +0000 Subject: [PATCH 2/6] Tweak test to catch bug when alternating train and infer steps Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 592930523c..7e48993cf3 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -4571,6 +4571,14 @@ def _linear_train_stage( tols = quantization_tols(quantization) for _ in range(steps): + # Update parameters with random values to simulate + # optimizer step or FSDP param all-gather + with torch.no_grad(): + module.weight.copy_(torch.empty_like(module.weight).uniform_()) + module.bias.copy_(torch.empty_like(module.bias).uniform_()) + for param in module.parameters(): + param.grad = None + # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, @@ -4602,13 +4610,6 @@ def _linear_train_stage( assert_close_grads(module.weight, w_ref, **tols) assert_close_grads(module.bias, b_ref, **tols) - # Update parameters with random values - with torch.no_grad(): - module.weight.copy_(torch.empty_like(module.weight).uniform_()) - module.bias.copy_(torch.empty_like(module.bias).uniform_()) - for param in module.parameters(): - param.grad = None - @torch.inference_mode def _linear_infer_stage( self, From 4798bdc31d860eaca8c5290a35ee570f43abcedc Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 25 Apr 2026 22:22:13 +0000 Subject: [PATCH 3/6] Avoid removing usages from quantized weights in grouped linear op Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 4 +- .../pytorch/ops/basic/grouped_linear.py | 88 ++++++++++--------- 2 files changed, 47 insertions(+), 45 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 7e48993cf3..36fb08bfb9 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -4553,7 +4553,7 @@ def _linear_train_stage( self, module: te.ops.Linear, *, - steps: int = 4, + steps: int = 3, in_shape: Sequence[int], out_shape: Sequence[int], dtype: torch.type, @@ -4615,7 +4615,7 @@ def _linear_infer_stage( self, module: te.ops.Linear, *, - steps: int = 4, + steps: int = 3, in_shape: Sequence[int], dtype: torch.type, device: torch.device, diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index fe5997a71e..bb9941e564 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -607,28 +607,6 @@ def pre_first_fuser_forward(self) -> None: f"Expected no biases, but bias {group_idx} is initialized" ) - def pre_fuser_forward(self, *, requires_grad: bool) -> None: - super().pre_fuser_forward(requires_grad=requires_grad) - if FP8GlobalStateManager.is_fp8_enabled(): - # Assume weights have consistent grad requirement - weight_requires_grad = ( - self.weight.requires_grad - if self.single_grouped_weight - else self.weight0.requires_grad - ) - weight_requires_grad = requires_grad and weight_requires_grad - - # Configure quantizer usages - # Note: We cache the quantized input for backward pass, - # but discard the quantized weights. - for group_idx in range(self.num_groups): - input_quantizer = self.get_quantizer("forward", 2 * group_idx) - weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) - grad_output_quantizer = self.get_quantizer("backward", group_idx) - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) - weight_quantizer.set_usage(rowwise=True, columnwise=False) - grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) - def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) @@ -641,32 +619,31 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: if grad_output_quantizer is not None: grad_output_quantizer.internal = True - # Handle weight quantizer + # Get weight tensor # Note: This function may be called in base class constructor, - # before any basic linear attrs have been set. - weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) - if weight_quantizer is None: - pass - elif is_quantized_tensor(getattr(self, f"weight{group_idx}", None)): - # Make sure weight param has correct quantizer - weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) - weight_quantizer.internal = False - if self.single_grouped_weight: - self.weight.quantizer = weight_quantizer.copy() - else: - getattr(self, f"weight{group_idx}").update_quantizer(weight_quantizer.copy()) + # before any grouped linear attrs have been set. + weight = None + weight_is_quantized = False + if getattr(self, "single_grouped_weight", False): + weight = getattr(self, "weight", None) + weight_is_quantized = ( + weight is not None and weight.quantizer is not None + ) else: - # Use internal tensors if quantized weights will not be - # exposed externally - weight_quantizer.internal = ( - not FP8GlobalStateManager.with_fp8_parameters() - and not getattr(self, "_with_quantized_weight", False) - and not self.single_grouped_weight + weight = getattr(self, f"weight{group_idx}", None) + weight_is_quantized = is_quantized_tensor(weight) + + # Configure weight quantizer + weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) + if weight_quantizer is not None: + # Determine if quantized weight is exposed as parameter + weight_quantizer.internal = not ( + FP8GlobalStateManager.with_fp8_parameters() + or getattr(self, "_with_quantized_weight", False) + or weight_is_quantized ) # Recipe-specific configuration - # Note: This function may be called in base class constructor, - # before any basic linear attrs have been set. if recipe is not None: if recipe.float8_current_scaling(): input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale @@ -680,6 +657,31 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: recipe.fp8_quant_bwd_grad.amax_epsilon ) + # Update quantizer in quantized weight tensor + if weight_quantizer is not None and weight_is_quantized: + # Get quantizer from weight tensor + weight_tensor_quantizer = ( + weight.quantizer if self.single_grouped_weight else weight._quantizer + ) + + # Set quantizer usages + # Note: Avoid disabling usages that are already set. The + # weight tensor may be reused across steps, so future + # steps may need usages that are currently unnecessary. + weight_quantizer.set_usage(rowwise=True) + columnwise_usage = torch.is_grad_enabled() + if weight_tensor_quantizer is not None and weight_tensor_quantizer.columnwise_usage: + columnwise_usage = True + if columnwise_usage: + weight_quantizer.set_usage(columnwise=True) + + # Update weight tensor + if self.single_grouped_weight: + if group_idx == 0: + weight.quantizer = weight_quantizer.copy() + else: + weight.update_quantizer(weight_quantizer.copy()) + def op_forward(self, *args, **kwargs): raise RuntimeError( f"{self.__class__.__name__} operation has " From 3b9a8a1a21112b2917f3c612cc81085a1e1ae858 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 25 Apr 2026 23:25:01 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/ops/basic/grouped_linear.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index bb9941e564..467ad75516 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -626,9 +626,7 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: weight_is_quantized = False if getattr(self, "single_grouped_weight", False): weight = getattr(self, "weight", None) - weight_is_quantized = ( - weight is not None and weight.quantizer is not None - ) + weight_is_quantized = weight is not None and weight.quantizer is not None else: weight = getattr(self, f"weight{group_idx}", None) weight_is_quantized = is_quantized_tensor(weight) From 63328be0ab62c90f054f2cb5fc9508e554aa13a9 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sun, 26 Apr 2026 05:02:25 +0000 Subject: [PATCH 5/6] Restore pre-forward quantizer config in ops Turns out we still need this in case the quantizer is used before the forward, e.g. in previous ops or CPU offloading. Signed-off-by: Tim Moon --- .../pytorch/ops/basic/basic_linear.py | 15 ++++++++++++++ .../pytorch/ops/basic/grouped_linear.py | 20 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 98010cb8fa..220f6b6710 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -325,6 +325,21 @@ def pre_first_fuser_forward(self) -> None: if self.weight.device.type == "meta": self.reset_parameters() + def pre_fuser_forward(self, *, requires_grad: bool) -> None: + super().pre_fuser_forward(requires_grad=requires_grad) + if FP8GlobalStateManager.is_fp8_enabled(): + # Configure quantizer usages + weight_requires_grad = requires_grad and self.weight.requires_grad + columnwise_usage = weight_requires_grad + if FP8GlobalStateManager.get_fp8_recipe().backward_override is not None: + columnwise_usage = False + input_quantizer = self.get_quantizer("forward", 0) + weight_quantizer = self.get_quantizer("forward", 1) + grad_output_quantizer = self.get_quantizer("backward", 0) + input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + weight_quantizer.set_usage(rowwise=True, columnwise=requires_grad) + grad_output_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 467ad75516..31ee4a5785 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -607,6 +607,26 @@ def pre_first_fuser_forward(self) -> None: f"Expected no biases, but bias {group_idx} is initialized" ) + def pre_fuser_forward(self, *, requires_grad: bool) -> None: + super().pre_fuser_forward(requires_grad=requires_grad) + if FP8GlobalStateManager.is_fp8_enabled(): + # Assume weights have consistent grad requirement + weight_requires_grad = ( + self.weight.requires_grad + if self.single_grouped_weight + else self.weight0.requires_grad + ) + weight_requires_grad = requires_grad and weight_requires_grad + + # Configure quantizer usages + for group_idx in range(self.num_groups): + input_quantizer = self.get_quantizer("forward", 2 * group_idx) + weight_quantizer = self.get_quantizer("forward", 2 * group_idx + 1) + grad_output_quantizer = self.get_quantizer("backward", group_idx) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + weight_quantizer.set_usage(rowwise=True, columnwise=requires_grad) + grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) From 838ea73f0ef3e81a4a590ca3572801bc9f8de486 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 30 Apr 2026 20:14:41 +0000 Subject: [PATCH 6/6] Blindly preserve quantizer usages in quantized weight params. Signed-off-by: Tim Moon --- .../pytorch/ops/basic/basic_linear.py | 20 ++++++++----------- .../pytorch/ops/basic/grouped_linear.py | 18 ++++++++--------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 220f6b6710..19fcf62ced 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -410,18 +410,14 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: # Update quantizer in quantized weight tensor if weight_quantizer is not None and is_quantized_tensor(weight): - # Set quantizer usages - # Note: Avoid disabling usages that are already set. The - # weight tensor may be reused across steps, so future - # steps may need usages that are currently unnecessary. - weight_quantizer.set_usage(rowwise=True) - columnwise_usage = torch.is_grad_enabled() - if weight._quantizer is not None and weight._quantizer.columnwise_usage: - columnwise_usage = True - if columnwise_usage: - weight_quantizer.set_usage(columnwise=True) - - # Update weight tensor + if weight._quantizer is not None: + # Preserve existing usages in weight tensor. Even if a + # usage is currently unnecessary, the weight tensor + # may be used elsewhere. + weight_quantizer.set_usage( + rowwise=weight._quantizer.rowwise_usage, + columnwise=weight._quantizer.columnwise_usage, + ) weight.update_quantizer(weight_quantizer.copy()) @staticmethod diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 31ee4a5785..b503cb186b 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -682,16 +682,14 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: weight.quantizer if self.single_grouped_weight else weight._quantizer ) - # Set quantizer usages - # Note: Avoid disabling usages that are already set. The - # weight tensor may be reused across steps, so future - # steps may need usages that are currently unnecessary. - weight_quantizer.set_usage(rowwise=True) - columnwise_usage = torch.is_grad_enabled() - if weight_tensor_quantizer is not None and weight_tensor_quantizer.columnwise_usage: - columnwise_usage = True - if columnwise_usage: - weight_quantizer.set_usage(columnwise=True) + # Preserve existing usages in weight tensor. Even if a + # usage is currently unnecessary, the weight tensor + # may be used elsewhere. + if weight_tensor_quantizer is not None: + weight_quantizer.set_usage( + rowwise=weight_tensor_quantizer.rowwise_usage, + columnwise=weight_tensor_quantizer.columnwise_usage, + ) # Update weight tensor if self.single_grouped_weight: